diff --git a/Chat.go b/Chat.go index 58e1698..0ec65bd 100644 --- a/Chat.go +++ b/Chat.go @@ -12,7 +12,15 @@ import ( ) func ChatPageHandler(c *fiber.Ctx) error { - return c.Render("chat", fiber.Map{}, "layouts/main") + authCookie := c.Cookies("jade-edgedb-auth-token", "") + + if authCookie != "" && !checkIfLogin() { + edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": authCookie}) + } + + fmt.Println("Current User: ", getCurrentUser()) + + return c.Render("chat", fiber.Map{"IsLogin": checkIfLogin()}, "layouts/main") } func LoadModelSelectionHandler(c *fiber.Ctx) error { @@ -45,8 +53,6 @@ func LoadUsageKPIHandler(c *fiber.Ctx) error { log.Fatal(err) } - fmt.Println(TotalUsage) - out, err := pongo2.Must(pongo2.FromFile("views/partials/usagePopover.html")).Execute(pongo2.Context{ "TotalUsage": TotalUsage, }) @@ -85,7 +91,11 @@ func DeleteMessageHandler(c *fiber.Ctx) error { } func LoadChatHandler(c *fiber.Ctx) error { - return c.SendString(generateChatHTML()) + if checkIfLogin() { + return c.SendString(generateChatHTML()) + } else { + return c.SendString(generateWelcomeChatHTML()) + } } type NextMessage struct { @@ -115,7 +125,6 @@ func generateChatHTML() string { // Reset NextMessages when a user message is encountered NextMessages = []NextMessage{} } else { - fmt.Println(i) modelID, exist := message.ModelID.Get() if !exist { modelID = "gpt-3.5-turbo" @@ -199,3 +208,28 @@ func GetMessageContentHandler(c *fiber.Ctx) error { return c.SendString(out) } + +func generateWelcomeChatHTML() string { + htmlString := "
" + + NextMessages := []NextMessage{} + nextMsg := NextMessage{ + Icon: "bouvai2", // Assuming Icon is a field you want to include from Message + Content: markdownToHTML("Hi, I'm Bouvai. How can I help you today?"), + Hidden: false, // Assuming Hidden is a field you want to include from Message + Id: "0", + Name: "JADE", + } + NextMessages = append(NextMessages, nextMsg) + + botOut, err := botTmpl.Execute(pongo2.Context{"Messages": NextMessages, "ConversationAreaId": 0}) + if err != nil { + panic(err) + } + htmlString += botOut + htmlString += "
" + htmlString += "
" + + // Render the HTML template with the messages + return htmlString +} diff --git a/database.go b/database.go index 2c6ecdc..e3a31d6 100644 --- a/database.go +++ b/database.go @@ -73,15 +73,8 @@ func init() { log.Fatal(err) } - // TODO Change edgeCtx = ctx - var clientUUID edgedb.UUID - clientUUID, err = edgedb.ParseUUID("9323365e-0b09-11ef-8f41-c3575d386283") - if err != nil { - fmt.Println("Error in edgedb.ParseUUID: in init") - log.Fatal(err) - } - edgeClient = client.WithGlobals(map[string]interface{}{"current_user_id": clientUUID}) + edgeClient = client } func getLastArea() edgedb.UUID { @@ -99,15 +92,21 @@ func getLastArea() edgedb.UUID { return inserted.id } -func checkIfLogin() bool { +func getCurrentUser() User { var result User err := edgeClient.QuerySingle(edgeCtx, "SELECT global currentUser LIMIT 1;", &result) if err != nil { - fmt.Println("Error in edgedb.QuerySingle: in checkIfLogin") + fmt.Println("Error in edgedb.QuerySingle: in getCurrentUser") fmt.Println(err) - return false + return User{} } - return true + return result +} + +func checkIfLogin() bool { + var result User + err := edgeClient.QuerySingle(edgeCtx, "SELECT global currentUser LIMIT 1;", &result) + return err == nil } func insertArea() edgedb.UUID { diff --git a/dbschema/default.esdl b/dbschema/default.esdl index 521cf3c..3445cbe 100644 --- a/dbschema/default.esdl +++ b/dbschema/default.esdl @@ -1,10 +1,11 @@ using extension auth; module default { - global current_user_id: uuid; global currentUser := ( + assert_single(( select User - filter .id = global current_user_id + filter .identity = global ext::auth::ClientTokenIdentity + )) ); type User { @@ -58,13 +59,13 @@ module default { } type Usage { - required model_id: str; - required user: User; - input_cost: float32; - output_cost: float32; - input_token: int32; - output_token: int32; - required date: datetime { + required model_id: str; + user: User; + input_cost: float32; + output_cost: float32; + input_token: int32; + output_token: int32; + required date: datetime { default := datetime_current(); } } diff --git a/dbschema/migrations/00017-m1g3r5w.edgeql b/dbschema/migrations/00017-m1g3r5w.edgeql new file mode 100644 index 0000000..964451e --- /dev/null +++ b/dbschema/migrations/00017-m1g3r5w.edgeql @@ -0,0 +1,11 @@ +CREATE MIGRATION m1g3r5wwsplyqphor3yvb7dzp2txacgd6ed3udrpvakuoz3e2zj7ka + ONTO m1t2tddtq2grfxf4ldn6aj4yotvwiqclsrks6chobvs33d5mcd62mq +{ + DROP GLOBAL default::currentUser; + CREATE GLOBAL default::current_user := (std::assert_single((SELECT + default::User + FILTER + (.identity = GLOBAL ext::auth::ClientTokenIdentity) + ))); + DROP GLOBAL default::current_user_id; +}; diff --git a/dbschema/migrations/00018-m1c2b3o.edgeql b/dbschema/migrations/00018-m1c2b3o.edgeql new file mode 100644 index 0000000..f4f7f37 --- /dev/null +++ b/dbschema/migrations/00018-m1c2b3o.edgeql @@ -0,0 +1,5 @@ +CREATE MIGRATION m1c2b3o4bhfgldk3kebneqyv6oxfgmorrp62uv5rreu773ryed3azq + ONTO m1g3r5wwsplyqphor3yvb7dzp2txacgd6ed3udrpvakuoz3e2zj7ka +{ + ALTER GLOBAL default::current_user RENAME TO default::currentUser; +}; diff --git a/dbschema/migrations/00019-m1rzoj5.edgeql b/dbschema/migrations/00019-m1rzoj5.edgeql new file mode 100644 index 0000000..229212c --- /dev/null +++ b/dbschema/migrations/00019-m1rzoj5.edgeql @@ -0,0 +1,9 @@ +CREATE MIGRATION m1rzoj5rvhxkec6tsew6dc6bd3ksrrmf7t3k532avozbg4722xulhq + ONTO m1c2b3o4bhfgldk3kebneqyv6oxfgmorrp62uv5rreu773ryed3azq +{ + ALTER TYPE default::Usage { + ALTER LINK user { + RESET OPTIONALITY; + }; + }; +}; diff --git a/login.go b/login.go index f09506d..fee2cb2 100644 --- a/login.go +++ b/login.go @@ -8,62 +8,106 @@ import ( "fmt" "io" "net/http" - "time" + "github.com/edgedb/edgedb-go" "github.com/gofiber/fiber/v2" ) const EDGEDB_AUTH_BASE_URL = "http://127.0.0.1:10700/db/main/ext/auth" func generatePKCE() (string, string) { - fmt.Println("Generating PKCE") - verifier := make([]byte, 32) - _, err := rand.Read(verifier) + verifier_source := make([]byte, 32) + _, err := rand.Read(verifier_source) if err != nil { panic(err) } - challenge := sha256.Sum256(verifier) - - var URLEncoding = base64.NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") - var RawURLEncoding = URLEncoding.WithPadding(base64.NoPadding) - encodedVerifier := RawURLEncoding.EncodeToString(verifier) - encodedChallenge := RawURLEncoding.EncodeToString(challenge[:]) - - fmt.Println("verifier: " + encodedVerifier) - fmt.Println("challenge: " + encodedChallenge) - - return encodedVerifier, encodedChallenge + verifier := base64.RawURLEncoding.EncodeToString(verifier_source) + challenge := sha256.Sum256([]byte(verifier)) + return verifier, base64.RawURLEncoding.EncodeToString(challenge[:]) } func handleUiSignIn(c *fiber.Ctx) error { verifier, challenge := generatePKCE() - fmt.Println("handleUiSignIn verifier: " + verifier) - fmt.Println("handleUiSignIn challenge: " + challenge) - - cookie := new(fiber.Cookie) - cookie.Name = "jade-edgedb-pkce-verifier" - cookie.Value = verifier - cookie.Expires = time.Now().Add(10 * time.Minute) - c.Cookie(cookie) + c.Cookie(&fiber.Cookie{ + Name: "jade-edgedb-pkce-verifier", + Value: verifier, + HTTPOnly: true, + Path: "/", + Secure: true, + SameSite: "Strict", + }) return c.Redirect(fmt.Sprintf("%s/ui/signup?challenge=%s", EDGEDB_AUTH_BASE_URL, challenge), fiber.StatusTemporaryRedirect) } -func handleUiSignUp(c *fiber.Ctx) error { - verifier, challenge := generatePKCE() +func handleCallbackSignup(c *fiber.Ctx) error { + code := c.Query("code") + if code == "" { + error := c.Query("error") + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("OAuth callback is missing 'code'. OAuth provider responded with error: %s", error)) + } - fmt.Println("handleUiSignUp verifier: " + verifier) - fmt.Println("handleUiSignUp challenge: " + challenge) + verifier := c.Cookies("jade-edgedb-pkce-verifier") + if verifier == "" { + return c.Status(fiber.StatusBadRequest).SendString("Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?") + } - cookie := new(fiber.Cookie) - cookie.Name = "jade-edgedb-pkce-verifier" - cookie.Value = verifier - cookie.Expires = time.Now().Add(10 * time.Minute) - c.Cookie(cookie) + codeExchangeURL := fmt.Sprintf("%s/token?code=%s&verifier=%s", EDGEDB_AUTH_BASE_URL, code, verifier) + resp, err := http.Get(codeExchangeURL) + if err != nil { + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", err.Error())) + } + defer resp.Body.Close() - return c.Redirect(fmt.Sprintf("%s/ui/signup?challenge=%s", EDGEDB_AUTH_BASE_URL, challenge), fiber.StatusTemporaryRedirect) + if resp.StatusCode != fiber.StatusOK { + body, _ := io.ReadAll(resp.Body) + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", string(body))) + } + + var tokenResponse struct { + AuthToken string `json:"auth_token"` + IdentityID string `json:"identity_id"` + } + err = json.NewDecoder(resp.Body).Decode(&tokenResponse) + if err != nil { + return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error decoding auth server response: %s", err.Error())) + } + + c.Cookie(&fiber.Cookie{ + Name: "jade-edgedb-auth-token", + Value: tokenResponse.AuthToken, + HTTPOnly: true, + Path: "/", + Secure: true, + SameSite: "Strict", + }) + + // Create a new User and attach the identity + var identityUUID edgedb.UUID + identityUUID, err = edgedb.ParseUUID(tokenResponse.IdentityID) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Error in edgedb.ParseUUID: in handleCallbackSignup: %s", err.Error())) + } + + err = edgeClient.Execute(edgeCtx, ` + INSERT User { + setting := ( + INSERT Setting { + default_model := "gpt-3.5-turbo" + } + ), + identity := (SELECT ext::auth::Identity FILTER .id = $0) + } + `, identityUUID) + if err != nil { + return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Error in edgedb.QuerySingle: in handleCallbackSignup: %s", err.Error())) + } + + edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": tokenResponse.AuthToken}) + + return c.Redirect("/", fiber.StatusTemporaryRedirect) } func handleCallback(c *fiber.Ctx) error { @@ -78,11 +122,7 @@ func handleCallback(c *fiber.Ctx) error { return c.Status(fiber.StatusBadRequest).SendString("Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?") } - fmt.Println("handleCallback code: " + code) - fmt.Println("handleCallback verifier: " + verifier) - codeExchangeURL := fmt.Sprintf("%s/token?code=%s&verifier=%s", EDGEDB_AUTH_BASE_URL, code, verifier) - fmt.Println("codeExchangeURL: " + codeExchangeURL) resp, err := http.Get(codeExchangeURL) if err != nil { return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", err.Error())) @@ -103,7 +143,7 @@ func handleCallback(c *fiber.Ctx) error { } c.Cookie(&fiber.Cookie{ - Name: "edgedb-auth-token", + Name: "jade-edgedb-auth-token", Value: tokenResponse.AuthToken, HTTPOnly: true, Path: "/", @@ -111,8 +151,13 @@ func handleCallback(c *fiber.Ctx) error { SameSite: "Strict", }) - fmt.Println("Login successful") - fmt.Println("edgedb-auth-token: " + tokenResponse.AuthToken) + edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": tokenResponse.AuthToken}) - return c.SendStatus(fiber.StatusNoContent) + return c.Redirect("/", fiber.StatusTemporaryRedirect) +} + +func handleSignOut(c *fiber.Ctx) error { + c.ClearCookie("jade-edgedb-auth-token") + edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": ""}) + return c.Redirect("/", fiber.StatusTemporaryRedirect) } diff --git a/main.go b/main.go index 9bf8d42..4663818 100644 --- a/main.go +++ b/main.go @@ -1,8 +1,6 @@ package main import ( - "fmt" - "github.com/flosch/pongo2" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/logger" @@ -32,24 +30,26 @@ func main() { // Add static files app.Static("/", "./static") - // Add routes - app.Get("/", ChatPageHandler) // Complete chat page - app.Post("/requestMultipleMessages", RequestMultipleMessages) // Request multiple messages - app.Get("/loadChat", LoadChatHandler) // Load chat - app.Post("/deleteMessage", DeleteMessageHandler) // Delete message - app.Get("/loadModelSelection", LoadModelSelectionHandler) // Load model selection - app.Get("/loadUsageKPI", LoadUsageKPIHandler) // Load usage KPI - app.Get("/generateMultipleMessages", GenerateMultipleMessages) // Generate multiple messages + // Main routes + app.Get("/", ChatPageHandler) + app.Get("/loadChat", LoadChatHandler) + + // Chat routes + app.Post("/requestMultipleMessages", RequestMultipleMessages) + app.Post("/deleteMessage", DeleteMessageHandler) + app.Get("/generateMultipleMessages", GenerateMultipleMessages) app.Get("/messageContent", GetMessageContentHandler) - app.Get("/signin", handleUiSignIn) - app.Get("/signup", handleUiSignUp) - app.Get("/callback", handleCallback) + // Popovers + app.Get("/loadModelSelection", LoadModelSelectionHandler) + app.Get("/loadUsageKPI", LoadUsageKPIHandler) + + // Authentication + app.Get("/signin", handleUiSignIn) + app.Get("/signout", handleSignOut) + app.Get("/callback", handleCallback) + app.Get("/callbackSignup", handleCallbackSignup) - app.Get("test", func(c *fiber.Ctx) error { - fmt.Println("test") - return c.SendString("") - }) // Start server app.Listen(":8080") } diff --git a/static/style.css b/static/style.css index 3a511f3..be2ec5b 100644 --- a/static/style.css +++ b/static/style.css @@ -1,23 +1,10 @@ -.my-indicator { - display: none; -} - -.htmx-request .my-indicator { - display: inline; -} - -.htmx-request.my-indicator { - display: inline; -} - -svg text { - font-family: 'Russo One', sans-serif; - text-transform: uppercase; - fill: #000; - stroke: #000; - font-size: 240px; +body, +html { + height: 100%; + margin: 0; } +/* Stuff for message boxes */ .message-content { background-color: #303030; border-radius: 5px; @@ -28,6 +15,7 @@ svg text { overflow-x: auto; white-space: pre; max-width: 662px; + position: relative; } #chat-messages .message-content pre code { @@ -36,17 +24,23 @@ svg text { white-space: inherit; } +.copy-button { + position: absolute; + top: 5px; + right: 5px; + cursor: pointer; + border: none; + border-radius: 5px; + padding: 5px 10px; +} + .content { font-size: 14px; line-height: 1.5; color: #ffffff; } -.content pre { - font-family: monospace; - background-color: #363636 !important; - border-radius: 3px; -} +/* Style for the overall chat container */ #chat-messages { max-width: 780px; @@ -55,23 +49,13 @@ svg text { margin-bottom: 180px; } -#chat-input-form { - max-width: 900px; - margin: auto; - width: 98%; -} - +/* Primary color */ :root { --bulma-body-background-color: #202020; --bulma-primary: #126d0f; } -body, -html { - height: 100%; - margin: 0; -} - +/* Chat input stuff */ .chat-input-container { display: flex; flex-direction: column; @@ -106,4 +90,26 @@ html { display: flex; align-items: center; gap: 10px; +} + +/* Indicator */ +.my-indicator { + display: none; +} + +.htmx-request .my-indicator { + display: inline; +} + +.htmx-request.my-indicator { + display: inline; +} + +/* Logo */ +svg text { + font-family: 'Russo One', sans-serif; + text-transform: uppercase; + fill: #000; + stroke: #000; + font-size: 240px; } \ No newline at end of file diff --git a/utils.go b/utils.go index 39918ca..1247859 100644 --- a/utils.go +++ b/utils.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "regexp" "strings" "github.com/yuin/goldmark" @@ -19,7 +20,22 @@ func markdownToHTML(markdownText string) string { panic(err) // Handle the error appropriately } - return buf.String() + return addCopyButtonsToCode(buf.String()) +} + +func addCopyButtonsToCode(htmlContent string) string { + buttonHTML := `` + + // Regular expression pattern to match
 elements and insert the button right before 
+	pattern := `(]*>)`
+
+	// Compile the regular expression
+	re := regexp.MustCompile(pattern)
+
+	// Replace each matched 
 element with the updated HTML
+	updatedHTML := re.ReplaceAllString(htmlContent, "$1"+buttonHTML)
+
+	return updatedHTML
 }
 
 func model2Icon(model string) string {
diff --git a/views/partials/navbar.html b/views/partials/navbar.html
index 4bf6c20..cf0f535 100644
--- a/views/partials/navbar.html
+++ b/views/partials/navbar.html
@@ -12,12 +12,15 @@