Working context

This commit is contained in:
Adrien Bouvais 2024-05-23 19:19:40 +02:00
parent ca4c1ba885
commit 8a20078430
7 changed files with 66 additions and 34 deletions

View File

@ -15,6 +15,7 @@ type AnthropicChatCompletionRequest struct {
Messages []RequestMessage `json:"messages"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
Context string `json:"system"`
}
type AnthropicChatCompletionResponse struct {
@ -38,7 +39,7 @@ type AnthropicUsage struct {
func addAnthropicMessage(llm LLM, selected bool) edgedb.UUID {
Messages := getAllSelectedMessages()
chatCompletion, err := RequestAnthropic(llm.Model.ModelID, Messages, int(llm.Model.MaxToken), float64(llm.Temperature))
chatCompletion, err := RequestAnthropic(llm.Model.ModelID, Messages, int(llm.Model.MaxToken), float64(llm.Temperature), llm.Context)
if err != nil {
panic(err)
} else if len(chatCompletion.Content) == 0 {
@ -66,6 +67,7 @@ func TestAnthropicKey(apiKey string) bool {
Messages: AnthropicMessages,
MaxTokens: 10,
Temperature: 0,
Context: "",
}
jsonBody, err := json.Marshal(requestBody)
@ -105,7 +107,7 @@ func TestAnthropicKey(apiKey string) bool {
return true
}
func RequestAnthropic(model string, messages []Message, maxTokens int, temperature float64) (AnthropicChatCompletionResponse, error) {
func RequestAnthropic(model string, messages []Message, maxTokens int, temperature float64, context string) (AnthropicChatCompletionResponse, error) {
var apiKey struct {
Key string `edgedb:"key"`
}
@ -124,13 +126,12 @@ func RequestAnthropic(model string, messages []Message, maxTokens int, temperatu
requestBody := AnthropicChatCompletionRequest{
Model: model,
Messages: Message2RequestMessage(messages),
Messages: Message2RequestMessage(messages, ""),
MaxTokens: maxTokens,
Temperature: temperature,
Context: context,
}
fmt.Println("Message2RequestMessage(messages):", Message2RequestMessage(messages))
jsonBody, err := json.Marshal(requestBody)
if err != nil {
return AnthropicChatCompletionResponse{}, fmt.Errorf("error marshaling JSON: %w", err)

View File

@ -47,7 +47,7 @@ type GoogleChoice struct {
func addGoogleMessage(llm LLM, selected bool) edgedb.UUID {
Messages := getAllSelectedMessages()
chatCompletion, err := RequestGoogle(llm.Model.ModelID, Messages, float64(llm.Temperature))
chatCompletion, err := RequestGoogle(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
if err != nil {
panic(err)
} else if len(chatCompletion.Choices) == 0 {
@ -116,7 +116,7 @@ func TestGoogleKey(apiKey string) bool {
return true
}
func RequestGoogle(model string, messages []Message, temperature float64) (OpenaiChatCompletionResponse, error) {
func RequestGoogle(model string, messages []Message, temperature float64, context string) (OpenaiChatCompletionResponse, error) {
var apiKey string
err := edgeClient.QuerySingle(edgeCtx, `
with
@ -135,7 +135,7 @@ func RequestGoogle(model string, messages []Message, temperature float64) (Opena
requestBody := OpenaiChatCompletionRequest{
Model: model,
Messages: Message2RequestMessage(messages),
Messages: Message2RequestMessage(messages, context),
Temperature: temperature,
}

View File

@ -40,7 +40,7 @@ type GroqChoice struct {
func addGroqMessage(llm LLM, selected bool) edgedb.UUID {
Messages := getAllSelectedMessages()
chatCompletion, err := RequestGroq(llm.Model.ModelID, Messages, float64(llm.Temperature))
chatCompletion, err := RequestGroq(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
if err != nil {
panic(err)
} else if len(chatCompletion.Choices) == 0 {
@ -67,7 +67,7 @@ func TestGroqKey(apiKey string) bool {
requestBody := GroqChatCompletionRequest{
Model: "llama3-8b-8192",
Messages: Message2RequestMessage(groqMessages),
Messages: Message2RequestMessage(groqMessages, ""),
Temperature: 0,
}
@ -107,7 +107,7 @@ func TestGroqKey(apiKey string) bool {
return true
}
func RequestGroq(model string, messages []Message, temperature float64) (GroqChatCompletionResponse, error) {
func RequestGroq(model string, messages []Message, temperature float64, context string) (GroqChatCompletionResponse, error) {
var apiKey string
err := edgeClient.QuerySingle(edgeCtx, `
with
@ -126,7 +126,7 @@ func RequestGroq(model string, messages []Message, temperature float64) (GroqCha
requestBody := GroqChatCompletionRequest{
Model: model,
Messages: Message2RequestMessage(messages),
Messages: Message2RequestMessage(messages, context),
Temperature: temperature,
}

View File

@ -59,7 +59,7 @@ func RequestHuggingface(llm LLM, messages []Message, temperature float64) (Huggi
requestBody := HuggingfaceChatCompletionRequest{
Model: "tgi",
Messages: Message2RequestMessage(messages),
Messages: Message2RequestMessage(messages, llm.Context),
Temperature: temperature,
Stream: false,
}

View File

@ -39,7 +39,7 @@ type MistralChoice struct {
func addMistralMessage(llm LLM, selected bool) edgedb.UUID {
Messages := getAllSelectedMessages()
chatCompletion, err := RequestMistral(llm.Model.ModelID, Messages, float64(llm.Temperature))
chatCompletion, err := RequestMistral(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
if err != nil {
panic(err)
} else if len(chatCompletion.Choices) == 0 {
@ -113,7 +113,7 @@ func TestMistralKey(apiKey string) bool {
return true
}
func RequestMistral(model string, messages []Message, temperature float64) (MistralChatCompletionResponse, error) {
func RequestMistral(model string, messages []Message, temperature float64, context string) (MistralChatCompletionResponse, error) {
var apiKey string
err := edgeClient.QuerySingle(edgeCtx, `
with
@ -132,7 +132,7 @@ func RequestMistral(model string, messages []Message, temperature float64) (Mist
requestBody := MistralChatCompletionRequest{
Model: model,
Messages: Message2RequestMessage(messages),
Messages: Message2RequestMessage(messages, context),
Temperature: temperature,
}

View File

@ -40,7 +40,7 @@ type OpenaiChoice struct {
func addOpenaiMessage(llm LLM, selected bool) edgedb.UUID {
Messages := getAllSelectedMessages()
chatCompletion, err := RequestOpenai(llm.Model.ModelID, Messages, float64(llm.Temperature))
chatCompletion, err := RequestOpenai(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
if err != nil {
panic(err)
} else if len(chatCompletion.Choices) == 0 {
@ -107,7 +107,7 @@ func TestOpenaiKey(apiKey string) bool {
return true
}
func RequestOpenai(model string, messages []Message, temperature float64) (OpenaiChatCompletionResponse, error) {
func RequestOpenai(model string, messages []Message, temperature float64, context string) (OpenaiChatCompletionResponse, error) {
var apiKey string
err := edgeClient.QuerySingle(edgeCtx, `
with
@ -124,9 +124,13 @@ func RequestOpenai(model string, messages []Message, temperature float64) (Opena
url := "https://api.openai.com/v1/chat/completions"
fmt.Println(context)
fmt.Println(Message2RequestMessage(messages, context))
requestBody := OpenaiChatCompletionRequest{
Model: model,
Messages: Message2RequestMessage(messages),
Messages: Message2RequestMessage(messages, context),
Temperature: temperature,
}

View File

@ -114,8 +114,34 @@ func getExistingKeys() (bool, bool, bool, bool, bool, bool) {
return openaiExists, anthropicExists, mistralExists, groqExists, gooseaiExists, googleExists
}
func Message2RequestMessage(messages []Message) []RequestMessage {
func Message2RequestMessage(messages []Message, context string) []RequestMessage {
// Add context if it exists
if context != "" {
m := make([]RequestMessage, len(messages)+1)
m[0] = RequestMessage{
Role: "system",
Content: context,
}
for i, msg := range messages {
var role string
switch msg.Role {
case "user":
role = "user"
case "bot":
role = "assistant"
default:
role = "system"
}
m[i+1] = RequestMessage{
Role: role,
Content: msg.Content,
}
}
return m
} else {
m := make([]RequestMessage, len(messages))
for i, msg := range messages {
var role string
switch msg.Role {
@ -132,6 +158,7 @@ func Message2RequestMessage(messages []Message) []RequestMessage {
}
}
return m
}
}
func GetAvailableModels() []ModelInfo {