diff --git a/database.go b/database.go index 1537312..486e4df 100644 --- a/database.go +++ b/database.go @@ -11,7 +11,6 @@ import ( var edgeCtx context.Context var edgeClient *edgedb.Client -var CurrentUser User type User struct { ID edgedb.UUID `edgedb:"id"` @@ -74,22 +73,15 @@ func init() { log.Fatal(err) } - // Get the user test@example.com - var user User - err = client.QuerySingle(ctx, ` - SELECT User { id, email, name, setting } FILTER .email = $0 LIMIT 1 - `, &user, "test@example.com") + // TODO Change + edgeCtx = ctx + var clientUUID edgedb.UUID + clientUUID, err = edgedb.ParseUUID("9323365e-0b09-11ef-8f41-c3575d386283") if err != nil { - fmt.Println("Error in edgedb.QuerySingle: in init") + fmt.Println("Error in edgedb.ParseUUID: in init") log.Fatal(err) } - CurrentUser = user - - fmt.Print("Current User: ") - fmt.Println(CurrentUser) - - edgeCtx = ctx - edgeClient = client + edgeClient = client.WithGlobals(map[string]interface{}{"current_user_id": clientUUID}) } func getLastArea() edgedb.UUID { @@ -99,7 +91,7 @@ func getLastArea() edgedb.UUID { filter .conversation.name = 'Default' AND .conversation.user.id = $0 order by .position desc limit 1 - `, &inserted, CurrentUser.ID) + `, &inserted, getCurrentUserID()) if err != nil { fmt.Println("Error in edgedb.QuerySingle: in getLastArea") log.Fatal(err) @@ -107,21 +99,43 @@ func getLastArea() edgedb.UUID { return inserted.id } +func getCurrentUserID() edgedb.UUID { + var result User + err := edgeClient.QuerySingle(edgeCtx, "SELECT global currentUser LIMIT 1;", &result) + if err != nil { + fmt.Println("Error in edgedb.QuerySingle: in getCurrentUserID") + fmt.Println(err) + + } + return result.ID +} + +func checkIfLogin() bool { + var result User + err := edgeClient.QuerySingle(edgeCtx, "SELECT global currentUser LIMIT 1;", &result) + if err != nil { + fmt.Println("Error in edgedb.QuerySingle: in checkIfLogin") + fmt.Println(err) + return false + } + return true +} + func insertArea() edgedb.UUID { // Insert a new area. var inserted struct{ id edgedb.UUID } err := edgeClient.QuerySingle(edgeCtx, ` WITH - positionVar := count((SELECT Area FILTER .conversation.name = 'Default' AND .conversation.user.id = $0)) + 1 + positionVar := count((SELECT Area FILTER .conversation.name = 'Default' AND .conversation.user = global currentUser)) + 1 INSERT Area { position := positionVar, conversation := ( SELECT Conversation - FILTER .name = 'Default' AND .user.id = $0 + FILTER .name = 'Default' AND .user = global currentUser LIMIT 1 ) } - `, &inserted, CurrentUser.ID) + `, &inserted) if err != nil { fmt.Println("Error in edgedb.QuerySingle: in insertArea") log.Fatal(err) @@ -144,11 +158,11 @@ func insertUserMessage(content string) edgedb.UUID { ), conversation := ( SELECT Conversation - FILTER .name = 'Default' AND .user.id = $3 + FILTER .name = 'Default' AND .user = global currentUser LIMIT 1 ) } - `, &inserted, "user", content, lastAreaID, CurrentUser.ID) + `, &inserted, "user", content, lastAreaID) if err != nil { fmt.Println("Error in edgedb.QuerySingle: in insertUserMessage") log.Fatal(err) @@ -167,16 +181,16 @@ func insertBotMessage(content string, selected bool, model string) edgedb.UUID { selected := $3, conversation := ( SELECT Conversation - FILTER .name = 'Default' AND .user.id = $4 + FILTER .name = 'Default' AND .user = global currentUser LIMIT 1 ), area := ( SELECT Area - FILTER .id = $5 + FILTER .id = $4 LIMIT 1 ) } - `, &inserted, "bot", model, content, selected, CurrentUser.ID, lastAreaID) + `, &inserted, "bot", model, content, selected, lastAreaID) if err != nil { fmt.Println("Error in edgedb.QuerySingle: in insertBotMessage") log.Fatal(err) @@ -185,6 +199,11 @@ func insertBotMessage(content string, selected bool, model string) edgedb.UUID { } func getAllMessages() []Message { + // If no CurrentUser, return an empty array + if !checkIfLogin() { + return []Message{} + } + var messages []Message err := edgeClient.Query(edgeCtx, ` @@ -195,9 +214,9 @@ func getAllMessages() []Message { role, content, date - } FILTER .conversation.name = 'Default' AND .conversation.user.id = $0 + } FILTER .conversation.name = 'Default' AND .conversation.user = global currentUser ORDER BY .date ASC - `, &messages, CurrentUser.ID) + `, &messages) if err != nil { fmt.Println("Error in edgedb.Query: in getAllMessages") fmt.Println(err) diff --git a/dbschema/default.esdl b/dbschema/default.esdl index cf9ca93..521cf3c 100644 --- a/dbschema/default.esdl +++ b/dbschema/default.esdl @@ -1,8 +1,15 @@ +using extension auth; + module default { + global current_user_id: uuid; + global currentUser := ( + select User + filter .id = global current_user_id + ); + type User { - required email: str; - required name: str; required setting: Setting; + required identity: ext::auth::Identity; } type Key { diff --git a/dbschema/migrations/00013-m1frbbh.edgeql b/dbschema/migrations/00013-m1frbbh.edgeql new file mode 100644 index 0000000..5e0331e --- /dev/null +++ b/dbschema/migrations/00013-m1frbbh.edgeql @@ -0,0 +1,6 @@ +CREATE MIGRATION m1frbbhs2tdrqsle67rgzwazozf3qp4xylmnepeacjodtdtcpmqzgq + ONTO m16dflw7c2tzuugatxe7g7ngx6ddn6eczk7a7a6oo5zlixscu565ta +{ + CREATE EXTENSION pgcrypto VERSION '1.3'; + CREATE EXTENSION auth VERSION '1.0'; +}; diff --git a/dbschema/migrations/00014-m1ipsij.edgeql b/dbschema/migrations/00014-m1ipsij.edgeql new file mode 100644 index 0000000..400b157 --- /dev/null +++ b/dbschema/migrations/00014-m1ipsij.edgeql @@ -0,0 +1,14 @@ +CREATE MIGRATION m1ipsijwf3e65w6mvm2622aekxdgvd5hpdlej6mrn2twfbsh47s4mq + ONTO m1frbbhs2tdrqsle67rgzwazozf3qp4xylmnepeacjodtdtcpmqzgq +{ + ALTER TYPE default::User { + CREATE REQUIRED LINK identity: ext::auth::Identity { + SET REQUIRED USING ({}); + }; + }; + CREATE GLOBAL default::currentUser := (SELECT + default::User + FILTER + (.identity ?= GLOBAL ext::auth::ClientTokenIdentity) + ); +}; diff --git a/dbschema/migrations/00015-m1pknui.edgeql b/dbschema/migrations/00015-m1pknui.edgeql new file mode 100644 index 0000000..ca3d978 --- /dev/null +++ b/dbschema/migrations/00015-m1pknui.edgeql @@ -0,0 +1,8 @@ +CREATE MIGRATION m1pknuisxcvpidw4gpdqip6qkiohucqgzmv3rcbli3dyvk2hlbmvaa + ONTO m1ipsijwf3e65w6mvm2622aekxdgvd5hpdlej6mrn2twfbsh47s4mq +{ + ALTER TYPE default::User { + DROP PROPERTY email; + DROP PROPERTY name; + }; +}; diff --git a/dbschema/migrations/00016-m1t2tdd.edgeql b/dbschema/migrations/00016-m1t2tdd.edgeql new file mode 100644 index 0000000..dc893ce --- /dev/null +++ b/dbschema/migrations/00016-m1t2tdd.edgeql @@ -0,0 +1,10 @@ +CREATE MIGRATION m1t2tddtq2grfxf4ldn6aj4yotvwiqclsrks6chobvs33d5mcd62mq + ONTO m1pknuisxcvpidw4gpdqip6qkiohucqgzmv3rcbli3dyvk2hlbmvaa +{ + CREATE GLOBAL default::current_user_id -> std::uuid; + ALTER GLOBAL default::currentUser USING (SELECT + default::User + FILTER + (.id = GLOBAL default::current_user_id) + ); +}; diff --git a/login.go b/login.go new file mode 100644 index 0000000..216cb97 --- /dev/null +++ b/login.go @@ -0,0 +1,128 @@ +package main + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/gofiber/fiber/v2" +) + +const EDGEDB_AUTH_BASE_URL = "http://127.0.0.1:10700/db/main/ext/auth" + +func generatePKCE() (string, string) { + verifier := make([]byte, 32) + _, err := rand.Read(verifier) + if err != nil { + panic(err) + } + + challenge := sha256.Sum256(verifier) + return base64.RawURLEncoding.EncodeToString(verifier), base64.RawURLEncoding.EncodeToString(challenge[:]) +} + +func handleUiSignIn(c *fiber.Ctx) error { + verifier, challenge := generatePKCE() + + fmt.Println("handleUiSignIn verifier: " + verifier) + fmt.Println("handleUiSignIn challenge: " + challenge) + + redirectURL := fmt.Sprintf("%s/ui/signin?challenge=%s", EDGEDB_AUTH_BASE_URL, challenge) + + c.Cookie(&fiber.Cookie{ + Name: "jade-edgedb-pkce-verifier", + Value: verifier, + HTTPOnly: true, + Path: "/", + Secure: true, + SameSite: "Strict", + }) + + return c.Redirect(redirectURL, fiber.StatusMovedPermanently) +} + +func handleUiSignUp(c *fiber.Ctx) error { + verifier, challenge := generatePKCE() + + fmt.Println("handleUiSignUp verifier: " + verifier) + fmt.Println("handleUiSignUp challenge: " + challenge) + + redirectURL := fmt.Sprintf("%s/ui/signup?challenge=%s", EDGEDB_AUTH_BASE_URL, challenge) + + c.Cookie(&fiber.Cookie{ + Name: "jade-edgedb-pkce-verifier", + Value: verifier, + HTTPOnly: true, + Path: "/", + Secure: true, + SameSite: "Strict", + }) + + return c.Redirect(redirectURL, fiber.StatusMovedPermanently) +} + +func handleCallbackSignup(c *fiber.Ctx) error { + code := c.Query("code") + fmt.Println("Callback signup code: " + code) + + verifier := c.Cookies("jade-edgedb-pkce-verifier") + if verifier == "" { + return c.Status(fiber.StatusBadRequest).SendString("Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?") + } + + return c.SendString("OK") +} + +func handleCallback(c *fiber.Ctx) error { + code := c.Query("code") + if code == "" { + error := c.Query("error") + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("OAuth callback is missing 'code'. OAuth provider responded with error: %s", error)) + } + + verifier := c.Cookies("jade-edgedb-pkce-verifier") + if verifier == "" { + return c.Status(fiber.StatusBadRequest).SendString("Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?") + } + + fmt.Println("handleCallback code: " + code) + fmt.Println("handleCallback verifier: " + verifier) + + codeExchangeURL := fmt.Sprintf("%s/token?code=%s&verifier=%s", EDGEDB_AUTH_BASE_URL, code, verifier) + resp, err := http.Get(codeExchangeURL) + if err != nil { + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", err.Error())) + } + defer resp.Body.Close() + + if resp.StatusCode != fiber.StatusOK { + body, _ := io.ReadAll(resp.Body) + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", string(body))) + } + + var tokenResponse struct { + AuthToken string `json:"auth_token"` + } + err = json.NewDecoder(resp.Body).Decode(&tokenResponse) + if err != nil { + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error decoding auth server response: %s", err.Error())) + } + + c.Cookie(&fiber.Cookie{ + Name: "edgedb-auth-token", + Value: tokenResponse.AuthToken, + HTTPOnly: true, + Path: "/", + Secure: true, + SameSite: "Strict", + }) + + fmt.Println("Login successful") + fmt.Println("edgedb-auth-token: " + tokenResponse.AuthToken) + + return c.SendStatus(fiber.StatusNoContent) +} diff --git a/main.go b/main.go index 24fad16..be59e8e 100644 --- a/main.go +++ b/main.go @@ -12,13 +12,11 @@ import ( var userTmpl *pongo2.Template var botTmpl *pongo2.Template var modelsPopoverTmpl *pongo2.Template -var conversationsPopoverTmpl *pongo2.Template func main() { botTmpl = pongo2.Must(pongo2.FromFile("views/partials/message-bot.html")) userTmpl = pongo2.Must(pongo2.FromFile("views/partials/message-user.html")) modelsPopoverTmpl = pongo2.Must(pongo2.FromFile("views/partials/modelsPopover.html")) - conversationsPopoverTmpl = pongo2.Must(pongo2.FromFile("views/partials/conversationsPopover.html")) // Import HTML using django engine/template engine := django.New("./views", ".html") @@ -45,6 +43,11 @@ func main() { app.Get("/generateMultipleMessages", GenerateMultipleMessages) // Generate multiple messages app.Get("/messageContent", GetMessageContentHandler) + app.Get("/auth/ui/signin", handleUiSignIn) + app.Get("/auth/ui/signup", handleUiSignUp) + app.Get("/auth/callback", handleCallback) + app.Get("/auth/callbackSignup", handleCallbackSignup) + app.Get("test", func(c *fiber.Ctx) error { fmt.Println("test") return c.SendString("") diff --git a/static/jade.png b/static/jade.png index f30210c..0e20344 100644 Binary files a/static/jade.png and b/static/jade.png differ diff --git a/static/jade_dark.png b/static/jade_dark.png index 62cd7ae..4f8dfca 100644 Binary files a/static/jade_dark.png and b/static/jade_dark.png differ diff --git a/views/chat.html b/views/chat.html index c9ed4f6..bd39682 100644 --- a/views/chat.html +++ b/views/chat.html @@ -1,4 +1,5 @@
+ diff --git a/views/layouts/main.html b/views/layouts/main.html index e9ce60d..3180c44 100644 --- a/views/layouts/main.html +++ b/views/layouts/main.html @@ -53,6 +53,8 @@ gap: 10px; } + + diff --git a/views/login.html b/views/login.html deleted file mode 100644 index 97e0b1a..0000000 --- a/views/login.html +++ /dev/null @@ -1,32 +0,0 @@ -
- - - -
-
-
- -
- - - -
-
-
-
-
\ No newline at end of file diff --git a/views/partials/message-bot.html b/views/partials/message-bot.html index e04bcd9..dc80da9 100644 --- a/views/partials/message-bot.html +++ b/views/partials/message-bot.html @@ -8,7 +8,7 @@ hx-target="#content-{{ ConversationAreaId }}" onclick="toggleGrayscale(this)">
User Image + style="filter: grayscale(100%);" {% endif %} title="{{ message.Name }}">
diff --git a/views/partials/navbar.html b/views/partials/navbar.html index 6e236c6..3b02ab6 100644 --- a/views/partials/navbar.html +++ b/views/partials/navbar.html @@ -10,7 +10,16 @@ \ No newline at end of file