Jade/LLM.go

189 lines
4.9 KiB
Go

package main
import (
"encoding/json"
"fmt"
"strconv"
"github.com/edgedb/edgedb-go"
"github.com/gofiber/fiber/v2"
)
// LLM stuff
func deleteLLM(c *fiber.Ctx) error {
var selectedLLMIds []string
err := json.Unmarshal([]byte(c.FormValue("selectedLLMIds")), &selectedLLMIds)
if err != nil {
fmt.Println("Error unmarshalling selected LLM IDs")
panic(err)
}
for _, id := range selectedLLMIds {
idUUID, _ := edgedb.ParseUUID(id)
err := edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).Execute(edgeCtx, `
UPDATE LLM
FILTER .id = <uuid>$0 AND .user = global currentUser
SET {
to_delete := true
};
`, idUUID)
if err != nil {
fmt.Println("Error deleting LLM")
panic(err)
}
}
deleteLLMtoDelete(c)
return c.SendString(GenerateModelPopoverHTML(true, c))
}
func deleteLLMtoDelete(c *fiber.Ctx) {
err := edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).Execute(edgeCtx, `
delete LLM
filter .to_delete = true and not exists(
select Message filter .llm = LLM
);
`)
if err != nil {
panic(err)
}
}
func createLLM(c *fiber.Ctx) error {
name := c.FormValue("model-name-input")
modelID := c.FormValue("selectedLLMId")
temperature := c.FormValue("temperature-slider")
f, _ := strconv.ParseFloat(temperature, 32)
temperatureFloat := float32(f)
systemPrompt := c.FormValue("model-prompt-input")
url := c.FormValue("model-url-input")
token := c.FormValue("model-key-input")
customID := c.FormValue("model-cid-input")
maxTokenStr := c.FormValue("max-token-input")
maxToken, err := strconv.Atoi(maxTokenStr)
if err != nil {
maxToken = 1024
}
fmt.Println("Adding LLM with maxtoken:", maxToken)
// TODO change the company
if modelID == "custom" {
err := edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).Execute(edgeCtx, `
WITH
countLLM := count((SELECT LLM FILTER .user = global currentUser))
INSERT LLM {
name := <str>$0,
context := <str>$1,
temperature := <float32>$2,
position := countLLM + 1,
max_tokens := <int32>$6,
modelInfo := (INSERT ModelInfo {
name := <str>$0,
modelID := <str>$5,
inputPrice := 0.0,
outputPrice := 0.0,
company := (SELECT Company FILTER .name = "huggingface" LIMIT 1),
}),
custom_endpoint := (INSERT CustomEndpoint {
endpoint := <str>$3,
key := <str>$4,
}),
user := global currentUser
};
`, name, systemPrompt, temperatureFloat, url, token, customID, int32(maxToken))
if err != nil {
fmt.Println("Error creating LLM")
panic(err)
}
} else {
err := edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).Execute(edgeCtx, `
WITH
countLLM := count((SELECT LLM FILTER .user = global currentUser))
INSERT LLM {
name := <str>$0,
context := <str>$1,
temperature := <float32>$2,
position := countLLM + 1,
max_tokens := <int32>$4,
modelInfo := (SELECT ModelInfo FILTER .modelID = <str>$3 LIMIT 1),
user := global currentUser
}
`, name, systemPrompt, temperatureFloat, modelID, int32(maxToken))
if err != nil {
fmt.Println("Error creating LLM")
panic(err)
}
}
return c.SendString(GenerateModelPopoverHTML(true, c))
}
type PositionUpdate struct {
Position int `json:"position"`
ID string `json:"id"`
}
func updateLLMPositionBatch(c *fiber.Ctx) error {
var positionUpdates []PositionUpdate
if err := c.BodyParser(&positionUpdates); err != nil {
return err
}
for _, update := range positionUpdates {
idUUID, err := edgedb.ParseUUID(update.ID)
if err != nil {
fmt.Println("Error parsing UUID")
panic(err)
}
err = edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).Execute(edgeCtx, `
UPDATE LLM
FILTER .id = <uuid>$0 AND .user = global currentUser
SET {
position := <int32>$1
};
`, idUUID, int32(update.Position))
if err != nil {
fmt.Println("Error updating LLM position")
panic(err)
}
}
return nil
}
// When we reorder the LLM list
func updateConversationPositionBatch(c *fiber.Ctx) error {
var positionUpdates []PositionUpdate
if err := c.BodyParser(&positionUpdates); err != nil {
return err
}
for _, update := range positionUpdates {
fmt.Println(update.ID)
idUUID, err := edgedb.ParseUUID(update.ID)
if err != nil {
fmt.Println("Error parsing UUID")
panic(err)
}
err = edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).Execute(edgeCtx, `
UPDATE Conversation
FILTER .id = <uuid>$0 AND .user = global currentUser
SET {
position := <int32>$1
};
`, idUUID, int32(update.Position))
if err != nil {
fmt.Println("Error updating conversation position")
panic(err)
}
}
return nil
}