305 lines
9.9 KiB
Go
305 lines
9.9 KiB
Go
package main
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"io"
|
||
"math"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"fmt"
|
||
|
||
"github.com/gofiber/fiber/v2"
|
||
)
|
||
|
||
type TogetherChatCompletionResponse struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
Model string `json:"model"`
|
||
Usage OpenaiUsage `json:"usage"`
|
||
Choices []TogetherChoice `json:"choices"`
|
||
}
|
||
|
||
type TogetherChoice struct {
|
||
Text string `json:"text"`
|
||
FinishReason string `json:"finish_reason"`
|
||
Index int `json:"index"`
|
||
}
|
||
|
||
var TogetherErrorCodes map[string]string
|
||
|
||
func init() {
|
||
TogetherErrorCodes = make(map[string]string)
|
||
TogetherErrorCodes["400"] = "Provider error: Invalid Request - Please contact the support."
|
||
TogetherErrorCodes["401"] = "Provider error: nvalid Authentication - Ensure that the API key is still valid."
|
||
TogetherErrorCodes["403"] = "Provider error: et max_tokens to a lower number. Or leave it empty for using max value."
|
||
TogetherErrorCodes["404"] = "Provider error: odel not found."
|
||
TogetherErrorCodes["429"] = "Provider error: ate limit reached for requests - You are sending requests too quickly."
|
||
TogetherErrorCodes["500"] = "Provider error: ssue on Together AI servers - Retry your request after a brief wait and contact Together AI if the issue persists."
|
||
TogetherErrorCodes["503"] = "Provider error: ervers are experiencing high traffic - Please retry your requests after a brief wait."
|
||
TogetherErrorCodes["504"] = "Provider error: ervers are experiencing high traffic - Please retry your requests after a brief wait."
|
||
TogetherErrorCodes["520"] = "Provider error: n unexpected error has occurred internal to Together’s systems."
|
||
TogetherErrorCodes["524"] = "Provider error: n unexpected error has occurred internal to Together’s systems."
|
||
TogetherErrorCodes["529"] = "Provider error: n unexpected error has occurred internal to Together’s systems."
|
||
}
|
||
|
||
func RequestTogether(c *fiber.Ctx, llm LLM, messages []Message, testApiKey string) string {
|
||
model := llm.Model.ModelID
|
||
temperature := float64(llm.Temperature)
|
||
context := llm.Context
|
||
maxTokens := int(llm.MaxToken)
|
||
|
||
var apiKey string
|
||
if testApiKey == "" {
|
||
err := edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).QuerySingle(edgeCtx, `
|
||
with
|
||
filtered_keys := (
|
||
select Key {
|
||
key
|
||
} filter .company.name = <str>$0 AND .<keys[is Setting].<setting[is User] = global currentUser
|
||
)
|
||
select filtered_keys.key limit 1
|
||
`, &apiKey, "together")
|
||
if err != nil {
|
||
logErrorCode.Println("07-00-0000")
|
||
return "JADE internal error: 07-00-0000. Please contact the support."
|
||
}
|
||
} else {
|
||
apiKey = testApiKey
|
||
}
|
||
|
||
url := "https://api.together.xyz/v1/completions"
|
||
|
||
requestBody := OpenaiChatCompletionRequest{
|
||
Model: model,
|
||
Messages: Message2RequestMessage(messages, context),
|
||
MaxTokens: maxTokens,
|
||
Temperature: temperature,
|
||
}
|
||
|
||
jsonBody, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
logErrorCode.Println("07-01-0001")
|
||
return "JADE internal error: 07-01-0001. Please contact the support."
|
||
}
|
||
|
||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
|
||
if err != nil {
|
||
logErrorCode.Println("07-02-0002")
|
||
return "JADE internal error: 07-02-0002. Please contact the support."
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
client := &http.Client{}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
logErrorCode.Println("07-02-0003")
|
||
return "JADE internal error: 07-02-0003. Please contact the support."
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
logErrorCode.Println("07-01-0004")
|
||
return "JADE internal error: 07-01-0004. Please contact the support."
|
||
}
|
||
|
||
for key, value := range TogetherErrorCodes {
|
||
if strings.Contains(resp.Status, key) {
|
||
return value
|
||
}
|
||
}
|
||
|
||
var chatCompletionResponse TogetherChatCompletionResponse
|
||
err = json.Unmarshal(body, &chatCompletionResponse)
|
||
if err != nil {
|
||
logErrorCode.Println("07-01-0005")
|
||
return "JADE internal error: 07-01-0005. Please contact the support."
|
||
}
|
||
|
||
if testApiKey != "" {
|
||
return chatCompletionResponse.Choices[0].Text
|
||
}
|
||
|
||
var usedModelInfo ModelInfo
|
||
usedModelInfo, found := getModelInfoByID(llm.Model.ID)
|
||
if !found {
|
||
logErrorCode.Println("07-00-0006")
|
||
return "JADE internal error: 07-00-0006. Please contact the support."
|
||
}
|
||
|
||
var inputCost float32 = float32(chatCompletionResponse.Usage.PromptTokens) * usedModelInfo.InputPrice
|
||
var outputCost float32 = float32(chatCompletionResponse.Usage.CompletionTokens) * usedModelInfo.OutputPrice
|
||
addUsage(c, inputCost, outputCost, chatCompletionResponse.Usage.PromptTokens, chatCompletionResponse.Usage.CompletionTokens, model)
|
||
|
||
if len(chatCompletionResponse.Choices) == 0 {
|
||
logErrorCode.Println("07-03-0007 -", resp.Status, "-", string(body))
|
||
return "JADE internal error: 07-03-0007. Please contact the support."
|
||
}
|
||
|
||
return chatCompletionResponse.Choices[0].Text
|
||
}
|
||
|
||
type TogetherModel struct {
|
||
ID string `json:"id"`
|
||
Object string `json:"object"`
|
||
Created int64 `json:"created"`
|
||
Type string `json:"type"`
|
||
DisplayName string `json:"display_name"`
|
||
Organization string `json:"organization"`
|
||
Link string `json:"link"`
|
||
License string `json:"license"`
|
||
ContextLength int `json:"context_length"`
|
||
Pricing TogetherPricing `json:"pricing"`
|
||
}
|
||
|
||
type TogetherPricing struct {
|
||
Hourly float64 `json:"hourly"`
|
||
Input float64 `json:"input"`
|
||
Output float64 `json:"output"`
|
||
Base float64 `json:"base"`
|
||
Finetune float64 `json:"finetune"`
|
||
}
|
||
|
||
// TODO: Use my key everytime, no auth needed
|
||
func UpdateTogetherModels(c *fiber.Ctx, currentModelInfos []ModelInfo) {
|
||
url := "https://api.together.xyz/v1/models"
|
||
|
||
var apiKey string
|
||
err := edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).QuerySingle(edgeCtx, `
|
||
with
|
||
filtered_keys := (
|
||
select Key {
|
||
key
|
||
} filter .company.name = <str>$0 AND .<keys[is Setting].<setting[is User] = global currentUser
|
||
)
|
||
select filtered_keys.key limit 1
|
||
`, &apiKey, "together")
|
||
if err != nil {
|
||
logErrorCode.Println("08-00-0000")
|
||
return
|
||
}
|
||
|
||
req, err := http.NewRequest("GET", url, nil)
|
||
if err != nil {
|
||
logErrorCode.Println("08-01-0001")
|
||
return
|
||
}
|
||
|
||
req.Header.Set("Accept", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
|
||
client := &http.Client{}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
logErrorCode.Println("08-02-0002")
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
logErrorCode.Println("08-01-0003")
|
||
return
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
logErrorCode.Println("08-03-0004 -", resp.Status, "-", string(body))
|
||
return
|
||
}
|
||
|
||
var togetherModels []TogetherModel
|
||
err = json.Unmarshal(body, &togetherModels)
|
||
if err != nil {
|
||
logErrorCode.Println("08-01-0005")
|
||
return
|
||
}
|
||
|
||
var exist bool
|
||
var savedModel ModelInfo
|
||
var newModelFound bool = false
|
||
const epsilon = 1e-9
|
||
|
||
fmt.Println("Updating Together ModelInfo:")
|
||
for _, model := range togetherModels {
|
||
if model.Type == "chat" {
|
||
exist = false
|
||
for _, currentModel := range currentModelInfos {
|
||
if currentModel.ModelID == model.ID {
|
||
exist = true
|
||
savedModel = currentModel
|
||
break
|
||
}
|
||
}
|
||
|
||
if exist {
|
||
if math.Abs(model.Pricing.Input/1000000-float64(savedModel.InputPrice)) > epsilon ||
|
||
math.Abs(model.Pricing.Output/1000000-float64(savedModel.OutputPrice)) > epsilon {
|
||
fmt.Println("Found one existing model with changed price:", model.ID, ". Updating price from ", float64(savedModel.InputPrice), " to ", model.Pricing.Input/1000000)
|
||
err = edgeGlobalClient.Execute(edgeCtx, `
|
||
UPDATE ModelInfo FILTER .modelID = <str>$0 SET { inputPrice := <float32>$1, outputPrice := <float32>$2}
|
||
`, model.ID, float32(model.Pricing.Input)/1000000, float32(model.Pricing.Output)/1000000)
|
||
if err != nil {
|
||
fmt.Println("Error updating price:", err.Error())
|
||
}
|
||
}
|
||
} else {
|
||
fmt.Println("Found new model:", model.ID)
|
||
newModelFound = true
|
||
err = edgeGlobalClient.Execute(edgeCtx, `
|
||
INSERT ModelInfo { name := "New model, name coming soon...", modelID := <str>$0, inputPrice := <float32>$1, outputPrice := <float32>$2, company := assert_single(( SELECT Company FILTER .name = <str>$3))}
|
||
`, model.ID, float32(model.Pricing.Input)/1000000, float32(model.Pricing.Output)/1000000, "together")
|
||
if err != nil {
|
||
fmt.Println("Error creating new modelInfo:", err.Error())
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if newModelFound {
|
||
SendNoreplyEmail("adrien.bouvais.pro@gmail.com", "New Together models found", "Look like new model have been found, you should name it.")
|
||
}
|
||
|
||
// Delete all unfound models
|
||
for _, currentModel := range currentModelInfos {
|
||
if currentModel.Company.Name != "together" {
|
||
continue
|
||
}
|
||
|
||
exist = false
|
||
for _, model := range togetherModels {
|
||
if model.ID == currentModel.ModelID {
|
||
exist = true
|
||
}
|
||
}
|
||
|
||
if !exist {
|
||
fmt.Println("A ModelInfo using this modelID:", currentModel.ModelID, ", was found using a modelID that doesn't exist anymore. Deleting it.")
|
||
err = edgeGlobalClient.Execute(edgeCtx, `
|
||
DELETE ModelInfo FILTER .modelID = <str>$0 AND .company.name = <str>$1
|
||
`, currentModel.ModelID, "together")
|
||
if err != nil {
|
||
fmt.Println("Error deleting a ModelInfo with an unfound modelID.")
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func UpdateTogetherModelsHandler(c *fiber.Ctx) error {
|
||
if !IsUserAdmin(c) {
|
||
return c.SendString("That's for admin, how did you manage to come here ?")
|
||
}
|
||
|
||
var currentModelInfos []ModelInfo
|
||
err := edgeGlobalClient.Query(edgeCtx, `SELECT ModelInfo { id, name, modelID, inputPrice, outputPrice, company}`, ¤tModelInfos)
|
||
if err != nil {
|
||
fmt.Println("Error getting all ModelInfo for updating them")
|
||
}
|
||
UpdateTogetherModels(c, currentModelInfos)
|
||
return c.SendString("Models from together ai updated")
|
||
}
|