From 08e8e27bf4c2c0eab49e5dd65a406469f9fc66e9 Mon Sep 17 00:00:00 2001 From: Adrien Date: Wed, 24 Apr 2024 08:49:58 +0200 Subject: [PATCH] Working one message chat --- RequestOpenai.go | 153 ++++++++++++++++++ .../Simplicity.drawio.png | Bin .../Simplicity2.drawio.png | Bin .../Simplicity3.drawio.png | Bin article.md => article/article.md | 0 main.go | 95 +++++++++-- views/chat.html | 4 +- views/layouts/main.html | 6 +- views/partials/bot-message-placeholder.html | 13 ++ views/partials/bot-message.gohtml | 12 ++ views/partials/chat-input.html | 2 +- views/partials/chat-messages.html | 6 +- views/partials/test-button.html | 1 + views/partials/test-display.html | 1 + views/partials/user-message.html | 5 +- views/welcome.html | 6 +- 16 files changed, 282 insertions(+), 22 deletions(-) create mode 100644 RequestOpenai.go rename Simplicity.drawio.png => article/Simplicity.drawio.png (100%) rename Simplicity2.drawio.png => article/Simplicity2.drawio.png (100%) rename Simplicity3.drawio.png => article/Simplicity3.drawio.png (100%) rename article.md => article/article.md (100%) create mode 100644 views/partials/bot-message-placeholder.html create mode 100644 views/partials/bot-message.gohtml create mode 100644 views/partials/test-button.html create mode 100644 views/partials/test-display.html diff --git a/RequestOpenai.go b/RequestOpenai.go new file mode 100644 index 0000000..7a6b4bf --- /dev/null +++ b/RequestOpenai.go @@ -0,0 +1,153 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "text/template" + "time" + + "github.com/gofiber/fiber/v2" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +type dataForTemplate struct { + Message OpenaiMessage +} + +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []OpenaiMessage `json:"messages"` + Temperature float64 `json:"temperature"` +} + +type OpenaiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Usage OpenaiUsage `json:"usage"` + Choices []OpenaiChoice `json:"choices"` +} + +type OpenaiUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type OpenaiChoice struct { + Message OpenaiMessage `json:"message"` + FinishReason string `json:"finish_reason"` + Index int `json:"index"` +} + +var lastMessageAsked string // TODO Remove this + +func addOpenaiMessage(c *fiber.Ctx) error { + message := lastMessageAsked // TODO Remove this + + chatCompletion, err := RequestOpenai("gpt-3.5-turbo", []Message{{Content: message, Role: "user", Date: time.Now(), ID: primitive.NilObjectID}}, 0.7) + if err != nil { + // Print error + fmt.Println("Error:", err) + return err + } else if len(chatCompletion.Choices) == 0 { + fmt.Println("No response from OpenAI") + return err + } + + collection := mongoClient.Database("chat").Collection("messages") + collection.InsertOne(context.Background(), bson.M{"message": message, "role": "user", "date": time.Now()}) + + // Render bot message MAYBE to optimize + // HOW TO GET STRING OF HTML FROM TEMPLATE + tmpl, err := template.ParseFiles("views/partials/bot-message.gohtml") + if err != nil { + fmt.Println("Error parsing template:", err) + return err + } + + // Add bot message if there is no error + var renderedMessage bytes.Buffer + Message := chatCompletion.Choices[0].Message + if err := tmpl.Execute(&renderedMessage, Message); err != nil { + fmt.Println("Error rendering template:", err) + return err + } + + collection.InsertOne(context.Background(), bson.M{"message": Message.Content, "role": "bot", "date": time.Now()}) + + return c.SendString(renderedMessage.String()) +} + +func Message2OpenaiMessage(message Message) OpenaiMessage { + return OpenaiMessage{ + Role: message.Role, + Content: message.Content, + } +} + +func Messages2OpenaiMessages(messages []Message) []OpenaiMessage { + var openaiMessages []OpenaiMessage + for _, message := range messages { + openaiMessages = append(openaiMessages, Message2OpenaiMessage(message)) + } + return openaiMessages +} + +func RequestOpenai(model string, messages []Message, temperature float64) (ChatCompletionResponse, error) { + apiKey := "sk-proj-f7StCvXCtcmiOOayiVmgT3BlbkFJlVtAcOo3JcrnGq1cPa5o" // TODO Use env variable + url := "https://api.openai.com/v1/chat/completions" + + // Convert messages to OpenAI format + openaiMessages := Messages2OpenaiMessages(messages) + + requestBody := ChatCompletionRequest{ + Model: model, + Messages: openaiMessages, + Temperature: temperature, + } + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return ChatCompletionResponse{}, fmt.Errorf("error marshaling JSON: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + if err != nil { + return ChatCompletionResponse{}, fmt.Errorf("error creating request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return ChatCompletionResponse{}, fmt.Errorf("error sending request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return ChatCompletionResponse{}, fmt.Errorf("error reading response body: %w", err) + } + + var chatCompletionResponse ChatCompletionResponse + err = json.Unmarshal(body, &chatCompletionResponse) + if err != nil { + return ChatCompletionResponse{}, fmt.Errorf("error unmarshaling JSON: %w", err) + } + + return chatCompletionResponse, nil +} diff --git a/Simplicity.drawio.png b/article/Simplicity.drawio.png similarity index 100% rename from Simplicity.drawio.png rename to article/Simplicity.drawio.png diff --git a/Simplicity2.drawio.png b/article/Simplicity2.drawio.png similarity index 100% rename from Simplicity2.drawio.png rename to article/Simplicity2.drawio.png diff --git a/Simplicity3.drawio.png b/article/Simplicity3.drawio.png similarity index 100% rename from Simplicity3.drawio.png rename to article/Simplicity3.drawio.png diff --git a/article.md b/article/article.md similarity index 100% rename from article.md rename to article/article.md diff --git a/main.go b/main.go index 91a8bb1..8a5cf64 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,15 @@ package main import ( + "bufio" "context" + "fmt" "time" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/template/django/v3" + "github.com/valyala/fasthttp" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" @@ -27,6 +30,19 @@ type Conversation struct { Messages []Message } +type User struct { + ID string `bson:"_id"` + Username string `bson:"username"` + OAth2Token string `bson:"oauth2token"` + IsSub bool `bson:"isSub"` +} + +type CurrentSession struct { + ID string + CurrentConversationID string + CurrentUserID string +} + func connectToMongoDB(uri string) { serverAPI := options.ServerAPI(options.ServerAPIVersion1) opts := options.Client().ApplyURI(uri).SetServerAPIOptions(serverAPI) @@ -44,6 +60,7 @@ func connectToMongoDB(uri string) { func main() { // Import HTML using django engine/template engine := django.New("./views", ".html") + if engine == nil { panic("Failed to create django engine") } @@ -71,6 +88,14 @@ func main() { app.Get("/chat", chatPageHandler) // Complete chat page app.Put("/chat", addMessageHandler) // Add message app.Delete("/chat", deleteMessageHandler) // Delete message + app.Get("/loadChat", generateChatHTML) // Load chat + + app.Get("/generateOpenai", addOpenaiMessage) + + app.Get("/sse", sseHandler) // SSE handler + + // Add test button + app.Get("/test-button", testButtonHandler) // Start server app.Listen(":3000") @@ -86,14 +111,22 @@ func chatPageHandler(c *fiber.Ctx) error { }, "layouts/main") } +func testButtonHandler(c *fiber.Ctx) error { + return c.Render("partials/test-button", fiber.Map{}) +} + func addMessageHandler(c *fiber.Ctx) error { message := c.FormValue("message") + lastMessageAsked = message - collection := mongoClient.Database("chat").Collection("messages") - collection.InsertOne(context.Background(), bson.M{"message": message, "role": "user", "date": time.Now()}) - collection.InsertOne(context.Background(), bson.M{"message": "I did something!", "role": "bot", "date": time.Now()}) - - return generateChatHTML(c) + return c.Render("partials/user-message", fiber.Map{ + "Message": Message{ + Content: message, + Role: "user", + Date: time.Now(), + }, + "IncludePlaceholder": true, + }) } func deleteMessageHandler(c *fiber.Ctx) error { @@ -139,7 +172,7 @@ func deleteMessageHandler(c *fiber.Ctx) error { }) } - return generateChatHTML(c) + return c.SendString("") } func generateChatHTML(c *fiber.Ctx) error { @@ -155,16 +188,17 @@ func generateChatHTML(c *fiber.Ctx) error { } // Convert the cursor to an array of messages - var messages []Message - if err = cursor.All(context.TODO(), &messages); err != nil { + var Messages []Message + if err = cursor.All(context.TODO(), &Messages); err != nil { return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ "error": "Failed to convert cursor to array", }) } // Render the HTML template with the messages - return c.Render("chat", fiber.Map{ - "messages": messages, + return c.Render("partials/chat-messages", fiber.Map{ + "Messages": Messages, + "IncludePlaceholder": false, }) } @@ -172,5 +206,44 @@ func isMongoDBConnectedHandler(c *fiber.Ctx) error { if mongoClient != nil { return c.SendString("

