113 lines
2.9 KiB
Go
113 lines
2.9 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/flosch/pongo2"
|
|
"github.com/gofiber/fiber/v2"
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
)
|
|
|
|
type ModelInfo struct {
|
|
ID string
|
|
Name string
|
|
Icon string
|
|
MaxToken int
|
|
InputPrice float64
|
|
OutputPrice float64
|
|
}
|
|
|
|
type CompanyInfo struct {
|
|
ID string
|
|
Name string
|
|
Icon string
|
|
ModelInfos []ModelInfo
|
|
}
|
|
|
|
var CompanyInfos []CompanyInfo
|
|
var ModelsInfos []ModelInfo
|
|
|
|
type MultipleModelsCompletionRequest struct {
|
|
ModelIds []string
|
|
Messages []Message
|
|
Message string
|
|
}
|
|
|
|
type BotContentMessage struct {
|
|
Content string
|
|
Hidden bool
|
|
}
|
|
|
|
var lastSelectedModelIds []string
|
|
|
|
func GenerateMultipleMessages(c *fiber.Ctx) error {
|
|
// Create a wait group to synchronize the goroutines
|
|
var wg sync.WaitGroup
|
|
var InsertedIDs []string
|
|
|
|
// Add the length of lastSelectedModelIds goroutines to the wait group
|
|
wg.Add(len(lastSelectedModelIds))
|
|
|
|
for i := range lastSelectedModelIds {
|
|
if model2Icon(lastSelectedModelIds[i]) == "openai" {
|
|
go func() {
|
|
defer wg.Done()
|
|
response := addOpenaiMessage(lastSelectedModelIds[i], i == 0)
|
|
InsertedIDs = append(InsertedIDs, response)
|
|
}()
|
|
} else if model2Icon(lastSelectedModelIds[i]) == "anthropic" {
|
|
go func() {
|
|
defer wg.Done()
|
|
response := addAnthropicMessage(lastSelectedModelIds[i], i == 0)
|
|
InsertedIDs = append(InsertedIDs, response)
|
|
}()
|
|
}
|
|
}
|
|
|
|
// Wait for both goroutines to finish
|
|
wg.Wait()
|
|
|
|
collection := mongoClient.Database("chat").Collection("messages")
|
|
for i := range InsertedIDs {
|
|
objectID, _ := primitive.ObjectIDFromHex(InsertedIDs[i])
|
|
collection.UpdateOne(context.Background(), bson.M{"_id": objectID}, bson.M{"$set": bson.M{"linked_message_ids": InsertedIDs}})
|
|
}
|
|
|
|
return c.SendString(generateChatHTML())
|
|
}
|
|
|
|
func RequestMultipleMessages(c *fiber.Ctx) error {
|
|
message := c.FormValue("message")
|
|
if chatString, commandExecuted := DetectCommand(message, c); commandExecuted {
|
|
return c.SendString(chatString)
|
|
}
|
|
|
|
collection := mongoClient.Database("chat").Collection("messages")
|
|
result, _ := collection.InsertOne(context.Background(), bson.M{"message": message, "role": "user", "date": time.Now()})
|
|
|
|
selectedModelIds := []string{}
|
|
for CompanyInfo := range CompanyInfos {
|
|
for ModelInfo := range CompanyInfos[CompanyInfo].ModelInfos {
|
|
out := c.FormValue("model-check-" + CompanyInfos[CompanyInfo].ModelInfos[ModelInfo].ID)
|
|
if out != "" {
|
|
selectedModelIds = append(selectedModelIds, CompanyInfos[CompanyInfo].ModelInfos[ModelInfo].ID)
|
|
}
|
|
}
|
|
}
|
|
lastSelectedModelIds = selectedModelIds
|
|
|
|
out := ""
|
|
|
|
HexID := result.InsertedID.(primitive.ObjectID).Hex()
|
|
messageOut, _ := userTmpl.Execute(pongo2.Context{"Content": markdownToHTML(message), "ID": HexID})
|
|
out += messageOut
|
|
|
|
messageOut, _ = botTmpl.Execute(pongo2.Context{"IsPlaceholder": true, "SelectedModelIds": selectedModelIds, "Message": message})
|
|
out += messageOut
|
|
|
|
return c.SendString(out)
|
|
}
|