Jade/Request.go
2024-05-20 09:24:02 +02:00

214 lines
5.3 KiB
Go

package main
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/edgedb/edgedb-go"
"github.com/flosch/pongo2"
"github.com/gofiber/fiber/v2"
)
type RequestMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
var lastSelectedLLMs []LLM
func GeneratePlaceholderHandler(c *fiber.Ctx) error {
// Step 1 I create a User message and send it as output with a placeholder
// that will make a request to GenerateMultipleMessagesHandler when loading
message := c.FormValue("message", "")
selectedLLMIds := []string{"1e5a07c4-12fe-11ef-8da6-67d29b408c53", "3cd15ca8-1433-11ef-9f22-93f2b78c78de", "95774e62-1447-11ef-bfea-33f555b75c17", "af3d8686-1447-11ef-bfea-07d880a979ff", "be7a922a-1478-11ef-a819-238de8775b87"} // TODO Hanle in the UI
var selectedLLMs []LLM
var selectedLLM LLM
for _, id := range selectedLLMIds {
idUUID, _ := edgedb.ParseUUID(id)
err := edgeClient.QuerySingle(context.Background(), `
SELECT LLM {
id,
name,
context,
temperature,
modelInfo : {
modelID,
maxToken,
company : {
icon,
name
}
}
}
FILTER
.id = <uuid>$0;
`, &selectedLLM, idUUID)
if err != nil {
panic(err)
}
selectedLLMs = append(selectedLLMs, selectedLLM)
}
lastSelectedLLMs = selectedLLMs
_, position := insertArea()
messageID := insertUserMessage(message)
out := ""
messageOut, _ := userTmpl.Execute(pongo2.Context{"Content": markdownToHTML(message), "ID": messageID.String()})
out += messageOut
messageOut, _ = botTmpl.Execute(pongo2.Context{"IsPlaceholder": true, "SelectedLLMs": selectedLLMs, "ConversationAreaId": position + 1})
out += messageOut
return c.SendString(out)
}
func GenerateMultipleMessagesHandler(c *fiber.Ctx) error {
// Step 2 generate multiple messages
// And send them one by one using events
insertArea()
selectedLLMs := lastSelectedLLMs
// Create a wait group to synchronize the goroutines
var wg sync.WaitGroup
// Add the length of selectedModelIds goroutines to the wait group
wg.Add(len(selectedLLMs))
// Create a channel to receive the index of the first completed goroutine
firstDone := make(chan int, 1)
for i := range selectedLLMs {
idx := i
go func() {
defer wg.Done()
// Create a context with a 1-minute timeout
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() // Ensure the context is cancelled to free resources
// Determine which message function to call based on the model
var addMessageFunc func(selectedLLM LLM, selected bool) edgedb.UUID
switch selectedLLMs[idx].Model.Company.Name {
case "openai":
addMessageFunc = addOpenaiMessage
case "anthropic":
addMessageFunc = addAnthropicMessage
case "mistral":
addMessageFunc = addMistralMessage
case "groq":
addMessageFunc = addGroqMessage
case "gooseai":
addMessageFunc = addGooseaiMessage
}
var messageID edgedb.UUID
if addMessageFunc != nil {
messageID = addMessageFunc(selectedLLMs[idx], idx == 0)
}
var message Message
err := edgeClient.QuerySingle(edgeCtx, `
SELECT Message {
id,
content,
area : {
id,
position
},
llm : {
modelInfo : {
modelID,
name,
company : {
icon,
}
}
}
}
FILTER .id = <uuid>$0;
`, &message, messageID)
if err != nil {
panic(err)
}
templateMessage := TemplateMessage{
Icon: message.LLM.Model.Company.Icon,
Content: message.Content,
Hidden: false,
Id: message.ID.String(),
Name: message.LLM.Model.Name,
ModelID: message.LLM.Model.ModelID,
}
// Check if the context's deadline is exceeded
select {
case <-ctx.Done():
// The context's deadline was exceeded
fmt.Printf("Goroutine %d timed out\n", idx)
default:
// Send the index of the completed goroutine to the firstDone channel
select {
case firstDone <- idx:
// Generate the HTML content
outBtn, err := selectBtnTmpl.Execute(map[string]interface{}{
"message": templateMessage,
"ConversationAreaId": message.Area.Position,
})
if err != nil {
panic(err)
}
outBtn = strings.ReplaceAll(outBtn, "\n", "")
outIcon := `<img src="` + selectedLLMs[idx].Model.Company.Icon + `" alt="User Image" id="selectedIcon-` + fmt.Sprintf("%d", message.Area.Position) + `">`
go func() {
sendEvent(
"swapContent-"+fmt.Sprintf("%d", message.Area.Position),
`<hx hx-get="/messageContent?id=`+message.ID.String()+`" hx-trigger="load" hx-swap="outerHTML"></hx>`,
)
sendEvent(
"swapSelectionBtn-"+selectedLLMs[idx].ID.String(),
outBtn,
)
sendEvent(
"swapIcon-"+fmt.Sprintf("%d", message.Area.Position),
outIcon,
)
}()
default:
out, err := selectBtnTmpl.Execute(map[string]interface{}{
"message": templateMessage,
"ConversationAreaId": message.Area.Position,
})
if err != nil {
panic(err)
}
// Replace newline characters to prevent premature termination
outBtn := strings.ReplaceAll(out, "\n", "")
// Send Content event
go func() {
sendEvent(
"swapSelectionBtn-"+selectedLLMs[idx].ID.String(),
outBtn,
)
}()
}
}
}()
}
// Wait for all goroutines to finish
wg.Wait()
return c.SendString("")
}