From 1125f86331dead248af8e5638a33afcb86e88a4f Mon Sep 17 00:00:00 2001 From: Adrien Date: Sun, 16 Jun 2024 00:23:30 +0200 Subject: [PATCH] SSE per user --- Chat.go | 18 ++++--- Request.go | 15 +++--- RequestGoogle.go | 4 +- main.go | 89 ++++++++++++++++++--------------- views/chat.html | 3 +- views/partials/message-bot.html | 6 +-- 6 files changed, 74 insertions(+), 61 deletions(-) diff --git a/Chat.go b/Chat.go index ee2871b..f88a918 100644 --- a/Chat.go +++ b/Chat.go @@ -126,7 +126,16 @@ func generateChatHTML(c *fiber.Ctx) string { panic(err) } - htmlString := "
" + var user User + err = edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).QuerySingle(edgeCtx, ` + SELECT global currentUser { id } LIMIT 1 + `, &user) + if err != nil { + fmt.Println("Error getting user") + panic(err) + } + + htmlString := "
" var templateMessages []TemplateMessage @@ -177,13 +186,6 @@ func generateChatHTML(c *fiber.Ctx) string { } } - // out, err := messagesPlaceholderTmpl.Execute(pongo2.Context{}) - // if err != nil { - // fmt.Println("Error executing message user placeholder template") - // panic(err) - // } - // htmlString += out - htmlString += "
" // Render the HTML template with the messages diff --git a/Request.go b/Request.go index 52783af..b6c5ceb 100644 --- a/Request.go +++ b/Request.go @@ -230,18 +230,20 @@ func GenerateMultipleMessagesHandler(c *fiber.Ctx) error { outIcon := `User Image` go func() { - // I do a ping because of sse size limit - fmt.Println("Sending event: ", "swapContent-"+fmt.Sprintf("%d", message.Area.Position)+"-"+user.ID.String()) + // I do a ping because of sse size limit. Do see if it's possible to do without it TODO sendEvent( - "swapContent-"+fmt.Sprintf("%d", message.Area.Position)+"-"+user.ID.String(), + user.ID.String(), + "swapContent-"+fmt.Sprintf("%d", message.Area.Position), ``, ) sendEvent( - "swapSelectionBtn-"+selectedLLMs[idx].ID.String()+"-"+user.ID.String(), + user.ID.String(), + "swapSelectionBtn-"+selectedLLMs[idx].ID.String(), outBtn, ) sendEvent( - "swapIcon-"+fmt.Sprintf("%d", message.Area.Position)+"-"+user.ID.String(), + user.ID.String(), + "swapIcon-"+fmt.Sprintf("%d", message.Area.Position), outIcon, ) }() @@ -261,7 +263,8 @@ func GenerateMultipleMessagesHandler(c *fiber.Ctx) error { // Send Content event go func() { sendEvent( - "swapSelectionBtn-"+selectedLLMs[idx].ID.String()+"-"+user.ID.String(), + user.ID.String(), + "swapSelectionBtn-"+selectedLLMs[idx].ID.String(), outBtn, ) }() diff --git a/RequestGoogle.go b/RequestGoogle.go index 4b528ed..99831e6 100644 --- a/RequestGoogle.go +++ b/RequestGoogle.go @@ -154,12 +154,12 @@ func RequestGoogle(c *fiber.Ctx, model string, messages []Message, temperature f if message.Role == "user" { googleMessages = append(googleMessages, GoogleRequestMessage{ Role: "user", - Parts: []GooglePart{GooglePart{Text: message.Content}}, + Parts: []GooglePart{{Text: message.Content}}, // Changed something here, to test }) } else { googleMessages = append(googleMessages, GoogleRequestMessage{ Role: "model", - Parts: []GooglePart{GooglePart{Text: message.Content}}, + Parts: []GooglePart{{Text: message.Content}}, }) } } diff --git a/main.go b/main.go index 152df8f..cfa487e 100644 --- a/main.go +++ b/main.go @@ -27,9 +27,9 @@ var ( welcomeChatTmpl *pongo2.Template chatInputTmpl *pongo2.Template explainLLMconvChatTmpl *pongo2.Template - messagesPlaceholderTmpl *pongo2.Template - clients = make(map[chan SSE]bool) mu sync.Mutex + app *fiber.App + userSSEChannels = make(map[string]chan SSE) ) // SSE event structure @@ -39,13 +39,16 @@ type SSE struct { } // Function to send events to all clients -func sendEvent(event, data string) { +func sendEvent(userID string, event string, data string) { mu.Lock() defer mu.Unlock() - for client := range clients { - client <- SSE{Event: event, Data: data} + userEvents, ok := userSSEChannels[userID] + if !ok { + return } + + userEvents <- SSE{Event: event, Data: data} } func main() { @@ -63,13 +66,12 @@ func main() { welcomeChatTmpl = pongo2.Must(pongo2.FromFile("views/partials/welcome-chat.html")) chatInputTmpl = pongo2.Must(pongo2.FromFile("views/partials/chat-input.html")) explainLLMconvChatTmpl = pongo2.Must(pongo2.FromFile("views/partials/explain-llm-conv-chat.html")) - messagesPlaceholderTmpl = pongo2.Must(pongo2.FromFile("views/partials/messages-placeholder.html")) // Import HTML using django engine/template engine := django.New("./views", ".html") // Create new Fiber instance - app := fiber.New(fiber.Config{ + app = fiber.New(fiber.Config{ Views: engine, AppName: "JADE", }) @@ -133,39 +135,7 @@ func main() { return c.SendString("") }) - app.Get("/sse", func(c *fiber.Ctx) error { - c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - - events := make(chan SSE, 500) - mu.Lock() - clients[events] = true - mu.Unlock() - - // Create a context copy to use in the goroutine - ctx := c.Context() - - go func() { - <-ctx.Done() - mu.Lock() - delete(clients, events) - mu.Unlock() - close(events) - }() - - c.Context().SetBodyStreamWriter(func(w *bufio.Writer) { - for event := range events { - if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event.Event, event.Data); err != nil { - fmt.Println(err) - return - } - w.Flush() - } - }) - - return nil - }) + app.Get("/sse", handleSSE) // Start server if err := app.Listen(":8080"); err != nil { @@ -173,6 +143,45 @@ func main() { } } +func handleSSE(c *fiber.Ctx) error { + userID := c.Query("userID") // Get userID from query parameter + if userID == "" { + return c.Status(fiber.StatusBadRequest).SendString("Missing userID") + } + + events := make(chan SSE, 500) + mu.Lock() + userSSEChannels[userID] = events + mu.Unlock() + + // Create a context copy to use in the goroutine + ctx := c.Context() + + go func() { + <-ctx.Done() + mu.Lock() + delete(userSSEChannels, userID) + mu.Unlock() + close(events) + }() + + c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + + c.Context().SetBodyStreamWriter(func(w *bufio.Writer) { + for event := range events { + if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event.Event, event.Data); err != nil { + fmt.Println(err) + return + } + w.Flush() + } + }) + + return nil +} + func addKeys(c *fiber.Ctx) error { keys := map[string]string{ "openai": c.FormValue("openai_key"), diff --git a/views/chat.html b/views/chat.html index 2a9a359..1c014c0 100644 --- a/views/chat.html +++ b/views/chat.html @@ -1,5 +1,4 @@ -
+
-
\ No newline at end of file diff --git a/views/partials/message-bot.html b/views/partials/message-bot.html index 48294ae..c77487c 100644 --- a/views/partials/message-bot.html +++ b/views/partials/message-bot.html @@ -6,7 +6,7 @@ {% if IsPlaceholder %}
+ sse-swap="swapIcon-{{ ConversationAreaId }}"> User Image
@@ -72,7 +72,7 @@ {% elif IsPlaceholder %}
+ sse-swap="swapContent-{{ ConversationAreaId }}">

@@ -98,7 +98,7 @@ {% for selectedLLM in SelectedLLMs %}