Working context
This commit is contained in:
parent
ca4c1ba885
commit
8a20078430
@ -15,6 +15,7 @@ type AnthropicChatCompletionRequest struct {
|
|||||||
Messages []RequestMessage `json:"messages"`
|
Messages []RequestMessage `json:"messages"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
MaxTokens int `json:"max_tokens"`
|
||||||
Temperature float64 `json:"temperature"`
|
Temperature float64 `json:"temperature"`
|
||||||
|
Context string `json:"system"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnthropicChatCompletionResponse struct {
|
type AnthropicChatCompletionResponse struct {
|
||||||
@ -38,7 +39,7 @@ type AnthropicUsage struct {
|
|||||||
func addAnthropicMessage(llm LLM, selected bool) edgedb.UUID {
|
func addAnthropicMessage(llm LLM, selected bool) edgedb.UUID {
|
||||||
Messages := getAllSelectedMessages()
|
Messages := getAllSelectedMessages()
|
||||||
|
|
||||||
chatCompletion, err := RequestAnthropic(llm.Model.ModelID, Messages, int(llm.Model.MaxToken), float64(llm.Temperature))
|
chatCompletion, err := RequestAnthropic(llm.Model.ModelID, Messages, int(llm.Model.MaxToken), float64(llm.Temperature), llm.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
} else if len(chatCompletion.Content) == 0 {
|
} else if len(chatCompletion.Content) == 0 {
|
||||||
@ -66,6 +67,7 @@ func TestAnthropicKey(apiKey string) bool {
|
|||||||
Messages: AnthropicMessages,
|
Messages: AnthropicMessages,
|
||||||
MaxTokens: 10,
|
MaxTokens: 10,
|
||||||
Temperature: 0,
|
Temperature: 0,
|
||||||
|
Context: "",
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonBody, err := json.Marshal(requestBody)
|
jsonBody, err := json.Marshal(requestBody)
|
||||||
@ -105,7 +107,7 @@ func TestAnthropicKey(apiKey string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func RequestAnthropic(model string, messages []Message, maxTokens int, temperature float64) (AnthropicChatCompletionResponse, error) {
|
func RequestAnthropic(model string, messages []Message, maxTokens int, temperature float64, context string) (AnthropicChatCompletionResponse, error) {
|
||||||
var apiKey struct {
|
var apiKey struct {
|
||||||
Key string `edgedb:"key"`
|
Key string `edgedb:"key"`
|
||||||
}
|
}
|
||||||
@ -124,13 +126,12 @@ func RequestAnthropic(model string, messages []Message, maxTokens int, temperatu
|
|||||||
|
|
||||||
requestBody := AnthropicChatCompletionRequest{
|
requestBody := AnthropicChatCompletionRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Messages: Message2RequestMessage(messages),
|
Messages: Message2RequestMessage(messages, ""),
|
||||||
MaxTokens: maxTokens,
|
MaxTokens: maxTokens,
|
||||||
Temperature: temperature,
|
Temperature: temperature,
|
||||||
|
Context: context,
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("Message2RequestMessage(messages):", Message2RequestMessage(messages))
|
|
||||||
|
|
||||||
jsonBody, err := json.Marshal(requestBody)
|
jsonBody, err := json.Marshal(requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return AnthropicChatCompletionResponse{}, fmt.Errorf("error marshaling JSON: %w", err)
|
return AnthropicChatCompletionResponse{}, fmt.Errorf("error marshaling JSON: %w", err)
|
||||||
|
@ -47,7 +47,7 @@ type GoogleChoice struct {
|
|||||||
func addGoogleMessage(llm LLM, selected bool) edgedb.UUID {
|
func addGoogleMessage(llm LLM, selected bool) edgedb.UUID {
|
||||||
Messages := getAllSelectedMessages()
|
Messages := getAllSelectedMessages()
|
||||||
|
|
||||||
chatCompletion, err := RequestGoogle(llm.Model.ModelID, Messages, float64(llm.Temperature))
|
chatCompletion, err := RequestGoogle(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
} else if len(chatCompletion.Choices) == 0 {
|
} else if len(chatCompletion.Choices) == 0 {
|
||||||
@ -116,7 +116,7 @@ func TestGoogleKey(apiKey string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func RequestGoogle(model string, messages []Message, temperature float64) (OpenaiChatCompletionResponse, error) {
|
func RequestGoogle(model string, messages []Message, temperature float64, context string) (OpenaiChatCompletionResponse, error) {
|
||||||
var apiKey string
|
var apiKey string
|
||||||
err := edgeClient.QuerySingle(edgeCtx, `
|
err := edgeClient.QuerySingle(edgeCtx, `
|
||||||
with
|
with
|
||||||
@ -135,7 +135,7 @@ func RequestGoogle(model string, messages []Message, temperature float64) (Opena
|
|||||||
|
|
||||||
requestBody := OpenaiChatCompletionRequest{
|
requestBody := OpenaiChatCompletionRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Messages: Message2RequestMessage(messages),
|
Messages: Message2RequestMessage(messages, context),
|
||||||
Temperature: temperature,
|
Temperature: temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ type GroqChoice struct {
|
|||||||
func addGroqMessage(llm LLM, selected bool) edgedb.UUID {
|
func addGroqMessage(llm LLM, selected bool) edgedb.UUID {
|
||||||
Messages := getAllSelectedMessages()
|
Messages := getAllSelectedMessages()
|
||||||
|
|
||||||
chatCompletion, err := RequestGroq(llm.Model.ModelID, Messages, float64(llm.Temperature))
|
chatCompletion, err := RequestGroq(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
} else if len(chatCompletion.Choices) == 0 {
|
} else if len(chatCompletion.Choices) == 0 {
|
||||||
@ -67,7 +67,7 @@ func TestGroqKey(apiKey string) bool {
|
|||||||
|
|
||||||
requestBody := GroqChatCompletionRequest{
|
requestBody := GroqChatCompletionRequest{
|
||||||
Model: "llama3-8b-8192",
|
Model: "llama3-8b-8192",
|
||||||
Messages: Message2RequestMessage(groqMessages),
|
Messages: Message2RequestMessage(groqMessages, ""),
|
||||||
Temperature: 0,
|
Temperature: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,7 +107,7 @@ func TestGroqKey(apiKey string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func RequestGroq(model string, messages []Message, temperature float64) (GroqChatCompletionResponse, error) {
|
func RequestGroq(model string, messages []Message, temperature float64, context string) (GroqChatCompletionResponse, error) {
|
||||||
var apiKey string
|
var apiKey string
|
||||||
err := edgeClient.QuerySingle(edgeCtx, `
|
err := edgeClient.QuerySingle(edgeCtx, `
|
||||||
with
|
with
|
||||||
@ -126,7 +126,7 @@ func RequestGroq(model string, messages []Message, temperature float64) (GroqCha
|
|||||||
|
|
||||||
requestBody := GroqChatCompletionRequest{
|
requestBody := GroqChatCompletionRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Messages: Message2RequestMessage(messages),
|
Messages: Message2RequestMessage(messages, context),
|
||||||
Temperature: temperature,
|
Temperature: temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ func RequestHuggingface(llm LLM, messages []Message, temperature float64) (Huggi
|
|||||||
|
|
||||||
requestBody := HuggingfaceChatCompletionRequest{
|
requestBody := HuggingfaceChatCompletionRequest{
|
||||||
Model: "tgi",
|
Model: "tgi",
|
||||||
Messages: Message2RequestMessage(messages),
|
Messages: Message2RequestMessage(messages, llm.Context),
|
||||||
Temperature: temperature,
|
Temperature: temperature,
|
||||||
Stream: false,
|
Stream: false,
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ type MistralChoice struct {
|
|||||||
func addMistralMessage(llm LLM, selected bool) edgedb.UUID {
|
func addMistralMessage(llm LLM, selected bool) edgedb.UUID {
|
||||||
Messages := getAllSelectedMessages()
|
Messages := getAllSelectedMessages()
|
||||||
|
|
||||||
chatCompletion, err := RequestMistral(llm.Model.ModelID, Messages, float64(llm.Temperature))
|
chatCompletion, err := RequestMistral(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
} else if len(chatCompletion.Choices) == 0 {
|
} else if len(chatCompletion.Choices) == 0 {
|
||||||
@ -113,7 +113,7 @@ func TestMistralKey(apiKey string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func RequestMistral(model string, messages []Message, temperature float64) (MistralChatCompletionResponse, error) {
|
func RequestMistral(model string, messages []Message, temperature float64, context string) (MistralChatCompletionResponse, error) {
|
||||||
var apiKey string
|
var apiKey string
|
||||||
err := edgeClient.QuerySingle(edgeCtx, `
|
err := edgeClient.QuerySingle(edgeCtx, `
|
||||||
with
|
with
|
||||||
@ -132,7 +132,7 @@ func RequestMistral(model string, messages []Message, temperature float64) (Mist
|
|||||||
|
|
||||||
requestBody := MistralChatCompletionRequest{
|
requestBody := MistralChatCompletionRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Messages: Message2RequestMessage(messages),
|
Messages: Message2RequestMessage(messages, context),
|
||||||
Temperature: temperature,
|
Temperature: temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ type OpenaiChoice struct {
|
|||||||
func addOpenaiMessage(llm LLM, selected bool) edgedb.UUID {
|
func addOpenaiMessage(llm LLM, selected bool) edgedb.UUID {
|
||||||
Messages := getAllSelectedMessages()
|
Messages := getAllSelectedMessages()
|
||||||
|
|
||||||
chatCompletion, err := RequestOpenai(llm.Model.ModelID, Messages, float64(llm.Temperature))
|
chatCompletion, err := RequestOpenai(llm.Model.ModelID, Messages, float64(llm.Temperature), llm.Context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
} else if len(chatCompletion.Choices) == 0 {
|
} else if len(chatCompletion.Choices) == 0 {
|
||||||
@ -107,7 +107,7 @@ func TestOpenaiKey(apiKey string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func RequestOpenai(model string, messages []Message, temperature float64) (OpenaiChatCompletionResponse, error) {
|
func RequestOpenai(model string, messages []Message, temperature float64, context string) (OpenaiChatCompletionResponse, error) {
|
||||||
var apiKey string
|
var apiKey string
|
||||||
err := edgeClient.QuerySingle(edgeCtx, `
|
err := edgeClient.QuerySingle(edgeCtx, `
|
||||||
with
|
with
|
||||||
@ -124,9 +124,13 @@ func RequestOpenai(model string, messages []Message, temperature float64) (Opena
|
|||||||
|
|
||||||
url := "https://api.openai.com/v1/chat/completions"
|
url := "https://api.openai.com/v1/chat/completions"
|
||||||
|
|
||||||
|
fmt.Println(context)
|
||||||
|
|
||||||
|
fmt.Println(Message2RequestMessage(messages, context))
|
||||||
|
|
||||||
requestBody := OpenaiChatCompletionRequest{
|
requestBody := OpenaiChatCompletionRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Messages: Message2RequestMessage(messages),
|
Messages: Message2RequestMessage(messages, context),
|
||||||
Temperature: temperature,
|
Temperature: temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
57
utils.go
57
utils.go
@ -114,24 +114,51 @@ func getExistingKeys() (bool, bool, bool, bool, bool, bool) {
|
|||||||
return openaiExists, anthropicExists, mistralExists, groqExists, gooseaiExists, googleExists
|
return openaiExists, anthropicExists, mistralExists, groqExists, gooseaiExists, googleExists
|
||||||
}
|
}
|
||||||
|
|
||||||
func Message2RequestMessage(messages []Message) []RequestMessage {
|
func Message2RequestMessage(messages []Message, context string) []RequestMessage {
|
||||||
m := make([]RequestMessage, len(messages))
|
// Add context if it exists
|
||||||
for i, msg := range messages {
|
if context != "" {
|
||||||
var role string
|
m := make([]RequestMessage, len(messages)+1)
|
||||||
switch msg.Role {
|
m[0] = RequestMessage{
|
||||||
case "user":
|
Role: "system",
|
||||||
role = "user"
|
Content: context,
|
||||||
case "bot":
|
|
||||||
role = "assistant"
|
|
||||||
default:
|
|
||||||
role = "system"
|
|
||||||
}
|
}
|
||||||
m[i] = RequestMessage{
|
|
||||||
Role: role,
|
for i, msg := range messages {
|
||||||
Content: msg.Content,
|
var role string
|
||||||
|
switch msg.Role {
|
||||||
|
case "user":
|
||||||
|
role = "user"
|
||||||
|
case "bot":
|
||||||
|
role = "assistant"
|
||||||
|
default:
|
||||||
|
role = "system"
|
||||||
|
}
|
||||||
|
m[i+1] = RequestMessage{
|
||||||
|
Role: role,
|
||||||
|
Content: msg.Content,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return m
|
||||||
|
} else {
|
||||||
|
m := make([]RequestMessage, len(messages))
|
||||||
|
|
||||||
|
for i, msg := range messages {
|
||||||
|
var role string
|
||||||
|
switch msg.Role {
|
||||||
|
case "user":
|
||||||
|
role = "user"
|
||||||
|
case "bot":
|
||||||
|
role = "assistant"
|
||||||
|
default:
|
||||||
|
role = "system"
|
||||||
|
}
|
||||||
|
m[i] = RequestMessage{
|
||||||
|
Role: role,
|
||||||
|
Content: msg.Content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
}
|
}
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAvailableModels() []ModelInfo {
|
func GetAvailableModels() []ModelInfo {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user