211 lines
5.1 KiB
Go
211 lines
5.1 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/edgedb/edgedb-go"
|
|
"github.com/flosch/pongo2"
|
|
"github.com/gofiber/fiber/v2"
|
|
)
|
|
|
|
type ModelInfo struct {
|
|
ID string
|
|
Name string
|
|
Icon string
|
|
MaxToken int
|
|
InputPrice float32
|
|
OutputPrice float32
|
|
}
|
|
|
|
type CompanyInfo struct {
|
|
ID string
|
|
Name string
|
|
Icon string
|
|
ModelInfos []ModelInfo
|
|
}
|
|
|
|
type SelectedModel struct {
|
|
ID string
|
|
Name string
|
|
Icon string
|
|
}
|
|
|
|
var CompanyInfos []CompanyInfo
|
|
var ModelsInfos []ModelInfo
|
|
|
|
func GenerateMultipleMessages(c *fiber.Ctx) error {
|
|
message := c.FormValue("message", "")
|
|
selectedModelIds := []string{}
|
|
for ModelInfo := range ModelsInfos {
|
|
out := c.FormValue("model-check-" + ModelsInfos[ModelInfo].ID)
|
|
if out != "" {
|
|
selectedModelIds = append(selectedModelIds, ModelsInfos[ModelInfo].ID)
|
|
}
|
|
}
|
|
|
|
_, position := insertArea()
|
|
messageID := insertUserMessage(message)
|
|
|
|
out := ""
|
|
messageOut, _ := userTmpl.Execute(pongo2.Context{"Content": markdownToHTML(message), "ID": messageID.String()})
|
|
out += messageOut
|
|
|
|
var selectedModels []SelectedModel
|
|
for i := range selectedModelIds {
|
|
selectedModels = append(selectedModels, SelectedModel{ID: selectedModelIds[i], Name: model2Name(selectedModelIds[i]), Icon: model2Icon(selectedModelIds[i])})
|
|
}
|
|
|
|
messageOut, _ = botTmpl.Execute(pongo2.Context{"IsPlaceholder": true, "selectedModels": selectedModels, "ConversationAreaId": position + 1})
|
|
out += messageOut
|
|
|
|
go HandleGenerateMultipleMessages(selectedModelIds)
|
|
|
|
return c.SendString(out)
|
|
}
|
|
|
|
func HandleGenerateMultipleMessages(selectedModelIds []string) {
|
|
insertArea()
|
|
|
|
// Create a wait group to synchronize the goroutines
|
|
var wg sync.WaitGroup
|
|
|
|
// Add the length of selectedModelIds goroutines to the wait group
|
|
wg.Add(len(selectedModelIds))
|
|
|
|
// Create a channel to receive the index of the first completed goroutine
|
|
firstDone := make(chan int, 1)
|
|
|
|
for i := range selectedModelIds {
|
|
idx := i
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
// Create a context with a 1-minute timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
|
defer cancel() // Ensure the context is cancelled to free resources
|
|
|
|
// Determine which message function to call based on the model
|
|
var addMessageFunc func(modelID string, selected bool) edgedb.UUID
|
|
switch model2Icon(selectedModelIds[idx]) {
|
|
case "openai":
|
|
addMessageFunc = addOpenaiMessage
|
|
case "anthropic":
|
|
addMessageFunc = addAnthropicMessage
|
|
case "mistral":
|
|
addMessageFunc = addMistralMessage
|
|
case "groq":
|
|
addMessageFunc = addGroqMessage
|
|
}
|
|
|
|
var messageID edgedb.UUID
|
|
if addMessageFunc != nil {
|
|
messageID = addMessageFunc(selectedModelIds[idx], idx == 0)
|
|
}
|
|
|
|
var message Message
|
|
err := edgeClient.QuerySingle(edgeCtx, `
|
|
SELECT Message {
|
|
model_id,
|
|
content,
|
|
area
|
|
}
|
|
FILTER .id = <uuid>$0;
|
|
`, &message, messageID)
|
|
if err != nil {
|
|
fmt.Println("Error in edgedb.QuerySingle: in GenerateMultipleMessages")
|
|
log.Fatal(err)
|
|
}
|
|
modelID, _ := message.ModelID.Get()
|
|
|
|
var area Area
|
|
err = edgeClient.QuerySingle(edgeCtx, `
|
|
SELECT Area {
|
|
position
|
|
}
|
|
FILTER .id = <uuid>$0;
|
|
`, &area, message.Area.ID)
|
|
if err != nil {
|
|
fmt.Println("Error in edgedb.QuerySingle: in GenerateMultipleMessages")
|
|
log.Fatal(err)
|
|
}
|
|
|
|
fmt.Println(area)
|
|
|
|
// Check if the context's deadline is exceeded
|
|
select {
|
|
case <-ctx.Done():
|
|
// The context's deadline was exceeded
|
|
fmt.Printf("Goroutine %d timed out\n", idx)
|
|
default:
|
|
// Send the index of the completed goroutine to the firstDone channel
|
|
select {
|
|
case firstDone <- idx:
|
|
// Generate the HTML content
|
|
out := "<div class='message-header'>"
|
|
out += "<p>"
|
|
out += model2Name(modelID)
|
|
out += " </p>"
|
|
out += "</div>"
|
|
out += "<div class='message-body'>"
|
|
out += " <ct class='content'>"
|
|
out += markdownToHTML(message.Content)
|
|
out += " </ct>"
|
|
out += "</div>"
|
|
|
|
fmt.Println("Sending event from first")
|
|
fmt.Println("swapContent-" + fmt.Sprintf("%d", area.Position))
|
|
|
|
// Send Content event
|
|
sseChanel.SendEvent(
|
|
"swapContent-"+fmt.Sprintf("%d", area.Position),
|
|
out,
|
|
)
|
|
|
|
out, err := modelSelecBtnTmpl.Execute(map[string]interface{}{
|
|
"message": message,
|
|
})
|
|
if err != nil {
|
|
fmt.Println("Error in modelSelecBtnTmpl.Execute: in GenerateMultipleMessages")
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Send Content event
|
|
sseChanel.SendEvent(
|
|
"swapSelectionBtn-"+modelID,
|
|
out,
|
|
)
|
|
|
|
// Send Icon Swap event
|
|
sseChanel.SendEvent(
|
|
"swapIcon-"+fmt.Sprintf("%d", area.Position),
|
|
`<img src="icons/`+model2Icon(modelID)+`.png" alt="User Image">`,
|
|
)
|
|
default:
|
|
out, err := modelSelecBtnTmpl.Execute(map[string]interface{}{
|
|
"message": message,
|
|
})
|
|
if err != nil {
|
|
fmt.Println("Error in modelSelecBtnTmpl.Execute: in GenerateMultipleMessages")
|
|
log.Fatal(err)
|
|
}
|
|
|
|
fmt.Println(("Sending event"))
|
|
|
|
// Send Content event
|
|
sseChanel.SendEvent(
|
|
"swapSelectionBtn-"+modelID,
|
|
out,
|
|
)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Wait for all goroutines to finish
|
|
wg.Wait()
|
|
}
|