From 1603532150a29afebcde371ebfe7c7983ead9e7c Mon Sep 17 00:00:00 2001 From: Adrien Date: Wed, 22 May 2024 22:48:53 +0200 Subject: [PATCH] Working huggingface endpoint --- Chat.go | 2 +- Request.go | 8 ++ RequestHuggingface.go | 101 +++++++++++++++++++++++ database.go | 18 ++-- dbschema/default.esdl | 8 ++ dbschema/migrations/00033-m1eiric.edgeql | 8 ++ dbschema/migrations/00034-m1x75hd.edgeql | 10 +++ views/partials/popover-models.html | 1 + 8 files changed, 150 insertions(+), 6 deletions(-) create mode 100644 RequestHuggingface.go create mode 100644 dbschema/migrations/00033-m1eiric.edgeql create mode 100644 dbschema/migrations/00034-m1x75hd.edgeql diff --git a/Chat.go b/Chat.go index 4749d38..39fde56 100644 --- a/Chat.go +++ b/Chat.go @@ -156,7 +156,7 @@ func GetMessageContentHandler(c *fiber.Ctx) error { out := "
" out += "

" - out += "" + selectedMessage.LLM.Name + " " + selectedMessage.LLM.Model.ModelID + "" + out += "" + selectedMessage.LLM.Name + " " + selectedMessage.LLM.Model.Name + "" out += "

" out += "
" out += "
" diff --git a/Request.go b/Request.go index 1dce59c..3f28182 100644 --- a/Request.go +++ b/Request.go @@ -43,6 +43,11 @@ func GeneratePlaceholderHandler(c *fiber.Ctx) error { name, context, temperature, + custom_endpoint : { + id, + endpoint, + key + }, modelInfo : { modelID, maxToken, @@ -112,6 +117,8 @@ func GenerateMultipleMessagesHandler(c *fiber.Ctx) error { addMessageFunc = addGroqMessage case "gooseai": addMessageFunc = addGooseaiMessage + case "huggingface": + addMessageFunc = addHuggingfaceMessage } var messageID edgedb.UUID @@ -141,6 +148,7 @@ func GenerateMultipleMessagesHandler(c *fiber.Ctx) error { FILTER .id = $0; `, &message, messageID) if err != nil { + fmt.Println("Is it here ?") panic(err) } diff --git a/RequestHuggingface.go b/RequestHuggingface.go new file mode 100644 index 0000000..6635dc6 --- /dev/null +++ b/RequestHuggingface.go @@ -0,0 +1,101 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/edgedb/edgedb-go" +) + +type HuggingfaceChatCompletionRequest struct { + Model string `json:"model"` + Messages []RequestMessage `json:"messages"` + Temperature float64 `json:"temperature"` + Stream bool `json:"stream"` +} + +type HuggingfaceChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []HuggingfaceChoice `json:"choices"` +} + +type HuggingfaceUsage struct { + PromptTokens int32 `json:"prompt_tokens"` + CompletionTokens int32 `json:"completion_tokens"` + TotalTokens int32 `json:"total_tokens"` +} + +type HuggingfaceChoice struct { + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` + Index int `json:"index"` +} + +func addHuggingfaceMessage(llm LLM, selected bool) edgedb.UUID { + Messages := getAllSelectedMessages() + + chatCompletion, err := RequestHuggingface(llm, Messages, float64(llm.Temperature)) + if err != nil { + panic(err) + } else if len(chatCompletion.Choices) == 0 { + fmt.Println("No response from Endpoint") + id := insertBotMessage("No response from Endpoint", selected, llm.ID) + return id + } else { + Content := chatCompletion.Choices[0].Message.Content + id := insertBotMessage(Content, selected, llm.ID) + return id + } +} + +func RequestHuggingface(llm LLM, messages []Message, temperature float64) (HuggingfaceChatCompletionResponse, error) { + url := llm.Endpoint.Endpoint + + requestBody := HuggingfaceChatCompletionRequest{ + Model: "tgi", + Messages: Message2RequestMessage(messages), + Temperature: temperature, + Stream: false, + } + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return HuggingfaceChatCompletionResponse{}, fmt.Errorf("error marshaling JSON: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + if err != nil { + return HuggingfaceChatCompletionResponse{}, fmt.Errorf("error creating request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+llm.Endpoint.Key) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return HuggingfaceChatCompletionResponse{}, fmt.Errorf("error sending request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return HuggingfaceChatCompletionResponse{}, fmt.Errorf("error reading response body: %w", err) + } + + var chatCompletionResponse HuggingfaceChatCompletionResponse + err = json.Unmarshal(body, &chatCompletionResponse) + if err != nil { + return HuggingfaceChatCompletionResponse{}, fmt.Errorf("error unmarshaling JSON: %w", err) + } + + addUsage(0, 0, 0, 0, llm.Model.ModelID) + + return chatCompletionResponse, nil +} diff --git a/database.go b/database.go index 1eed4bf..b83ca2a 100644 --- a/database.go +++ b/database.go @@ -64,11 +64,19 @@ type Usage struct { } type LLM struct { - ID edgedb.UUID `edgedb:"id"` - Name string `edgedb:"name"` - Context string `edgedb:"context"` - Temperature float32 `edgedb:"temperature"` - Model ModelInfo `edgedb:"modelInfo"` + ID edgedb.UUID `edgedb:"id"` + Name string `edgedb:"name"` + Context string `edgedb:"context"` + Temperature float32 `edgedb:"temperature"` + Model ModelInfo `edgedb:"modelInfo"` + Endpoint CustomEndpoint `edgedb:"custom_endpoint"` +} + +type CustomEndpoint struct { + edgedb.Optional + ID edgedb.UUID `edgedb:"id"` + Endpoint string `edgedb:"endpoint"` + Key string `edgedb:"key"` } type ModelInfo struct { diff --git a/dbschema/default.esdl b/dbschema/default.esdl index 7cb99e4..530e74d 100644 --- a/dbschema/default.esdl +++ b/dbschema/default.esdl @@ -87,6 +87,14 @@ module default { required user: User { on target delete allow; }; + custom_endpoint: CustomEndpoint { + on source delete delete target; + }; + } + + type CustomEndpoint { + required endpoint: str; + required key: str; } type Company { diff --git a/dbschema/migrations/00033-m1eiric.edgeql b/dbschema/migrations/00033-m1eiric.edgeql new file mode 100644 index 0000000..34658b7 --- /dev/null +++ b/dbschema/migrations/00033-m1eiric.edgeql @@ -0,0 +1,8 @@ +CREATE MIGRATION m1eiric4fqayh7eieleesdm2s66f3sk7j4incugnyk2xzncp2t4rxa + ONTO m1nonmddagbu3p7dcqmy3bvxkwinjfosg7iuna5xxwruig4rcnr4yq +{ + CREATE TYPE default::customEndpoint { + CREATE REQUIRED PROPERTY endpoint: std::str; + CREATE REQUIRED PROPERTY key: std::str; + }; +}; diff --git a/dbschema/migrations/00034-m1x75hd.edgeql b/dbschema/migrations/00034-m1x75hd.edgeql new file mode 100644 index 0000000..920d976 --- /dev/null +++ b/dbschema/migrations/00034-m1x75hd.edgeql @@ -0,0 +1,10 @@ +CREATE MIGRATION m1x75hdgm27pmshypxbzfrhje6xru5ypx65efdiu6zuwnute2xschq + ONTO m1eiric4fqayh7eieleesdm2s66f3sk7j4incugnyk2xzncp2t4rxa +{ + ALTER TYPE default::customEndpoint RENAME TO default::CustomEndpoint; + ALTER TYPE default::LLM { + CREATE LINK custom_endpoint: default::CustomEndpoint { + ON SOURCE DELETE DELETE TARGET; + }; + }; +}; diff --git a/views/partials/popover-models.html b/views/partials/popover-models.html index 76ab806..c517082 100644 --- a/views/partials/popover-models.html +++ b/views/partials/popover-models.html @@ -46,6 +46,7 @@ placeholder="Model name">