From 8a20078430d3b02402d5d84a7e421d753c5b6d82 Mon Sep 17 00:00:00 2001 From: Adrien Date: Thu, 23 May 2024 19:19:40 +0200 Subject: [PATCH] Working context --- RequestAnthropic.go | 11 +++++---- RequestGoogle.go | 6 ++--- RequestGroq.go | 8 +++--- RequestHuggingface.go | 2 +- RequestMistral.go | 6 ++--- RequestOpenai.go | 10 +++++--- utils.go | 57 +++++++++++++++++++++++++++++++------------ 7 files changed, 66 insertions(+), 34 deletions(-) diff --git a/RequestAnthropic.go b/RequestAnthropic.go index 75ba6d0..9196055 100644 --- a/RequestAnthropic.go +++ b/RequestAnthropic.go @@ -15,6 +15,7 @@ type AnthropicChatCompletionRequest struct { Messages []RequestMessage `json:"messages"` MaxTokens int `json:"max_tokens"` Temperature float64 `json:"temperature"` + Context string `json:"system"` } type AnthropicChatCompletionResponse struct { @@ -38,7 +39,7 @@ type AnthropicUsage struct { func addAnthropicMessage(llm LLM, selected bool) edgedb.UUID { Messages := getAllSelectedMessages() - chatCompletion, err := RequestAnthropic(llm.Model.ModelID, Messages, int(llm.Model.MaxToken), float64(llm.Temperature)) + chatCompletion, err := RequestAnthropic(llm.Model.ModelID, Messages, int(llm.Model.MaxToken), float64(llm.Temperature), llm.Context) if err != nil { panic(err) } else if len(chatCompletion.Content) == 0 { @@ -66,6 +67,7 @@ func TestAnthropicKey(apiKey string) bool { Messages: AnthropicMessages, MaxTokens: 10, Temperature: 0, + Context: "", } jsonBody, err := json.Marshal(requestBody) @@ -105,7 +107,7 @@ func TestAnthropicKey(apiKey string) bool { return true } -func RequestAnthropic(model string, messages []Message, maxTokens int, temperature float64) (AnthropicChatCompletionResponse, error) { +func RequestAnthropic(model string, messages []Message, maxTokens int, temperature float64, context string) (AnthropicChatCompletionResponse, error) { var apiKey struct { Key string `edgedb:"key"` } @@ -124,13 +126,12 @@ func RequestAnthropic(model string, messages []Message, maxTokens int, temperatu requestBody := AnthropicChatCompletionRequest{ Model: model, - Messages: Message2RequestMessage(messages), + Messages: Message2RequestMessage(messages, ""), MaxTokens: maxTokens, Temperature: temperature, + Context: context, } - fmt.Println("Message2RequestMessage(messages):", Message2RequestMessage(messages)) - jsonBody, err := json.Marshal(requestBody) if err != nil { return AnthropicChatCompletionResponse{}, fmt.Errorf("error marshaling JSON: %w", err) diff --git a/RequestGoogle.go b/RequestGoogle.go index 362e27e..ab8440d 100644 --- a/RequestGoogle.go +++ b/RequestGoogle.go @@ -47,7 +47,7 @@ type GoogleChoice struct { func addGoogleMessage(llm LLM, selected bool) edgedb.UUID { Messages := getAllSelectedMessages() - chatCompletion, err := RequestGoogle(llm.Model.ModelID, Messages, float64(llm.Temperature)) + chatCompletion, err := RequestGoogle(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context) if err != nil { panic(err) } else if len(chatCompletion.Choices) == 0 { @@ -116,7 +116,7 @@ func TestGoogleKey(apiKey string) bool { return true } -func RequestGoogle(model string, messages []Message, temperature float64) (OpenaiChatCompletionResponse, error) { +func RequestGoogle(model string, messages []Message, temperature float64, context string) (OpenaiChatCompletionResponse, error) { var apiKey string err := edgeClient.QuerySingle(edgeCtx, ` with @@ -135,7 +135,7 @@ func RequestGoogle(model string, messages []Message, temperature float64) (Opena requestBody := OpenaiChatCompletionRequest{ Model: model, - Messages: Message2RequestMessage(messages), + Messages: Message2RequestMessage(messages, context), Temperature: temperature, } diff --git a/RequestGroq.go b/RequestGroq.go index 0edbb11..0c5bf1c 100644 --- a/RequestGroq.go +++ b/RequestGroq.go @@ -40,7 +40,7 @@ type GroqChoice struct { func addGroqMessage(llm LLM, selected bool) edgedb.UUID { Messages := getAllSelectedMessages() - chatCompletion, err := RequestGroq(llm.Model.ModelID, Messages, float64(llm.Temperature)) + chatCompletion, err := RequestGroq(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context) if err != nil { panic(err) } else if len(chatCompletion.Choices) == 0 { @@ -67,7 +67,7 @@ func TestGroqKey(apiKey string) bool { requestBody := GroqChatCompletionRequest{ Model: "llama3-8b-8192", - Messages: Message2RequestMessage(groqMessages), + Messages: Message2RequestMessage(groqMessages, ""), Temperature: 0, } @@ -107,7 +107,7 @@ func TestGroqKey(apiKey string) bool { return true } -func RequestGroq(model string, messages []Message, temperature float64) (GroqChatCompletionResponse, error) { +func RequestGroq(model string, messages []Message, temperature float64, context string) (GroqChatCompletionResponse, error) { var apiKey string err := edgeClient.QuerySingle(edgeCtx, ` with @@ -126,7 +126,7 @@ func RequestGroq(model string, messages []Message, temperature float64) (GroqCha requestBody := GroqChatCompletionRequest{ Model: model, - Messages: Message2RequestMessage(messages), + Messages: Message2RequestMessage(messages, context), Temperature: temperature, } diff --git a/RequestHuggingface.go b/RequestHuggingface.go index 88a5365..752e2ac 100644 --- a/RequestHuggingface.go +++ b/RequestHuggingface.go @@ -59,7 +59,7 @@ func RequestHuggingface(llm LLM, messages []Message, temperature float64) (Huggi requestBody := HuggingfaceChatCompletionRequest{ Model: "tgi", - Messages: Message2RequestMessage(messages), + Messages: Message2RequestMessage(messages, llm.Context), Temperature: temperature, Stream: false, } diff --git a/RequestMistral.go b/RequestMistral.go index 7700025..7f6b7e3 100644 --- a/RequestMistral.go +++ b/RequestMistral.go @@ -39,7 +39,7 @@ type MistralChoice struct { func addMistralMessage(llm LLM, selected bool) edgedb.UUID { Messages := getAllSelectedMessages() - chatCompletion, err := RequestMistral(llm.Model.ModelID, Messages, float64(llm.Temperature)) + chatCompletion, err := RequestMistral(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context) if err != nil { panic(err) } else if len(chatCompletion.Choices) == 0 { @@ -113,7 +113,7 @@ func TestMistralKey(apiKey string) bool { return true } -func RequestMistral(model string, messages []Message, temperature float64) (MistralChatCompletionResponse, error) { +func RequestMistral(model string, messages []Message, temperature float64, context string) (MistralChatCompletionResponse, error) { var apiKey string err := edgeClient.QuerySingle(edgeCtx, ` with @@ -132,7 +132,7 @@ func RequestMistral(model string, messages []Message, temperature float64) (Mist requestBody := MistralChatCompletionRequest{ Model: model, - Messages: Message2RequestMessage(messages), + Messages: Message2RequestMessage(messages, context), Temperature: temperature, } diff --git a/RequestOpenai.go b/RequestOpenai.go index 65cbb57..dbd0cfc 100644 --- a/RequestOpenai.go +++ b/RequestOpenai.go @@ -40,7 +40,7 @@ type OpenaiChoice struct { func addOpenaiMessage(llm LLM, selected bool) edgedb.UUID { Messages := getAllSelectedMessages() - chatCompletion, err := RequestOpenai(llm.Model.ModelID, Messages, float64(llm.Temperature)) + chatCompletion, err := RequestOpenai(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context) if err != nil { panic(err) } else if len(chatCompletion.Choices) == 0 { @@ -107,7 +107,7 @@ func TestOpenaiKey(apiKey string) bool { return true } -func RequestOpenai(model string, messages []Message, temperature float64) (OpenaiChatCompletionResponse, error) { +func RequestOpenai(model string, messages []Message, temperature float64, context string) (OpenaiChatCompletionResponse, error) { var apiKey string err := edgeClient.QuerySingle(edgeCtx, ` with @@ -124,9 +124,13 @@ func RequestOpenai(model string, messages []Message, temperature float64) (Opena url := "https://api.openai.com/v1/chat/completions" + fmt.Println(context) + + fmt.Println(Message2RequestMessage(messages, context)) + requestBody := OpenaiChatCompletionRequest{ Model: model, - Messages: Message2RequestMessage(messages), + Messages: Message2RequestMessage(messages, context), Temperature: temperature, } diff --git a/utils.go b/utils.go index e2fa741..92c1fc6 100644 --- a/utils.go +++ b/utils.go @@ -114,24 +114,51 @@ func getExistingKeys() (bool, bool, bool, bool, bool, bool) { return openaiExists, anthropicExists, mistralExists, groqExists, gooseaiExists, googleExists } -func Message2RequestMessage(messages []Message) []RequestMessage { - m := make([]RequestMessage, len(messages)) - for i, msg := range messages { - var role string - switch msg.Role { - case "user": - role = "user" - case "bot": - role = "assistant" - default: - role = "system" +func Message2RequestMessage(messages []Message, context string) []RequestMessage { + // Add context if it exists + if context != "" { + m := make([]RequestMessage, len(messages)+1) + m[0] = RequestMessage{ + Role: "system", + Content: context, } - m[i] = RequestMessage{ - Role: role, - Content: msg.Content, + + for i, msg := range messages { + var role string + switch msg.Role { + case "user": + role = "user" + case "bot": + role = "assistant" + default: + role = "system" + } + m[i+1] = RequestMessage{ + Role: role, + Content: msg.Content, + } } + return m + } else { + m := make([]RequestMessage, len(messages)) + + for i, msg := range messages { + var role string + switch msg.Role { + case "user": + role = "user" + case "bot": + role = "assistant" + default: + role = "system" + } + m[i] = RequestMessage{ + Role: role, + Content: msg.Content, + } + } + return m } - return m } func GetAvailableModels() []ModelInfo {