Working context

This commit is contained in:
Adrien Bouvais 2024-05-23 19:19:40 +02:00
parent ca4c1ba885
commit 8a20078430
7 changed files with 66 additions and 34 deletions

View File

@ -15,6 +15,7 @@ type AnthropicChatCompletionRequest struct {
Messages []RequestMessage `json:"messages"` Messages []RequestMessage `json:"messages"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature"`
Context string `json:"system"`
} }
type AnthropicChatCompletionResponse struct { type AnthropicChatCompletionResponse struct {
@ -38,7 +39,7 @@ type AnthropicUsage struct {
func addAnthropicMessage(llm LLM, selected bool) edgedb.UUID { func addAnthropicMessage(llm LLM, selected bool) edgedb.UUID {
Messages := getAllSelectedMessages() Messages := getAllSelectedMessages()
chatCompletion, err := RequestAnthropic(llm.Model.ModelID, Messages, int(llm.Model.MaxToken), float64(llm.Temperature)) chatCompletion, err := RequestAnthropic(llm.Model.ModelID, Messages, int(llm.Model.MaxToken), float64(llm.Temperature), llm.Context)
if err != nil { if err != nil {
panic(err) panic(err)
} else if len(chatCompletion.Content) == 0 { } else if len(chatCompletion.Content) == 0 {
@ -66,6 +67,7 @@ func TestAnthropicKey(apiKey string) bool {
Messages: AnthropicMessages, Messages: AnthropicMessages,
MaxTokens: 10, MaxTokens: 10,
Temperature: 0, Temperature: 0,
Context: "",
} }
jsonBody, err := json.Marshal(requestBody) jsonBody, err := json.Marshal(requestBody)
@ -105,7 +107,7 @@ func TestAnthropicKey(apiKey string) bool {
return true return true
} }
func RequestAnthropic(model string, messages []Message, maxTokens int, temperature float64) (AnthropicChatCompletionResponse, error) { func RequestAnthropic(model string, messages []Message, maxTokens int, temperature float64, context string) (AnthropicChatCompletionResponse, error) {
var apiKey struct { var apiKey struct {
Key string `edgedb:"key"` Key string `edgedb:"key"`
} }
@ -124,13 +126,12 @@ func RequestAnthropic(model string, messages []Message, maxTokens int, temperatu
requestBody := AnthropicChatCompletionRequest{ requestBody := AnthropicChatCompletionRequest{
Model: model, Model: model,
Messages: Message2RequestMessage(messages), Messages: Message2RequestMessage(messages, ""),
MaxTokens: maxTokens, MaxTokens: maxTokens,
Temperature: temperature, Temperature: temperature,
Context: context,
} }
fmt.Println("Message2RequestMessage(messages):", Message2RequestMessage(messages))
jsonBody, err := json.Marshal(requestBody) jsonBody, err := json.Marshal(requestBody)
if err != nil { if err != nil {
return AnthropicChatCompletionResponse{}, fmt.Errorf("error marshaling JSON: %w", err) return AnthropicChatCompletionResponse{}, fmt.Errorf("error marshaling JSON: %w", err)

View File

@ -47,7 +47,7 @@ type GoogleChoice struct {
func addGoogleMessage(llm LLM, selected bool) edgedb.UUID { func addGoogleMessage(llm LLM, selected bool) edgedb.UUID {
Messages := getAllSelectedMessages() Messages := getAllSelectedMessages()
chatCompletion, err := RequestGoogle(llm.Model.ModelID, Messages, float64(llm.Temperature)) chatCompletion, err := RequestGoogle(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
if err != nil { if err != nil {
panic(err) panic(err)
} else if len(chatCompletion.Choices) == 0 { } else if len(chatCompletion.Choices) == 0 {
@ -116,7 +116,7 @@ func TestGoogleKey(apiKey string) bool {
return true return true
} }
func RequestGoogle(model string, messages []Message, temperature float64) (OpenaiChatCompletionResponse, error) { func RequestGoogle(model string, messages []Message, temperature float64, context string) (OpenaiChatCompletionResponse, error) {
var apiKey string var apiKey string
err := edgeClient.QuerySingle(edgeCtx, ` err := edgeClient.QuerySingle(edgeCtx, `
with with
@ -135,7 +135,7 @@ func RequestGoogle(model string, messages []Message, temperature float64) (Opena
requestBody := OpenaiChatCompletionRequest{ requestBody := OpenaiChatCompletionRequest{
Model: model, Model: model,
Messages: Message2RequestMessage(messages), Messages: Message2RequestMessage(messages, context),
Temperature: temperature, Temperature: temperature,
} }

View File

@ -40,7 +40,7 @@ type GroqChoice struct {
func addGroqMessage(llm LLM, selected bool) edgedb.UUID { func addGroqMessage(llm LLM, selected bool) edgedb.UUID {
Messages := getAllSelectedMessages() Messages := getAllSelectedMessages()
chatCompletion, err := RequestGroq(llm.Model.ModelID, Messages, float64(llm.Temperature)) chatCompletion, err := RequestGroq(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
if err != nil { if err != nil {
panic(err) panic(err)
} else if len(chatCompletion.Choices) == 0 { } else if len(chatCompletion.Choices) == 0 {
@ -67,7 +67,7 @@ func TestGroqKey(apiKey string) bool {
requestBody := GroqChatCompletionRequest{ requestBody := GroqChatCompletionRequest{
Model: "llama3-8b-8192", Model: "llama3-8b-8192",
Messages: Message2RequestMessage(groqMessages), Messages: Message2RequestMessage(groqMessages, ""),
Temperature: 0, Temperature: 0,
} }
@ -107,7 +107,7 @@ func TestGroqKey(apiKey string) bool {
return true return true
} }
func RequestGroq(model string, messages []Message, temperature float64) (GroqChatCompletionResponse, error) { func RequestGroq(model string, messages []Message, temperature float64, context string) (GroqChatCompletionResponse, error) {
var apiKey string var apiKey string
err := edgeClient.QuerySingle(edgeCtx, ` err := edgeClient.QuerySingle(edgeCtx, `
with with
@ -126,7 +126,7 @@ func RequestGroq(model string, messages []Message, temperature float64) (GroqCha
requestBody := GroqChatCompletionRequest{ requestBody := GroqChatCompletionRequest{
Model: model, Model: model,
Messages: Message2RequestMessage(messages), Messages: Message2RequestMessage(messages, context),
Temperature: temperature, Temperature: temperature,
} }

View File

@ -59,7 +59,7 @@ func RequestHuggingface(llm LLM, messages []Message, temperature float64) (Huggi
requestBody := HuggingfaceChatCompletionRequest{ requestBody := HuggingfaceChatCompletionRequest{
Model: "tgi", Model: "tgi",
Messages: Message2RequestMessage(messages), Messages: Message2RequestMessage(messages, llm.Context),
Temperature: temperature, Temperature: temperature,
Stream: false, Stream: false,
} }

View File

@ -39,7 +39,7 @@ type MistralChoice struct {
func addMistralMessage(llm LLM, selected bool) edgedb.UUID { func addMistralMessage(llm LLM, selected bool) edgedb.UUID {
Messages := getAllSelectedMessages() Messages := getAllSelectedMessages()
chatCompletion, err := RequestMistral(llm.Model.ModelID, Messages, float64(llm.Temperature)) chatCompletion, err := RequestMistral(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
if err != nil { if err != nil {
panic(err) panic(err)
} else if len(chatCompletion.Choices) == 0 { } else if len(chatCompletion.Choices) == 0 {
@ -113,7 +113,7 @@ func TestMistralKey(apiKey string) bool {
return true return true
} }
func RequestMistral(model string, messages []Message, temperature float64) (MistralChatCompletionResponse, error) { func RequestMistral(model string, messages []Message, temperature float64, context string) (MistralChatCompletionResponse, error) {
var apiKey string var apiKey string
err := edgeClient.QuerySingle(edgeCtx, ` err := edgeClient.QuerySingle(edgeCtx, `
with with
@ -132,7 +132,7 @@ func RequestMistral(model string, messages []Message, temperature float64) (Mist
requestBody := MistralChatCompletionRequest{ requestBody := MistralChatCompletionRequest{
Model: model, Model: model,
Messages: Message2RequestMessage(messages), Messages: Message2RequestMessage(messages, context),
Temperature: temperature, Temperature: temperature,
} }

View File

@ -40,7 +40,7 @@ type OpenaiChoice struct {
func addOpenaiMessage(llm LLM, selected bool) edgedb.UUID { func addOpenaiMessage(llm LLM, selected bool) edgedb.UUID {
Messages := getAllSelectedMessages() Messages := getAllSelectedMessages()
chatCompletion, err := RequestOpenai(llm.Model.ModelID, Messages, float64(llm.Temperature)) chatCompletion, err := RequestOpenai(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
if err != nil { if err != nil {
panic(err) panic(err)
} else if len(chatCompletion.Choices) == 0 { } else if len(chatCompletion.Choices) == 0 {
@ -107,7 +107,7 @@ func TestOpenaiKey(apiKey string) bool {
return true return true
} }
func RequestOpenai(model string, messages []Message, temperature float64) (OpenaiChatCompletionResponse, error) { func RequestOpenai(model string, messages []Message, temperature float64, context string) (OpenaiChatCompletionResponse, error) {
var apiKey string var apiKey string
err := edgeClient.QuerySingle(edgeCtx, ` err := edgeClient.QuerySingle(edgeCtx, `
with with
@ -124,9 +124,13 @@ func RequestOpenai(model string, messages []Message, temperature float64) (Opena
url := "https://api.openai.com/v1/chat/completions" url := "https://api.openai.com/v1/chat/completions"
fmt.Println(context)
fmt.Println(Message2RequestMessage(messages, context))
requestBody := OpenaiChatCompletionRequest{ requestBody := OpenaiChatCompletionRequest{
Model: model, Model: model,
Messages: Message2RequestMessage(messages), Messages: Message2RequestMessage(messages, context),
Temperature: temperature, Temperature: temperature,
} }

View File

@ -114,24 +114,51 @@ func getExistingKeys() (bool, bool, bool, bool, bool, bool) {
return openaiExists, anthropicExists, mistralExists, groqExists, gooseaiExists, googleExists return openaiExists, anthropicExists, mistralExists, groqExists, gooseaiExists, googleExists
} }
func Message2RequestMessage(messages []Message) []RequestMessage { func Message2RequestMessage(messages []Message, context string) []RequestMessage {
m := make([]RequestMessage, len(messages)) // Add context if it exists
for i, msg := range messages { if context != "" {
var role string m := make([]RequestMessage, len(messages)+1)
switch msg.Role { m[0] = RequestMessage{
case "user": Role: "system",
role = "user" Content: context,
case "bot":
role = "assistant"
default:
role = "system"
} }
m[i] = RequestMessage{
Role: role, for i, msg := range messages {
Content: msg.Content, var role string
switch msg.Role {
case "user":
role = "user"
case "bot":
role = "assistant"
default:
role = "system"
}
m[i+1] = RequestMessage{
Role: role,
Content: msg.Content,
}
} }
return m
} else {
m := make([]RequestMessage, len(messages))
for i, msg := range messages {
var role string
switch msg.Role {
case "user":
role = "user"
case "bot":
role = "assistant"
default:
role = "system"
}
m[i] = RequestMessage{
Role: role,
Content: msg.Content,
}
}
return m
} }
return m
} }
func GetAvailableModels() []ModelInfo { func GetAvailableModels() []ModelInfo {