264 lines
7.9 KiB
Go
264 lines
7.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"sync"
|
|
|
|
"github.com/flosch/pongo2"
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/fiber/v2/middleware/logger"
|
|
"github.com/gofiber/fiber/v2/middleware/recover"
|
|
"github.com/gofiber/template/django/v3"
|
|
"github.com/stripe/stripe-go"
|
|
)
|
|
|
|
var (
|
|
userTmpl *pongo2.Template
|
|
botTmpl *pongo2.Template
|
|
selectBtnTmpl *pongo2.Template
|
|
modelPopoverTmpl *pongo2.Template
|
|
usagePopoverTmpl *pongo2.Template
|
|
settingPopoverTmpl *pongo2.Template
|
|
messageEditTmpl *pongo2.Template
|
|
conversationPopoverTmpl *pongo2.Template
|
|
welcomeChatTmpl *pongo2.Template
|
|
chatInputTmpl *pongo2.Template
|
|
explainLLMconvChatTmpl *pongo2.Template
|
|
mu sync.Mutex
|
|
app *fiber.App
|
|
userSSEChannels = make(map[string]chan SSE)
|
|
)
|
|
|
|
// SSE event structure
|
|
type SSE struct {
|
|
Event string
|
|
Data string
|
|
}
|
|
|
|
// Function to send events to all clients
|
|
func sendEvent(userID string, event string, data string) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
userEvents, ok := userSSEChannels[userID]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
userEvents <- SSE{Event: event, Data: data}
|
|
}
|
|
|
|
func main() {
|
|
// Use STRIPE_KEY environment variable
|
|
stripe.Key = os.Getenv("STRIPE_KEY")
|
|
|
|
botTmpl = pongo2.Must(pongo2.FromFile("views/partials/message-bot.html"))
|
|
userTmpl = pongo2.Must(pongo2.FromFile("views/partials/message-user.html"))
|
|
selectBtnTmpl = pongo2.Must(pongo2.FromFile("views/partials/model-selection-btn.html"))
|
|
modelPopoverTmpl = pongo2.Must(pongo2.FromFile("views/partials/popover-models.html"))
|
|
conversationPopoverTmpl = pongo2.Must(pongo2.FromFile("views/partials/popover-conversation.html"))
|
|
usagePopoverTmpl = pongo2.Must(pongo2.FromFile("views/partials/popover-usage.html"))
|
|
settingPopoverTmpl = pongo2.Must(pongo2.FromFile("views/partials/popover-settings.html"))
|
|
messageEditTmpl = pongo2.Must(pongo2.FromFile("views/partials/message-edit-form.html"))
|
|
welcomeChatTmpl = pongo2.Must(pongo2.FromFile("views/partials/welcome-chat.html"))
|
|
chatInputTmpl = pongo2.Must(pongo2.FromFile("views/partials/chat-input.html"))
|
|
explainLLMconvChatTmpl = pongo2.Must(pongo2.FromFile("views/partials/explain-llm-conv-chat.html"))
|
|
|
|
// Import HTML using django engine/template
|
|
engine := django.New("./views", ".html")
|
|
|
|
// Create new Fiber instance
|
|
app = fiber.New(fiber.Config{
|
|
Views: engine,
|
|
AppName: "JADE",
|
|
})
|
|
defer app.Shutdown()
|
|
|
|
// Add default logger
|
|
app.Use(logger.New())
|
|
app.Use(recover.New())
|
|
|
|
// Main routes
|
|
app.Get("/", ChatPageHandler)
|
|
app.Get("/loadChat", LoadChatHandler)
|
|
app.Get("/loadChatInput", LoadChatInputHandler)
|
|
app.Get("/pricingTable", PricingTableHandler)
|
|
app.Get("/generateTermAndService", generateTermAndServiceHandler)
|
|
|
|
// Chat routes
|
|
app.Post("/deleteMessage", DeleteMessageHandler)
|
|
app.Post("/generatePlaceholder", GeneratePlaceholderHandler)
|
|
app.Get("/generateMultipleMessages", GenerateMultipleMessagesHandler)
|
|
app.Get("/messageContent", GetMessageContentHandler)
|
|
app.Get("/editMessageForm", GetEditMessageFormHandler)
|
|
app.Post("/redoMessage", RedoMessageHandler)
|
|
app.Post("/clearChat", ClearChatHandler)
|
|
app.Get("/userMessage", GetUserMessageHandler)
|
|
app.Post("/editMessage", EditMessageHandler)
|
|
app.Get("/help", generateHelpChatHandler)
|
|
app.Get("/selectionBtn", GetSelectionBtnHandler)
|
|
|
|
// Settings routes
|
|
app.Post("/addKeys", addKeys)
|
|
|
|
// Popovers
|
|
app.Get("/loadModelSelection", LoadModelSelectionHandler)
|
|
app.Get("/loadConversationSelection", LoadConversationSelectionHandler)
|
|
app.Get("/loadUsageKPI", LoadUsageKPIHandler)
|
|
app.Get("/loadSettings", LoadSettingsHandler)
|
|
app.Get("/refreshConversationSelection", RefreshConversationSelectionHandler)
|
|
|
|
// Conversation routes
|
|
app.Get("/createConversation", CreateConversationHandler)
|
|
app.Get("/deleteConversation", DeleteConversationHandler)
|
|
app.Get("/selectConversation", SelectConversationHandler)
|
|
app.Post("/updateConversationPositionBatch", updateConversationPositionBatch)
|
|
app.Post("/archiveDefaultConversation", ArchiveDefaultConversationHandler)
|
|
|
|
// Authentication
|
|
app.Get("/signin", handleUiSignIn)
|
|
app.Get("/signout", handleSignOut)
|
|
app.Get("/callback", handleCallback)
|
|
app.Get("/callbackSignup", handleCallbackSignup)
|
|
|
|
// LLM
|
|
app.Get("/deleteLLM", deleteLLM)
|
|
app.Post("/createLLM", createLLM)
|
|
app.Post("/updateLLMPositionBatch", updateLLMPositionBatch)
|
|
|
|
// Add static files
|
|
app.Static("/", "./static")
|
|
|
|
app.Get("/empty", func(c *fiber.Ctx) error {
|
|
return c.SendString("")
|
|
})
|
|
|
|
app.Get("/sse", handleSSE)
|
|
|
|
// Start server
|
|
if err := app.Listen(":8080"); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func handleSSE(c *fiber.Ctx) error {
|
|
userID := c.Query("userID") // Get userID from query parameter
|
|
if userID == "" {
|
|
return c.Status(fiber.StatusBadRequest).SendString("Missing userID")
|
|
}
|
|
|
|
events := make(chan SSE, 500)
|
|
mu.Lock()
|
|
userSSEChannels[userID] = events
|
|
mu.Unlock()
|
|
|
|
// Create a context copy to use in the goroutine
|
|
ctx := c.Context()
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
mu.Lock()
|
|
delete(userSSEChannels, userID)
|
|
mu.Unlock()
|
|
close(events)
|
|
}()
|
|
|
|
c.Set("Content-Type", "text/event-stream")
|
|
c.Set("Cache-Control", "no-cache")
|
|
c.Set("Connection", "keep-alive")
|
|
|
|
c.Context().SetBodyStreamWriter(func(w *bufio.Writer) {
|
|
for event := range events {
|
|
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event.Event, event.Data); err != nil {
|
|
fmt.Println(err)
|
|
return
|
|
}
|
|
w.Flush()
|
|
}
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func addKeys(c *fiber.Ctx) error {
|
|
keys := map[string]string{
|
|
"openai": c.FormValue("openai_key"),
|
|
"anthropic": c.FormValue("anthropic_key"),
|
|
"mistral": c.FormValue("mistral_key"),
|
|
"groq": c.FormValue("groq_key"),
|
|
"gooseai": c.FormValue("goose_key"),
|
|
"google": c.FormValue("google_key"),
|
|
"nim": c.FormValue("nim_key"),
|
|
"perplexity": c.FormValue("perplexity_key"),
|
|
"fireworks": c.FormValue("fireworks_key"),
|
|
}
|
|
|
|
testFunctions := map[string]func(string) bool{
|
|
"openai": TestOpenaiKey,
|
|
"anthropic": TestAnthropicKey,
|
|
"mistral": TestMistralKey,
|
|
"groq": TestGroqKey,
|
|
"gooseai": TestGooseaiKey,
|
|
"google": TestGoogleKey,
|
|
"nim": TestNimKey,
|
|
"perplexity": TestPerplexityKey,
|
|
"fireworks": TestFireworkKey,
|
|
}
|
|
|
|
for company, key := range keys {
|
|
if key != "" {
|
|
if !testFunctions[company](key) {
|
|
return c.SendString(fmt.Sprintf("Invalid %s API Key\n", company))
|
|
}
|
|
|
|
var Exists bool
|
|
err := edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).QuerySingle(edgeCtx, `
|
|
select exists (
|
|
select global currentUser.setting.keys
|
|
filter .company.name = <str>$0
|
|
);
|
|
`, &Exists, company)
|
|
if err != nil {
|
|
fmt.Println("Error checking if key exists")
|
|
panic(err)
|
|
}
|
|
|
|
if Exists {
|
|
err = edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).Execute(edgeCtx, `
|
|
UPDATE Key filter .company.name = <str>$0 AND .key = <str>$1
|
|
SET {
|
|
key := <str>$1,
|
|
}
|
|
`, company, key)
|
|
if err != nil {
|
|
fmt.Println("Error updating key")
|
|
panic(err)
|
|
}
|
|
} else {
|
|
err = edgeGlobalClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": c.Cookies("jade-edgedb-auth-token")}).Execute(edgeCtx, `
|
|
WITH
|
|
c := (SELECT Company FILTER .name = <str>$0 LIMIT 1)
|
|
UPDATE global currentUser.setting
|
|
SET {
|
|
keys += (
|
|
INSERT Key {
|
|
company := c,
|
|
key := <str>$1,
|
|
name := <str>$2 ++ " API Key",
|
|
}
|
|
)
|
|
}`, company, key, company)
|
|
if err != nil {
|
|
fmt.Println("Error adding key")
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return c.SendString("<script>window.location.reload()</script>")
|
|
}
|