Connected

") } - return c.SendString("

Not connected

") + return c.SendString("

Not connected

") +} + +// SSE Stuff +var ( + eventChannel = make(chan string) +) + +func sseHandler(c *fiber.Ctx) error { + c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + fmt.Println("WRITER") + for { + select { + case msg := <-eventChannel: + fmt.Fprintf(w, "data: %s\n\n", msg) + err := w.Flush() + if err != nil { + fmt.Printf("Error while flushing: %v. Closing http connection.\n", err) + return + } + default: + if c.Context() != nil && c.Context().Done() != nil { + select { + case <-c.Context().Done(): + fmt.Println("Client connection closed") + return + default: + } + } + time.Sleep(100 * time.Millisecond) + } + } + })) + + return nil } diff --git a/views/chat.html b/views/chat.html index eb752ca..53e97f4 100644 --- a/views/chat.html +++ b/views/chat.html @@ -1,5 +1,5 @@
-

Chat Page

- {% include "partials/chat-messages.html" %} +

Chat Page

+
{% include "partials/chat-input.html" %}
\ No newline at end of file diff --git a/views/layouts/main.html b/views/layouts/main.html index 7b70653..71559ad 100644 --- a/views/layouts/main.html +++ b/views/layouts/main.html @@ -8,12 +8,14 @@ - + + + + {% include "partials/navbar.html" %} {{embed}} -