Working context
This commit is contained in:
parent
ca4c1ba885
commit
8a20078430
@ -15,6 +15,7 @@ type AnthropicChatCompletionRequest struct {
|
||||
Messages []RequestMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
Context string `json:"system"`
|
||||
}
|
||||
|
||||
type AnthropicChatCompletionResponse struct {
|
||||
@ -38,7 +39,7 @@ type AnthropicUsage struct {
|
||||
func addAnthropicMessage(llm LLM, selected bool) edgedb.UUID {
|
||||
Messages := getAllSelectedMessages()
|
||||
|
||||
chatCompletion, err := RequestAnthropic(llm.Model.ModelID, Messages, int(llm.Model.MaxToken), float64(llm.Temperature))
|
||||
chatCompletion, err := RequestAnthropic(llm.Model.ModelID, Messages, int(llm.Model.MaxToken), float64(llm.Temperature), llm.Context)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if len(chatCompletion.Content) == 0 {
|
||||
@ -66,6 +67,7 @@ func TestAnthropicKey(apiKey string) bool {
|
||||
Messages: AnthropicMessages,
|
||||
MaxTokens: 10,
|
||||
Temperature: 0,
|
||||
Context: "",
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(requestBody)
|
||||
@ -105,7 +107,7 @@ func TestAnthropicKey(apiKey string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func RequestAnthropic(model string, messages []Message, maxTokens int, temperature float64) (AnthropicChatCompletionResponse, error) {
|
||||
func RequestAnthropic(model string, messages []Message, maxTokens int, temperature float64, context string) (AnthropicChatCompletionResponse, error) {
|
||||
var apiKey struct {
|
||||
Key string `edgedb:"key"`
|
||||
}
|
||||
@ -124,13 +126,12 @@ func RequestAnthropic(model string, messages []Message, maxTokens int, temperatu
|
||||
|
||||
requestBody := AnthropicChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: Message2RequestMessage(messages),
|
||||
Messages: Message2RequestMessage(messages, ""),
|
||||
MaxTokens: maxTokens,
|
||||
Temperature: temperature,
|
||||
Context: context,
|
||||
}
|
||||
|
||||
fmt.Println("Message2RequestMessage(messages):", Message2RequestMessage(messages))
|
||||
|
||||
jsonBody, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return AnthropicChatCompletionResponse{}, fmt.Errorf("error marshaling JSON: %w", err)
|
||||
|
@ -47,7 +47,7 @@ type GoogleChoice struct {
|
||||
func addGoogleMessage(llm LLM, selected bool) edgedb.UUID {
|
||||
Messages := getAllSelectedMessages()
|
||||
|
||||
chatCompletion, err := RequestGoogle(llm.Model.ModelID, Messages, float64(llm.Temperature))
|
||||
chatCompletion, err := RequestGoogle(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if len(chatCompletion.Choices) == 0 {
|
||||
@ -116,7 +116,7 @@ func TestGoogleKey(apiKey string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func RequestGoogle(model string, messages []Message, temperature float64) (OpenaiChatCompletionResponse, error) {
|
||||
func RequestGoogle(model string, messages []Message, temperature float64, context string) (OpenaiChatCompletionResponse, error) {
|
||||
var apiKey string
|
||||
err := edgeClient.QuerySingle(edgeCtx, `
|
||||
with
|
||||
@ -135,7 +135,7 @@ func RequestGoogle(model string, messages []Message, temperature float64) (Opena
|
||||
|
||||
requestBody := OpenaiChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: Message2RequestMessage(messages),
|
||||
Messages: Message2RequestMessage(messages, context),
|
||||
Temperature: temperature,
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ type GroqChoice struct {
|
||||
func addGroqMessage(llm LLM, selected bool) edgedb.UUID {
|
||||
Messages := getAllSelectedMessages()
|
||||
|
||||
chatCompletion, err := RequestGroq(llm.Model.ModelID, Messages, float64(llm.Temperature))
|
||||
chatCompletion, err := RequestGroq(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if len(chatCompletion.Choices) == 0 {
|
||||
@ -67,7 +67,7 @@ func TestGroqKey(apiKey string) bool {
|
||||
|
||||
requestBody := GroqChatCompletionRequest{
|
||||
Model: "llama3-8b-8192",
|
||||
Messages: Message2RequestMessage(groqMessages),
|
||||
Messages: Message2RequestMessage(groqMessages, ""),
|
||||
Temperature: 0,
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ func TestGroqKey(apiKey string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func RequestGroq(model string, messages []Message, temperature float64) (GroqChatCompletionResponse, error) {
|
||||
func RequestGroq(model string, messages []Message, temperature float64, context string) (GroqChatCompletionResponse, error) {
|
||||
var apiKey string
|
||||
err := edgeClient.QuerySingle(edgeCtx, `
|
||||
with
|
||||
@ -126,7 +126,7 @@ func RequestGroq(model string, messages []Message, temperature float64) (GroqCha
|
||||
|
||||
requestBody := GroqChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: Message2RequestMessage(messages),
|
||||
Messages: Message2RequestMessage(messages, context),
|
||||
Temperature: temperature,
|
||||
}
|
||||
|
||||
|
@ -59,7 +59,7 @@ func RequestHuggingface(llm LLM, messages []Message, temperature float64) (Huggi
|
||||
|
||||
requestBody := HuggingfaceChatCompletionRequest{
|
||||
Model: "tgi",
|
||||
Messages: Message2RequestMessage(messages),
|
||||
Messages: Message2RequestMessage(messages, llm.Context),
|
||||
Temperature: temperature,
|
||||
Stream: false,
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ type MistralChoice struct {
|
||||
func addMistralMessage(llm LLM, selected bool) edgedb.UUID {
|
||||
Messages := getAllSelectedMessages()
|
||||
|
||||
chatCompletion, err := RequestMistral(llm.Model.ModelID, Messages, float64(llm.Temperature))
|
||||
chatCompletion, err := RequestMistral(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if len(chatCompletion.Choices) == 0 {
|
||||
@ -113,7 +113,7 @@ func TestMistralKey(apiKey string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func RequestMistral(model string, messages []Message, temperature float64) (MistralChatCompletionResponse, error) {
|
||||
func RequestMistral(model string, messages []Message, temperature float64, context string) (MistralChatCompletionResponse, error) {
|
||||
var apiKey string
|
||||
err := edgeClient.QuerySingle(edgeCtx, `
|
||||
with
|
||||
@ -132,7 +132,7 @@ func RequestMistral(model string, messages []Message, temperature float64) (Mist
|
||||
|
||||
requestBody := MistralChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: Message2RequestMessage(messages),
|
||||
Messages: Message2RequestMessage(messages, context),
|
||||
Temperature: temperature,
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ type OpenaiChoice struct {
|
||||
func addOpenaiMessage(llm LLM, selected bool) edgedb.UUID {
|
||||
Messages := getAllSelectedMessages()
|
||||
|
||||
chatCompletion, err := RequestOpenai(llm.Model.ModelID, Messages, float64(llm.Temperature))
|
||||
chatCompletion, err := RequestOpenai(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if len(chatCompletion.Choices) == 0 {
|
||||
@ -107,7 +107,7 @@ func TestOpenaiKey(apiKey string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func RequestOpenai(model string, messages []Message, temperature float64) (OpenaiChatCompletionResponse, error) {
|
||||
func RequestOpenai(model string, messages []Message, temperature float64, context string) (OpenaiChatCompletionResponse, error) {
|
||||
var apiKey string
|
||||
err := edgeClient.QuerySingle(edgeCtx, `
|
||||
with
|
||||
@ -124,9 +124,13 @@ func RequestOpenai(model string, messages []Message, temperature float64) (Opena
|
||||
|
||||
url := "https://api.openai.com/v1/chat/completions"
|
||||
|
||||
fmt.Println(context)
|
||||
|
||||
fmt.Println(Message2RequestMessage(messages, context))
|
||||
|
||||
requestBody := OpenaiChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: Message2RequestMessage(messages),
|
||||
Messages: Message2RequestMessage(messages, context),
|
||||
Temperature: temperature,
|
||||
}
|
||||
|
||||
|
29
utils.go
29
utils.go
@ -114,8 +114,34 @@ func getExistingKeys() (bool, bool, bool, bool, bool, bool) {
|
||||
return openaiExists, anthropicExists, mistralExists, groqExists, gooseaiExists, googleExists
|
||||
}
|
||||
|
||||
func Message2RequestMessage(messages []Message) []RequestMessage {
|
||||
func Message2RequestMessage(messages []Message, context string) []RequestMessage {
|
||||
// Add context if it exists
|
||||
if context != "" {
|
||||
m := make([]RequestMessage, len(messages)+1)
|
||||
m[0] = RequestMessage{
|
||||
Role: "system",
|
||||
Content: context,
|
||||
}
|
||||
|
||||
for i, msg := range messages {
|
||||
var role string
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
role = "user"
|
||||
case "bot":
|
||||
role = "assistant"
|
||||
default:
|
||||
role = "system"
|
||||
}
|
||||
m[i+1] = RequestMessage{
|
||||
Role: role,
|
||||
Content: msg.Content,
|
||||
}
|
||||
}
|
||||
return m
|
||||
} else {
|
||||
m := make([]RequestMessage, len(messages))
|
||||
|
||||
for i, msg := range messages {
|
||||
var role string
|
||||
switch msg.Role {
|
||||
@ -133,6 +159,7 @@ func Message2RequestMessage(messages []Message) []RequestMessage {
|
||||
}
|
||||
return m
|
||||
}
|
||||
}
|
||||
|
||||
func GetAvailableModels() []ModelInfo {
|
||||
// TODO Filter if key is not available
|
||||
|
Loading…
x
Reference in New Issue
Block a user