diff --git a/Chat.go b/Chat.go index 9eaff00..ef620b5 100644 --- a/Chat.go +++ b/Chat.go @@ -3,7 +3,6 @@ package main import ( "context" "encoding/json" - "fmt" "sort" "strings" "time" @@ -20,7 +19,7 @@ func ChatPageHandler(c *fiber.Ctx) error { edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": authCookie}) } - return c.Render("chat", fiber.Map{"IsLogin": checkIfLogin(), "HaveKey": checkIfHaveKey()}, "layouts/main") + return c.Render("chat", fiber.Map{"IsLogin": checkIfLogin(), "HaveKey": checkIfHaveKey(), "IsSubscribed": IsCurrentUserSubscribed(), "IsLimiteReached": IsCurrentUserLimiteReached()}, "layouts/main") } func DeleteMessageHandler(c *fiber.Ctx) error { @@ -49,7 +48,9 @@ func LoadChatHandler(c *fiber.Ctx) error { deleteLLMtoDelete() if checkIfLogin() { - if getCurrentUserKeys() == nil { + if IsCurrentUserLimiteReached() && !IsCurrentUserSubscribed() { + return c.SendString(generateLimitReachedChatHTML()) + } else if getCurrentUserKeys() == nil { return c.SendString(generateEnterKeyChatHTML()) } return c.SendString(generateChatHTML()) @@ -272,6 +273,38 @@ func generateEnterKeyChatHTML() string { return htmlString } +func generateLimitReachedChatHTML() string { + welcomeMessage := `You have reached the maximum number of messages for a free account. Please upgrade your account to continue using JADE.` + + stripeTable := ` + + ` + + htmlString := "
" + + NextMessages := []TemplateMessage{} + nextMsg := TemplateMessage{ + Icon: "icons/bouvai2.png", // Assuming Icon is a field you want to include from Message + Content: "
" + markdownToHTML(welcomeMessage) + "
" + stripeTable, + Hidden: false, // Assuming Hidden is a field you want to include from Message + Id: "0", + Name: "JADE", + } + NextMessages = append(NextMessages, nextMsg) + + botOut, err := botTmpl.Execute(pongo2.Context{"Messages": NextMessages, "ConversationAreaId": 0, "NotClickable": true}) + if err != nil { + panic(err) + } + htmlString += botOut + htmlString += "
" + htmlString += "
" + + // Render the HTML template with the messages + return htmlString +} + // Buton actions func GetEditMessageFormHandler(c *fiber.Ctx) error { id := c.FormValue("id") @@ -430,7 +463,6 @@ func LoadUsageKPIHandler(c *fiber.Ctx) error { } FILTER .total_count > 0 ORDER BY .total_cost DESC `, &usages, InputDate, InputDate.AddDate(0, 1, 0)) if err != nil { - fmt.Println(err) panic(err) } @@ -552,6 +584,7 @@ func LoadSettingsHandler(c *fiber.Ctx) error { "GooseaiExists": gooseaiExists, "GoogleExists": googleExists, "AnyExists": openaiExists || anthropicExists || mistralExists || groqExists || gooseaiExists || googleExists, + "IsSub": IsCurrentUserSubscribed(), }) if err != nil { panic(err) diff --git a/LLM.go b/LLM.go index 1692c40..9a0015c 100644 --- a/LLM.go +++ b/LLM.go @@ -2,7 +2,6 @@ package main import ( "encoding/json" - "fmt" "strconv" "github.com/edgedb/edgedb-go" @@ -61,8 +60,6 @@ func createLLM(c *fiber.Ctx) error { token := c.FormValue("model-key-input") customID := c.FormValue("model-cid-input") - fmt.Println(name, modelID, temperatureFloat, systemPrompt, url, token) - if modelID == "custom" { err := edgeClient.Execute(edgeCtx, ` INSERT LLM { @@ -85,7 +82,6 @@ func createLLM(c *fiber.Ctx) error { }; `, name, systemPrompt, temperatureFloat, url, token, customID) // TODO Add real max token if err != nil { - fmt.Println("Error in createLLM: ", err) panic(err) } } else { diff --git a/Request.go b/Request.go index 2575191..cb684c7 100644 --- a/Request.go +++ b/Request.go @@ -153,7 +153,6 @@ func GenerateMultipleMessagesHandler(c *fiber.Ctx) error { FILTER .id = $0; `, &message, messageID) if err != nil { - fmt.Println("Is it here ?") panic(err) } diff --git a/RequestGoogle.go b/RequestGoogle.go index ab8440d..947150b 100644 --- a/RequestGoogle.go +++ b/RequestGoogle.go @@ -103,8 +103,6 @@ func TestGoogleKey(apiKey string) bool { return false } - fmt.Println(string(body)) - var chatCompletionResponse GoogleChatCompletionResponse err = json.Unmarshal(body, &chatCompletionResponse) if err != nil { diff --git a/RequestGooseai.go b/RequestGooseai.go index 5348f69..076d3d2 100644 --- a/RequestGooseai.go +++ b/RequestGooseai.go @@ -82,9 +82,6 @@ func TestGooseaiKey(apiKey string) bool { return false } - // Print the response body - fmt.Println(string(body)) - var chatCompletionResponse GooseaiCompletionResponse err = json.Unmarshal(body, &chatCompletionResponse) if err != nil { diff --git a/RequestHuggingface.go b/RequestHuggingface.go index 752e2ac..14fb59f 100644 --- a/RequestHuggingface.go +++ b/RequestHuggingface.go @@ -77,8 +77,6 @@ func RequestHuggingface(llm LLM, messages []Message, temperature float64) (Huggi req.Header.Set("Authorization", "Bearer "+llm.Endpoint.Key) req.Header.Set("Content-Type", "application/json") - fmt.Println(url, llm.Endpoint.Key) - client := &http.Client{} resp, err := client.Do(req) if err != nil { diff --git a/RequestMistral.go b/RequestMistral.go index 7f6b7e3..a53867d 100644 --- a/RequestMistral.go +++ b/RequestMistral.go @@ -43,7 +43,6 @@ func addMistralMessage(llm LLM, selected bool) edgedb.UUID { if err != nil { panic(err) } else if len(chatCompletion.Choices) == 0 { - fmt.Println("No response from Mistral") id := insertBotMessage("No response from Mistral", selected, llm.ID) return id } else { @@ -72,13 +71,11 @@ func TestMistralKey(apiKey string) bool { jsonBody, err := json.Marshal(requestBody) if err != nil { - fmt.Println("Error:", err) return false } req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) if err != nil { - fmt.Println("Error:", err) return false } @@ -89,25 +86,21 @@ func TestMistralKey(apiKey string) bool { client := &http.Client{} resp, err := client.Do(req) if err != nil { - fmt.Println("Error:", err) return false } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - fmt.Println("Error:", err) return false } var chatCompletionResponse MistralChatCompletionResponse err = json.Unmarshal(body, &chatCompletionResponse) if err != nil { - fmt.Println("Error:", err) return false } if chatCompletionResponse.Usage.CompletionTokens == 0 { - fmt.Println("Error: No response from Mistral") return false } return true diff --git a/RequestOpenai.go b/RequestOpenai.go index dbd0cfc..f166f55 100644 --- a/RequestOpenai.go +++ b/RequestOpenai.go @@ -124,10 +124,6 @@ func RequestOpenai(model string, messages []Message, temperature float64, contex url := "https://api.openai.com/v1/chat/completions" - fmt.Println(context) - - fmt.Println(Message2RequestMessage(messages, context)) - requestBody := OpenaiChatCompletionRequest{ Model: model, Messages: Message2RequestMessage(messages, context), diff --git a/Stripe.go b/Stripe.go index 51a08bd..b47162e 100644 --- a/Stripe.go +++ b/Stripe.go @@ -51,5 +51,10 @@ func generatePricingTableChatHTML() string { func IsCurrentUserSubscribed() bool { // TODO Ask Stripe if user is subscribed - return false + return true +} + +func IsCurrentUserLimiteReached() bool { + // TODO Ask Stripe if user is subscribed + return true } diff --git a/database.go b/database.go index b83ca2a..ff270a3 100644 --- a/database.go +++ b/database.go @@ -10,6 +10,11 @@ import ( var edgeCtx context.Context var edgeClient *edgedb.Client +type Identity struct { + ID edgedb.UUID `edgedb:"id"` + Issuer string `edgedb:"issuer"` +} + type User struct { ID edgedb.UUID `edgedb:"id"` Setting Setting `edgedb:"setting"` diff --git a/dbschema/default.esdl b/dbschema/default.esdl index 530e74d..8801852 100644 --- a/dbschema/default.esdl +++ b/dbschema/default.esdl @@ -10,7 +10,9 @@ module default { type User { required setting: Setting; - required identity: ext::auth::Identity; + required identity: ext::auth::Identity { + on source delete delete target; + } } type Key { diff --git a/dbschema/migrations/00035-m1e72ub.edgeql b/dbschema/migrations/00035-m1e72ub.edgeql new file mode 100644 index 0000000..28e814b --- /dev/null +++ b/dbschema/migrations/00035-m1e72ub.edgeql @@ -0,0 +1,9 @@ +CREATE MIGRATION m1e72ubyamp6762oufm57qqns7tkr4o25gc3l62iml7sltlytgx5lq + ONTO m1x75hdgm27pmshypxbzfrhje6xru5ypx65efdiu6zuwnute2xschq +{ + ALTER TYPE default::User { + ALTER LINK identity { + ON SOURCE DELETE DELETE TARGET; + }; + }; +}; diff --git a/go.mod b/go.mod index 466800e..b407a35 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( github.com/pelletier/go-toml/v2 v2.2.1 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/sigurn/crc16 v0.0.0-20211026045750-20ab5afb07e3 // indirect + github.com/stripe/stripe-go/v78 v78.7.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect diff --git a/go.sum b/go.sum index f8b2195..b7f06c2 100644 --- a/go.sum +++ b/go.sum @@ -108,6 +108,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stripe/stripe-go/v78 v78.7.0 h1:TdTkzBn0wB0ntgOI74YHpvsNyHPBijX83n4ljsjXh6o= +github.com/stripe/stripe-go/v78 v78.7.0/go.mod h1:GjncxVLUc1xoIOidFqVwq+y3pYiG7JLVWiVQxTsLrvQ= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= @@ -152,6 +154,7 @@ golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= @@ -163,6 +166,7 @@ golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -175,6 +179,7 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= diff --git a/login.go b/login.go index 5275236..81d7ae0 100644 --- a/login.go +++ b/login.go @@ -7,12 +7,110 @@ import ( "encoding/json" "fmt" "io" + "io/ioutil" "net/http" "github.com/edgedb/edgedb-go" "github.com/gofiber/fiber/v2" ) +type DiscoveryDocument struct { + UserInfoEndpoint string `json:"userinfo_endpoint"` +} + +type UserProfile struct { + Email string `json:"email"` + Name string `json:"name"` +} + +func getGoogleUserProfile(providerToken string) (string, string) { + // Fetch the discovery document + resp, err := http.Get("https://accounts.google.com/.well-known/openid-configuration") + if err != nil { + panic(fmt.Sprintf("failed to fetch discovery document: %v", err)) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + panic(fmt.Sprintf("failed to fetch discovery document: status code %d", resp.StatusCode)) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + panic(fmt.Sprintf("failed to read discovery document response: %v", err)) + } + + var discoveryDocument DiscoveryDocument + if err := json.Unmarshal(body, &discoveryDocument); err != nil { + panic(fmt.Sprintf("failed to unmarshal discovery document: %v", err)) + } + + // Fetch the user profile + req, err := http.NewRequest("GET", discoveryDocument.UserInfoEndpoint, nil) + if err != nil { + panic(fmt.Sprintf("failed to create user profile request: %v", err)) + } + req.Header.Set("Authorization", "Bearer "+providerToken) + req.Header.Set("Accept", "application/json") + + client := &http.Client{} + resp, err = client.Do(req) + if err != nil { + panic(fmt.Sprintf("failed to fetch user profile: %v", err)) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + panic(fmt.Sprintf("failed to fetch user profile: status code %d", resp.StatusCode)) + } + + body, err = ioutil.ReadAll(resp.Body) + if err != nil { + panic(fmt.Sprintf("failed to read user profile response: %v", err)) + } + + var userProfile UserProfile + if err := json.Unmarshal(body, &userProfile); err != nil { + panic(fmt.Sprintf("failed to unmarshal user profile: %v", err)) + } + + return userProfile.Email, userProfile.Name +} + +func getGitHubUserProfile(providerToken string) (string, string) { + // Create the request to fetch the user profile + req, err := http.NewRequest("GET", "https://api.github.com/user", nil) + if err != nil { + panic(fmt.Sprintf("failed to create user profile request: %v", err)) + } + req.Header.Set("Authorization", "Bearer "+providerToken) + req.Header.Set("Accept", "application/vnd.github+json") + req.Header.Set("X-GitHub-Api-Version", "2022-11-28") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + panic(fmt.Sprintf("failed to fetch user profile: %v", err)) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + panic(fmt.Sprintf("failed to fetch user profile: status code %d", resp.StatusCode)) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + panic(fmt.Sprintf("failed to read user profile response: %v", err)) + } + + var userProfile GitHubUserProfile + if err := json.Unmarshal(body, &userProfile); err != nil { + panic(fmt.Sprintf("failed to unmarshal user profile: %v", err)) + } + + return userProfile.Email, userProfile.Name +} + const EDGEDB_AUTH_BASE_URL = "http://127.0.0.1:10700/db/main/ext/auth" func generatePKCE() (string, string) { @@ -67,8 +165,9 @@ func handleCallbackSignup(c *fiber.Ctx) error { } var tokenResponse struct { - AuthToken string `json:"auth_token"` - IdentityID string `json:"identity_id"` + AuthToken string `json:"auth_token"` + IdentityID string `json:"identity_id"` + ProviderToken string `json:"provider_token"` } err = json.NewDecoder(resp.Body).Decode(&tokenResponse) if err != nil { @@ -84,9 +183,45 @@ func handleCallbackSignup(c *fiber.Ctx) error { SameSite: "Strict", }) + // Get the issuer of the identity + var identity Identity + identityUUID, err := edgedb.ParseUUID(tokenResponse.IdentityID) + if err != nil { + panic(err) + } + err = edgeClient.QuerySingle(edgeCtx, ` + SELECT ext::auth::Identity { + issuer + } FILTER .id = $0 + `, &identity, identityUUID) + if err != nil { + panic(err) + } + + var ( + providerEmail string + providerName string + ) + + // Get the email and name from the provider + if identity.Issuer == "https://accounts.google.com" { + providerEmail, providerName = getGoogleUserProfile(tokenResponse.ProviderToken) + } else if identity.Issuer == "https://github.com" { + providerEmail, providerName = getGithubUserProfile(tokenResponse.ProviderToken) + } + + fmt.Println(providerEmail, providerName) + + // Create stripe user + //stripe.Key = "sk_test_51OxXuWP2nW0okNQyiNAOcBTTWZSiyP1el5KOmV3yIv1DQR0415YPsH1eb89SLrsOFj80o9p2AxGOy042e53yDvZN00jHxHAbE6" + // + //params := &stripe.CustomerParams{ + // Name: stripe.String( + // Email: stripe.String(), + //} + //result, err := customer.New(params) + // Create a new User and attach the identity - var identityUUID edgedb.UUID - identityUUID, err = edgedb.ParseUUID(tokenResponse.IdentityID) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Error in edgedb.ParseUUID: in handleCallbackSignup: %s", err.Error())) } @@ -119,7 +254,6 @@ func handleCallback(c *fiber.Ctx) error { verifier := c.Cookies("jade-edgedb-pkce-verifier", "") if verifier == "" { - fmt.Println("Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?") return c.Status(fiber.StatusBadRequest).SendString("Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?") } diff --git a/main.go b/main.go index f622a42..b3d67ed 100644 --- a/main.go +++ b/main.go @@ -96,11 +96,6 @@ func main() { app.Get("deleteLLM", deleteLLM) app.Post("/createLLM", createLLM) - app.Get("/test", func(c *fiber.Ctx) error { - fmt.Println("Hello from test") - return c.SendString("") - }) - app.Get("/empty", func(c *fiber.Ctx) error { return c.SendString("") }) @@ -155,7 +150,6 @@ func addKeys(c *fiber.Ctx) error { // Handle OpenAI key if openaiKey != "" { if !TestOpenaiKey(openaiKey) { - fmt.Println("Invalid OpenAI API Key") return c.SendString("Invalid OpenAI API Key\n") } @@ -171,7 +165,6 @@ func addKeys(c *fiber.Ctx) error { } if Exists { - fmt.Println("Company key already exists") err = edgeClient.Execute(edgeCtx, ` UPDATE Key filter .company.name = $0 AND .key = $1 SET { @@ -204,7 +197,6 @@ func addKeys(c *fiber.Ctx) error { // Handle Anthropic key if anthropicKey != "" { if !TestAnthropicKey(anthropicKey) { - fmt.Println("Invalid Anthropic API Key") return c.SendString("Invalid Anthropic API Key\n") } @@ -252,7 +244,6 @@ func addKeys(c *fiber.Ctx) error { // Handle Mistral key if mistralKey != "" { if !TestMistralKey(mistralKey) { - fmt.Println("Invalid Mistral API Key") return c.SendString("Invalid Mistral API Key\n") } @@ -300,7 +291,6 @@ func addKeys(c *fiber.Ctx) error { // Handle Groq key if groqKey != "" { if !TestGroqKey(groqKey) { - fmt.Println("Invalid Groq API Key") return c.SendString("Invalid Groq API Key\n") } @@ -348,7 +338,6 @@ func addKeys(c *fiber.Ctx) error { // Handle Gooseai key if gooseaiKey != "" { if !TestGooseaiKey(gooseaiKey) { - fmt.Println("Invalid Gooseai API Key") return c.SendString("Invalid Gooseai API Key\n") } @@ -396,7 +385,6 @@ func addKeys(c *fiber.Ctx) error { // Handle Google key if googleKey != "" { if !TestGoogleKey(googleKey) { - fmt.Println("Invalid Google API Key") return c.SendString("Invalid Google API Key\n") } diff --git a/views/chat.html b/views/chat.html index 95b0065..7e30506 100644 --- a/views/chat.html +++ b/views/chat.html @@ -1,6 +1,7 @@
+ {% if IsSubscribed or not IsLimiteReached %}