diff --git a/Chat.go b/Chat.go
index 58e1698..0ec65bd 100644
--- a/Chat.go
+++ b/Chat.go
@@ -12,7 +12,15 @@ import (
)
func ChatPageHandler(c *fiber.Ctx) error {
- return c.Render("chat", fiber.Map{}, "layouts/main")
+ authCookie := c.Cookies("jade-edgedb-auth-token", "")
+
+ if authCookie != "" && !checkIfLogin() {
+ edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": authCookie})
+ }
+
+ fmt.Println("Current User: ", getCurrentUser())
+
+ return c.Render("chat", fiber.Map{"IsLogin": checkIfLogin()}, "layouts/main")
}
func LoadModelSelectionHandler(c *fiber.Ctx) error {
@@ -45,8 +53,6 @@ func LoadUsageKPIHandler(c *fiber.Ctx) error {
log.Fatal(err)
}
- fmt.Println(TotalUsage)
-
out, err := pongo2.Must(pongo2.FromFile("views/partials/usagePopover.html")).Execute(pongo2.Context{
"TotalUsage": TotalUsage,
})
@@ -85,7 +91,11 @@ func DeleteMessageHandler(c *fiber.Ctx) error {
}
func LoadChatHandler(c *fiber.Ctx) error {
- return c.SendString(generateChatHTML())
+ if checkIfLogin() {
+ return c.SendString(generateChatHTML())
+ } else {
+ return c.SendString(generateWelcomeChatHTML())
+ }
}
type NextMessage struct {
@@ -115,7 +125,6 @@ func generateChatHTML() string {
// Reset NextMessages when a user message is encountered
NextMessages = []NextMessage{}
} else {
- fmt.Println(i)
modelID, exist := message.ModelID.Get()
if !exist {
modelID = "gpt-3.5-turbo"
@@ -199,3 +208,28 @@ func GetMessageContentHandler(c *fiber.Ctx) error {
return c.SendString(out)
}
+
+func generateWelcomeChatHTML() string {
+ htmlString := "
"
+
+ NextMessages := []NextMessage{}
+ nextMsg := NextMessage{
+ Icon: "bouvai2", // Assuming Icon is a field you want to include from Message
+ Content: markdownToHTML("Hi, I'm Bouvai. How can I help you today?"),
+ Hidden: false, // Assuming Hidden is a field you want to include from Message
+ Id: "0",
+ Name: "JADE",
+ }
+ NextMessages = append(NextMessages, nextMsg)
+
+ botOut, err := botTmpl.Execute(pongo2.Context{"Messages": NextMessages, "ConversationAreaId": 0})
+ if err != nil {
+ panic(err)
+ }
+ htmlString += botOut
+ htmlString += "
"
+ htmlString += "
"
+
+ // Render the HTML template with the messages
+ return htmlString
+}
diff --git a/database.go b/database.go
index 2c6ecdc..e3a31d6 100644
--- a/database.go
+++ b/database.go
@@ -73,15 +73,8 @@ func init() {
log.Fatal(err)
}
- // TODO Change
edgeCtx = ctx
- var clientUUID edgedb.UUID
- clientUUID, err = edgedb.ParseUUID("9323365e-0b09-11ef-8f41-c3575d386283")
- if err != nil {
- fmt.Println("Error in edgedb.ParseUUID: in init")
- log.Fatal(err)
- }
- edgeClient = client.WithGlobals(map[string]interface{}{"current_user_id": clientUUID})
+ edgeClient = client
}
func getLastArea() edgedb.UUID {
@@ -99,15 +92,21 @@ func getLastArea() edgedb.UUID {
return inserted.id
}
-func checkIfLogin() bool {
+func getCurrentUser() User {
var result User
err := edgeClient.QuerySingle(edgeCtx, "SELECT global currentUser LIMIT 1;", &result)
if err != nil {
- fmt.Println("Error in edgedb.QuerySingle: in checkIfLogin")
+ fmt.Println("Error in edgedb.QuerySingle: in getCurrentUser")
fmt.Println(err)
- return false
+ return User{}
}
- return true
+ return result
+}
+
+func checkIfLogin() bool {
+ var result User
+ err := edgeClient.QuerySingle(edgeCtx, "SELECT global currentUser LIMIT 1;", &result)
+ return err == nil
}
func insertArea() edgedb.UUID {
diff --git a/dbschema/default.esdl b/dbschema/default.esdl
index 521cf3c..3445cbe 100644
--- a/dbschema/default.esdl
+++ b/dbschema/default.esdl
@@ -1,10 +1,11 @@
using extension auth;
module default {
- global current_user_id: uuid;
global currentUser := (
+ assert_single((
select User
- filter .id = global current_user_id
+ filter .identity = global ext::auth::ClientTokenIdentity
+ ))
);
type User {
@@ -58,13 +59,13 @@ module default {
}
type Usage {
- required model_id: str;
- required user: User;
- input_cost: float32;
- output_cost: float32;
- input_token: int32;
- output_token: int32;
- required date: datetime {
+ required model_id: str;
+ user: User;
+ input_cost: float32;
+ output_cost: float32;
+ input_token: int32;
+ output_token: int32;
+ required date: datetime {
default := datetime_current();
}
}
diff --git a/dbschema/migrations/00017-m1g3r5w.edgeql b/dbschema/migrations/00017-m1g3r5w.edgeql
new file mode 100644
index 0000000..964451e
--- /dev/null
+++ b/dbschema/migrations/00017-m1g3r5w.edgeql
@@ -0,0 +1,11 @@
+CREATE MIGRATION m1g3r5wwsplyqphor3yvb7dzp2txacgd6ed3udrpvakuoz3e2zj7ka
+ ONTO m1t2tddtq2grfxf4ldn6aj4yotvwiqclsrks6chobvs33d5mcd62mq
+{
+ DROP GLOBAL default::currentUser;
+ CREATE GLOBAL default::current_user := (std::assert_single((SELECT
+ default::User
+ FILTER
+ (.identity = GLOBAL ext::auth::ClientTokenIdentity)
+ )));
+ DROP GLOBAL default::current_user_id;
+};
diff --git a/dbschema/migrations/00018-m1c2b3o.edgeql b/dbschema/migrations/00018-m1c2b3o.edgeql
new file mode 100644
index 0000000..f4f7f37
--- /dev/null
+++ b/dbschema/migrations/00018-m1c2b3o.edgeql
@@ -0,0 +1,5 @@
+CREATE MIGRATION m1c2b3o4bhfgldk3kebneqyv6oxfgmorrp62uv5rreu773ryed3azq
+ ONTO m1g3r5wwsplyqphor3yvb7dzp2txacgd6ed3udrpvakuoz3e2zj7ka
+{
+ ALTER GLOBAL default::current_user RENAME TO default::currentUser;
+};
diff --git a/dbschema/migrations/00019-m1rzoj5.edgeql b/dbschema/migrations/00019-m1rzoj5.edgeql
new file mode 100644
index 0000000..229212c
--- /dev/null
+++ b/dbschema/migrations/00019-m1rzoj5.edgeql
@@ -0,0 +1,9 @@
+CREATE MIGRATION m1rzoj5rvhxkec6tsew6dc6bd3ksrrmf7t3k532avozbg4722xulhq
+ ONTO m1c2b3o4bhfgldk3kebneqyv6oxfgmorrp62uv5rreu773ryed3azq
+{
+ ALTER TYPE default::Usage {
+ ALTER LINK user {
+ RESET OPTIONALITY;
+ };
+ };
+};
diff --git a/login.go b/login.go
index f09506d..fee2cb2 100644
--- a/login.go
+++ b/login.go
@@ -8,62 +8,106 @@ import (
"fmt"
"io"
"net/http"
- "time"
+ "github.com/edgedb/edgedb-go"
"github.com/gofiber/fiber/v2"
)
const EDGEDB_AUTH_BASE_URL = "http://127.0.0.1:10700/db/main/ext/auth"
func generatePKCE() (string, string) {
- fmt.Println("Generating PKCE")
- verifier := make([]byte, 32)
- _, err := rand.Read(verifier)
+ verifier_source := make([]byte, 32)
+ _, err := rand.Read(verifier_source)
if err != nil {
panic(err)
}
- challenge := sha256.Sum256(verifier)
-
- var URLEncoding = base64.NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")
- var RawURLEncoding = URLEncoding.WithPadding(base64.NoPadding)
- encodedVerifier := RawURLEncoding.EncodeToString(verifier)
- encodedChallenge := RawURLEncoding.EncodeToString(challenge[:])
-
- fmt.Println("verifier: " + encodedVerifier)
- fmt.Println("challenge: " + encodedChallenge)
-
- return encodedVerifier, encodedChallenge
+ verifier := base64.RawURLEncoding.EncodeToString(verifier_source)
+ challenge := sha256.Sum256([]byte(verifier))
+ return verifier, base64.RawURLEncoding.EncodeToString(challenge[:])
}
func handleUiSignIn(c *fiber.Ctx) error {
verifier, challenge := generatePKCE()
- fmt.Println("handleUiSignIn verifier: " + verifier)
- fmt.Println("handleUiSignIn challenge: " + challenge)
-
- cookie := new(fiber.Cookie)
- cookie.Name = "jade-edgedb-pkce-verifier"
- cookie.Value = verifier
- cookie.Expires = time.Now().Add(10 * time.Minute)
- c.Cookie(cookie)
+ c.Cookie(&fiber.Cookie{
+ Name: "jade-edgedb-pkce-verifier",
+ Value: verifier,
+ HTTPOnly: true,
+ Path: "/",
+ Secure: true,
+ SameSite: "Strict",
+ })
return c.Redirect(fmt.Sprintf("%s/ui/signup?challenge=%s", EDGEDB_AUTH_BASE_URL, challenge), fiber.StatusTemporaryRedirect)
}
-func handleUiSignUp(c *fiber.Ctx) error {
- verifier, challenge := generatePKCE()
+func handleCallbackSignup(c *fiber.Ctx) error {
+ code := c.Query("code")
+ if code == "" {
+ error := c.Query("error")
+ return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("OAuth callback is missing 'code'. OAuth provider responded with error: %s", error))
+ }
- fmt.Println("handleUiSignUp verifier: " + verifier)
- fmt.Println("handleUiSignUp challenge: " + challenge)
+ verifier := c.Cookies("jade-edgedb-pkce-verifier")
+ if verifier == "" {
+ return c.Status(fiber.StatusBadRequest).SendString("Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?")
+ }
- cookie := new(fiber.Cookie)
- cookie.Name = "jade-edgedb-pkce-verifier"
- cookie.Value = verifier
- cookie.Expires = time.Now().Add(10 * time.Minute)
- c.Cookie(cookie)
+ codeExchangeURL := fmt.Sprintf("%s/token?code=%s&verifier=%s", EDGEDB_AUTH_BASE_URL, code, verifier)
+ resp, err := http.Get(codeExchangeURL)
+ if err != nil {
+ return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", err.Error()))
+ }
+ defer resp.Body.Close()
- return c.Redirect(fmt.Sprintf("%s/ui/signup?challenge=%s", EDGEDB_AUTH_BASE_URL, challenge), fiber.StatusTemporaryRedirect)
+ if resp.StatusCode != fiber.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", string(body)))
+ }
+
+ var tokenResponse struct {
+ AuthToken string `json:"auth_token"`
+ IdentityID string `json:"identity_id"`
+ }
+ err = json.NewDecoder(resp.Body).Decode(&tokenResponse)
+ if err != nil {
+ return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error decoding auth server response: %s", err.Error()))
+ }
+
+ c.Cookie(&fiber.Cookie{
+ Name: "jade-edgedb-auth-token",
+ Value: tokenResponse.AuthToken,
+ HTTPOnly: true,
+ Path: "/",
+ Secure: true,
+ SameSite: "Strict",
+ })
+
+ // Create a new User and attach the identity
+ var identityUUID edgedb.UUID
+ identityUUID, err = edgedb.ParseUUID(tokenResponse.IdentityID)
+ if err != nil {
+ return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Error in edgedb.ParseUUID: in handleCallbackSignup: %s", err.Error()))
+ }
+
+ err = edgeClient.Execute(edgeCtx, `
+ INSERT User {
+ setting := (
+ INSERT Setting {
+ default_model := "gpt-3.5-turbo"
+ }
+ ),
+ identity := (SELECT ext::auth::Identity FILTER .id = $0)
+ }
+ `, identityUUID)
+ if err != nil {
+ return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Error in edgedb.QuerySingle: in handleCallbackSignup: %s", err.Error()))
+ }
+
+ edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": tokenResponse.AuthToken})
+
+ return c.Redirect("/", fiber.StatusTemporaryRedirect)
}
func handleCallback(c *fiber.Ctx) error {
@@ -78,11 +122,7 @@ func handleCallback(c *fiber.Ctx) error {
return c.Status(fiber.StatusBadRequest).SendString("Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?")
}
- fmt.Println("handleCallback code: " + code)
- fmt.Println("handleCallback verifier: " + verifier)
-
codeExchangeURL := fmt.Sprintf("%s/token?code=%s&verifier=%s", EDGEDB_AUTH_BASE_URL, code, verifier)
- fmt.Println("codeExchangeURL: " + codeExchangeURL)
resp, err := http.Get(codeExchangeURL)
if err != nil {
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", err.Error()))
@@ -103,7 +143,7 @@ func handleCallback(c *fiber.Ctx) error {
}
c.Cookie(&fiber.Cookie{
- Name: "edgedb-auth-token",
+ Name: "jade-edgedb-auth-token",
Value: tokenResponse.AuthToken,
HTTPOnly: true,
Path: "/",
@@ -111,8 +151,13 @@ func handleCallback(c *fiber.Ctx) error {
SameSite: "Strict",
})
- fmt.Println("Login successful")
- fmt.Println("edgedb-auth-token: " + tokenResponse.AuthToken)
+ edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": tokenResponse.AuthToken})
- return c.SendStatus(fiber.StatusNoContent)
+ return c.Redirect("/", fiber.StatusTemporaryRedirect)
+}
+
+func handleSignOut(c *fiber.Ctx) error {
+ c.ClearCookie("jade-edgedb-auth-token")
+ edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": ""})
+ return c.Redirect("/", fiber.StatusTemporaryRedirect)
}
diff --git a/main.go b/main.go
index 9bf8d42..4663818 100644
--- a/main.go
+++ b/main.go
@@ -1,8 +1,6 @@
package main
import (
- "fmt"
-
"github.com/flosch/pongo2"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
@@ -32,24 +30,26 @@ func main() {
// Add static files
app.Static("/", "./static")
- // Add routes
- app.Get("/", ChatPageHandler) // Complete chat page
- app.Post("/requestMultipleMessages", RequestMultipleMessages) // Request multiple messages
- app.Get("/loadChat", LoadChatHandler) // Load chat
- app.Post("/deleteMessage", DeleteMessageHandler) // Delete message
- app.Get("/loadModelSelection", LoadModelSelectionHandler) // Load model selection
- app.Get("/loadUsageKPI", LoadUsageKPIHandler) // Load usage KPI
- app.Get("/generateMultipleMessages", GenerateMultipleMessages) // Generate multiple messages
+ // Main routes
+ app.Get("/", ChatPageHandler)
+ app.Get("/loadChat", LoadChatHandler)
+
+ // Chat routes
+ app.Post("/requestMultipleMessages", RequestMultipleMessages)
+ app.Post("/deleteMessage", DeleteMessageHandler)
+ app.Get("/generateMultipleMessages", GenerateMultipleMessages)
app.Get("/messageContent", GetMessageContentHandler)
- app.Get("/signin", handleUiSignIn)
- app.Get("/signup", handleUiSignUp)
- app.Get("/callback", handleCallback)
+ // Popovers
+ app.Get("/loadModelSelection", LoadModelSelectionHandler)
+ app.Get("/loadUsageKPI", LoadUsageKPIHandler)
+
+ // Authentication
+ app.Get("/signin", handleUiSignIn)
+ app.Get("/signout", handleSignOut)
+ app.Get("/callback", handleCallback)
+ app.Get("/callbackSignup", handleCallbackSignup)
- app.Get("test", func(c *fiber.Ctx) error {
- fmt.Println("test")
- return c.SendString("")
- })
// Start server
app.Listen(":8080")
}
diff --git a/static/style.css b/static/style.css
index 3a511f3..be2ec5b 100644
--- a/static/style.css
+++ b/static/style.css
@@ -1,23 +1,10 @@
-.my-indicator {
- display: none;
-}
-
-.htmx-request .my-indicator {
- display: inline;
-}
-
-.htmx-request.my-indicator {
- display: inline;
-}
-
-svg text {
- font-family: 'Russo One', sans-serif;
- text-transform: uppercase;
- fill: #000;
- stroke: #000;
- font-size: 240px;
+body,
+html {
+ height: 100%;
+ margin: 0;
}
+/* Stuff for message boxes */
.message-content {
background-color: #303030;
border-radius: 5px;
@@ -28,6 +15,7 @@ svg text {
overflow-x: auto;
white-space: pre;
max-width: 662px;
+ position: relative;
}
#chat-messages .message-content pre code {
@@ -36,17 +24,23 @@ svg text {
white-space: inherit;
}
+.copy-button {
+ position: absolute;
+ top: 5px;
+ right: 5px;
+ cursor: pointer;
+ border: none;
+ border-radius: 5px;
+ padding: 5px 10px;
+}
+
.content {
font-size: 14px;
line-height: 1.5;
color: #ffffff;
}
-.content pre {
- font-family: monospace;
- background-color: #363636 !important;
- border-radius: 3px;
-}
+/* Style for the overall chat container */
#chat-messages {
max-width: 780px;
@@ -55,23 +49,13 @@ svg text {
margin-bottom: 180px;
}
-#chat-input-form {
- max-width: 900px;
- margin: auto;
- width: 98%;
-}
-
+/* Primary color */
:root {
--bulma-body-background-color: #202020;
--bulma-primary: #126d0f;
}
-body,
-html {
- height: 100%;
- margin: 0;
-}
-
+/* Chat input stuff */
.chat-input-container {
display: flex;
flex-direction: column;
@@ -106,4 +90,26 @@ html {
display: flex;
align-items: center;
gap: 10px;
+}
+
+/* Indicator */
+.my-indicator {
+ display: none;
+}
+
+.htmx-request .my-indicator {
+ display: inline;
+}
+
+.htmx-request.my-indicator {
+ display: inline;
+}
+
+/* Logo */
+svg text {
+ font-family: 'Russo One', sans-serif;
+ text-transform: uppercase;
+ fill: #000;
+ stroke: #000;
+ font-size: 240px;
}
\ No newline at end of file
diff --git a/utils.go b/utils.go
index 39918ca..1247859 100644
--- a/utils.go
+++ b/utils.go
@@ -2,6 +2,7 @@ package main
import (
"bytes"
+ "regexp"
"strings"
"github.com/yuin/goldmark"
@@ -19,7 +20,22 @@ func markdownToHTML(markdownText string) string {
panic(err) // Handle the error appropriately
}
- return buf.String()
+ return addCopyButtonsToCode(buf.String())
+}
+
+func addCopyButtonsToCode(htmlContent string) string {
+ buttonHTML := ` `
+
+ // Regular expression pattern to match elements and insert the button right before
+ pattern := `(]*>)`
+
+ // Compile the regular expression
+ re := regexp.MustCompile(pattern)
+
+ // Replace each matched element with the updated HTML
+ updatedHTML := re.ReplaceAllString(htmlContent, "$1"+buttonHTML)
+
+ return updatedHTML
}
func model2Icon(model string) string {
diff --git a/views/partials/navbar.html b/views/partials/navbar.html
index 4bf6c20..cf0f535 100644
--- a/views/partials/navbar.html
+++ b/views/partials/navbar.html
@@ -12,12 +12,15 @@