diff --git a/Chat.go b/Chat.go
index 39fde56..16ffd5a 100644
--- a/Chat.go
+++ b/Chat.go
@@ -156,7 +156,7 @@ func GetMessageContentHandler(c *fiber.Ctx) error {
out := "
"
out += ""
@@ -449,19 +449,12 @@ func GenerateModelPopoverHTML(refresh bool) string {
}
}
}
- FILTER .user = global currentUser AND .name != 'none' AND .to_delete = false
+ FILTER .user = global currentUser AND .name != 'none' AND .to_delete = false
`, &llms)
if err != nil {
panic(err)
}
- //for i := 0; i < len(llms); i++ {
- // // If the modelID len is higher than 15, truncate it
- // if len(llms[i].Model.ModelID) > 12 {
- // llms[i].Model.ModelID = llms[i].Model.ModelID[0:12] + "..."
- // }
- //}
-
modelInfos := GetAvailableModels()
out, err := pongo2.Must(pongo2.FromFile("views/partials/popover-models.html")).Execute(pongo2.Context{
diff --git a/LLM.go b/LLM.go
index ff70b71..1692c40 100644
--- a/LLM.go
+++ b/LLM.go
@@ -2,6 +2,7 @@ package main
import (
"encoding/json"
+ "fmt"
"strconv"
"github.com/edgedb/edgedb-go"
@@ -52,12 +53,43 @@ func createLLM(c *fiber.Ctx) error {
name := c.FormValue("model-name-input")
modelID := c.FormValue("selectedLLMId")
temperature := c.FormValue("temperature-slider")
- systemPrompt := c.FormValue("model-prompt-input")
-
f, _ := strconv.ParseFloat(temperature, 32)
temperatureFloat := float32(f)
- err := edgeClient.Execute(edgeCtx, `
+ systemPrompt := c.FormValue("model-prompt-input")
+ url := c.FormValue("model-url-input")
+ token := c.FormValue("model-key-input")
+ customID := c.FormValue("model-cid-input")
+
+ fmt.Println(name, modelID, temperatureFloat, systemPrompt, url, token)
+
+ if modelID == "custom" {
+ err := edgeClient.Execute(edgeCtx, `
+ INSERT LLM {
+ name :=
$0,
+ context := $1,
+ temperature := $2,
+ modelInfo := (INSERT ModelInfo {
+ name := $0,
+ modelID := $5,
+ inputPrice := 0.0,
+ outputPrice := 0.0,
+ maxToken := 500,
+ company := (SELECT Company FILTER .name = "huggingface" LIMIT 1),
+ }),
+ custom_endpoint := (INSERT CustomEndpoint {
+ endpoint := $3,
+ key := $4,
+ }),
+ user := global currentUser
+ };
+ `, name, systemPrompt, temperatureFloat, url, token, customID) // TODO Add real max token
+ if err != nil {
+ fmt.Println("Error in createLLM: ", err)
+ panic(err)
+ }
+ } else {
+ err := edgeClient.Execute(edgeCtx, `
INSERT LLM {
name := $0,
context := $1,
@@ -66,8 +98,9 @@ func createLLM(c *fiber.Ctx) error {
user := global currentUser
}
`, name, systemPrompt, temperatureFloat, modelID)
- if err != nil {
- panic(err)
+ if err != nil {
+ panic(err)
+ }
}
return c.SendString(GenerateModelPopoverHTML(true))
diff --git a/RequestGooseai.go b/RequestGooseai.go
index 378336f..5348f69 100644
--- a/RequestGooseai.go
+++ b/RequestGooseai.go
@@ -14,6 +14,7 @@ type GooseaiCompletionRequest struct {
Model string `json:"model"`
Prompt []string `json:"prompt"`
Temperature float64 `json:"temperature"`
+ MaxToken int32 `json:"max_tokens"`
}
type GooseaiCompletionResponse struct {
@@ -53,6 +54,7 @@ func TestGooseaiKey(apiKey string) bool {
Model: "gpt-j-6b",
Prompt: []string{"Hello, how are you?"},
Temperature: 0,
+ MaxToken: 10,
}
jsonBody, err := json.Marshal(requestBody)
@@ -111,10 +113,16 @@ func RequestGooseai(model string, messages []Message, temperature float64) (Goos
url := "https://api.goose.ai/v1/engines/" + model + "/completions"
+ var prompt string
+ for _, message := range messages {
+ prompt += message.Content
+ }
+
requestBody := GooseaiCompletionRequest{
Model: model,
- Prompt: []string{messages[len(messages)-1].Content},
+ Prompt: []string{prompt},
Temperature: temperature,
+ MaxToken: 300,
}
jsonBody, err := json.Marshal(requestBody)
diff --git a/RequestHuggingface.go b/RequestHuggingface.go
index 6635dc6..88a5365 100644
--- a/RequestHuggingface.go
+++ b/RequestHuggingface.go
@@ -77,6 +77,8 @@ func RequestHuggingface(llm LLM, messages []Message, temperature float64) (Huggi
req.Header.Set("Authorization", "Bearer "+llm.Endpoint.Key)
req.Header.Set("Content-Type", "application/json")
+ fmt.Println(url, llm.Endpoint.Key)
+
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
diff --git a/RequestMistral.go b/RequestMistral.go
index 895efca..7700025 100644
--- a/RequestMistral.go
+++ b/RequestMistral.go
@@ -177,6 +177,9 @@ func RequestMistral(model string, messages []Message, temperature float64) (Mist
FILTER .modelID = $0
LIMIT 1
`, &usedModelInfo, model)
+ if err != nil {
+ return MistralChatCompletionResponse{}, fmt.Errorf("error getting model info: %w", err)
+ }
if usedModelInfo.InputPrice == 0 || usedModelInfo.OutputPrice == 0 {
return MistralChatCompletionResponse{}, fmt.Errorf("model %s not found in Mistral", model)
diff --git a/utils.go b/utils.go
index a967dbd..e2fa741 100644
--- a/utils.go
+++ b/utils.go
@@ -145,7 +145,7 @@ func GetAvailableModels() []ModelInfo {
name,
icon
}
- } FILTER .modelID != 'none'
+ } FILTER .modelID != 'none' AND .company.name != 'huggingface'
`, &models)
if err != nil {
panic(err)
diff --git a/views/partials/message-bot.html b/views/partials/message-bot.html
index 51113eb..89ac318 100644
--- a/views/partials/message-bot.html
+++ b/views/partials/message-bot.html
@@ -29,7 +29,7 @@
{% if not message.Hidden %}
diff --git a/views/partials/popover-models.html b/views/partials/popover-models.html
index c517082..8d1b9b3 100644
--- a/views/partials/popover-models.html
+++ b/views/partials/popover-models.html
@@ -1,14 +1,14 @@
-