From 103d6371c8d6e030b14941e8c5edd2af2968e239 Mon Sep 17 00:00:00 2001 From: Adrien Date: Thu, 23 May 2024 00:26:13 +0200 Subject: [PATCH] Working custom endpoint --- Chat.go | 11 ++----- LLM.go | 43 ++++++++++++++++++++++++---- RequestGooseai.go | 10 ++++++- RequestHuggingface.go | 2 ++ RequestMistral.go | 3 ++ utils.go | 2 +- views/partials/message-bot.html | 2 +- views/partials/popover-models.html | 40 +++++++++++++++++++++----- views/partials/popover-settings.html | 25 ++++++++-------- 9 files changed, 102 insertions(+), 36 deletions(-) diff --git a/Chat.go b/Chat.go index 39fde56..16ffd5a 100644 --- a/Chat.go +++ b/Chat.go @@ -156,7 +156,7 @@ func GetMessageContentHandler(c *fiber.Ctx) error { out := "
" out += "

" - out += "" + selectedMessage.LLM.Name + " " + selectedMessage.LLM.Model.Name + "" + out += "" + selectedMessage.LLM.Name + " " + selectedMessage.LLM.Model.ModelID + "" out += "

" out += "
" out += "
" @@ -449,19 +449,12 @@ func GenerateModelPopoverHTML(refresh bool) string { } } } - FILTER .user = global currentUser AND .name != 'none' AND .to_delete = false + FILTER .user = global currentUser AND .name != 'none' AND .to_delete = false `, &llms) if err != nil { panic(err) } - //for i := 0; i < len(llms); i++ { - // // If the modelID len is higher than 15, truncate it - // if len(llms[i].Model.ModelID) > 12 { - // llms[i].Model.ModelID = llms[i].Model.ModelID[0:12] + "..." - // } - //} - modelInfos := GetAvailableModels() out, err := pongo2.Must(pongo2.FromFile("views/partials/popover-models.html")).Execute(pongo2.Context{ diff --git a/LLM.go b/LLM.go index ff70b71..1692c40 100644 --- a/LLM.go +++ b/LLM.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "fmt" "strconv" "github.com/edgedb/edgedb-go" @@ -52,12 +53,43 @@ func createLLM(c *fiber.Ctx) error { name := c.FormValue("model-name-input") modelID := c.FormValue("selectedLLMId") temperature := c.FormValue("temperature-slider") - systemPrompt := c.FormValue("model-prompt-input") - f, _ := strconv.ParseFloat(temperature, 32) temperatureFloat := float32(f) - err := edgeClient.Execute(edgeCtx, ` + systemPrompt := c.FormValue("model-prompt-input") + url := c.FormValue("model-url-input") + token := c.FormValue("model-key-input") + customID := c.FormValue("model-cid-input") + + fmt.Println(name, modelID, temperatureFloat, systemPrompt, url, token) + + if modelID == "custom" { + err := edgeClient.Execute(edgeCtx, ` + INSERT LLM { + name := $0, + context := $1, + temperature := $2, + modelInfo := (INSERT ModelInfo { + name := $0, + modelID := $5, + inputPrice := 0.0, + outputPrice := 0.0, + maxToken := 500, + company := (SELECT Company FILTER .name = "huggingface" LIMIT 1), + }), + custom_endpoint := (INSERT CustomEndpoint { + endpoint := $3, + key := $4, + }), + user := global currentUser + }; + `, name, systemPrompt, temperatureFloat, url, token, customID) // TODO Add real max token + if err != nil { + fmt.Println("Error in createLLM: ", err) + panic(err) + } + } else { + err := edgeClient.Execute(edgeCtx, ` INSERT LLM { name := $0, context := $1, @@ -66,8 +98,9 @@ func createLLM(c *fiber.Ctx) error { user := global currentUser } `, name, systemPrompt, temperatureFloat, modelID) - if err != nil { - panic(err) + if err != nil { + panic(err) + } } return c.SendString(GenerateModelPopoverHTML(true)) diff --git a/RequestGooseai.go b/RequestGooseai.go index 378336f..5348f69 100644 --- a/RequestGooseai.go +++ b/RequestGooseai.go @@ -14,6 +14,7 @@ type GooseaiCompletionRequest struct { Model string `json:"model"` Prompt []string `json:"prompt"` Temperature float64 `json:"temperature"` + MaxToken int32 `json:"max_tokens"` } type GooseaiCompletionResponse struct { @@ -53,6 +54,7 @@ func TestGooseaiKey(apiKey string) bool { Model: "gpt-j-6b", Prompt: []string{"Hello, how are you?"}, Temperature: 0, + MaxToken: 10, } jsonBody, err := json.Marshal(requestBody) @@ -111,10 +113,16 @@ func RequestGooseai(model string, messages []Message, temperature float64) (Goos url := "https://api.goose.ai/v1/engines/" + model + "/completions" + var prompt string + for _, message := range messages { + prompt += message.Content + } + requestBody := GooseaiCompletionRequest{ Model: model, - Prompt: []string{messages[len(messages)-1].Content}, + Prompt: []string{prompt}, Temperature: temperature, + MaxToken: 300, } jsonBody, err := json.Marshal(requestBody) diff --git a/RequestHuggingface.go b/RequestHuggingface.go index 6635dc6..88a5365 100644 --- a/RequestHuggingface.go +++ b/RequestHuggingface.go @@ -77,6 +77,8 @@ func RequestHuggingface(llm LLM, messages []Message, temperature float64) (Huggi req.Header.Set("Authorization", "Bearer "+llm.Endpoint.Key) req.Header.Set("Content-Type", "application/json") + fmt.Println(url, llm.Endpoint.Key) + client := &http.Client{} resp, err := client.Do(req) if err != nil { diff --git a/RequestMistral.go b/RequestMistral.go index 895efca..7700025 100644 --- a/RequestMistral.go +++ b/RequestMistral.go @@ -177,6 +177,9 @@ func RequestMistral(model string, messages []Message, temperature float64) (Mist FILTER .modelID = $0 LIMIT 1 `, &usedModelInfo, model) + if err != nil { + return MistralChatCompletionResponse{}, fmt.Errorf("error getting model info: %w", err) + } if usedModelInfo.InputPrice == 0 || usedModelInfo.OutputPrice == 0 { return MistralChatCompletionResponse{}, fmt.Errorf("model %s not found in Mistral", model) diff --git a/utils.go b/utils.go index a967dbd..e2fa741 100644 --- a/utils.go +++ b/utils.go @@ -145,7 +145,7 @@ func GetAvailableModels() []ModelInfo { name, icon } - } FILTER .modelID != 'none' + } FILTER .modelID != 'none' AND .company.name != 'huggingface' `, &models) if err != nil { panic(err) diff --git a/views/partials/message-bot.html b/views/partials/message-bot.html index 51113eb..89ac318 100644 --- a/views/partials/message-bot.html +++ b/views/partials/message-bot.html @@ -29,7 +29,7 @@ {% if not message.Hidden %}

- {{ message.Name }} {{ message.Model }} + {{ message.Name }} {{ message.ModelID }}

diff --git a/views/partials/popover-models.html b/views/partials/popover-models.html index c517082..8d1b9b3 100644 --- a/views/partials/popover-models.html +++ b/views/partials/popover-models.html @@ -1,14 +1,14 @@ -