Jade/login.go

302 lines
9.0 KiB
Go

package main
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/edgedb/edgedb-go"
"github.com/gofiber/fiber/v2"
)
type DiscoveryDocument struct {
UserInfoEndpoint string `json:"userinfo_endpoint"`
}
type UserProfile struct {
Email string `json:"email"`
Name string `json:"name"`
AvatarGitHub string `json:"avatar_url"`
AvatarGoogle string `json:"picture"`
}
type TokenResponse struct {
AuthToken string `json:"auth_token"`
IdentityID string `json:"identity_id"`
ProviderToken string `json:"provider_token"`
}
const EDGEDB_AUTH_BASE_URL = "http://127.0.0.1:10700/db/main/ext/auth"
func getGoogleUserProfile(providerToken string) (string, string, string) {
// Fetch the discovery document
resp, err := http.Get("https://accounts.google.com/.well-known/openid-configuration")
if err != nil {
panic(fmt.Sprintf("failed to fetch discovery document: %v", err))
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
panic(fmt.Sprintf("failed to fetch discovery document: status code %d", resp.StatusCode))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
panic(fmt.Sprintf("failed to read discovery document response: %v", err))
}
var discoveryDocument DiscoveryDocument
if err := json.Unmarshal(body, &discoveryDocument); err != nil {
panic(fmt.Sprintf("failed to unmarshal discovery document: %v", err))
}
// Fetch the user profile
req, err := http.NewRequest("GET", discoveryDocument.UserInfoEndpoint, nil)
if err != nil {
panic(fmt.Sprintf("failed to create user profile request: %v", err))
}
req.Header.Set("Authorization", "Bearer "+providerToken)
req.Header.Set("Accept", "application/json")
client := &http.Client{}
resp, err = client.Do(req)
if err != nil {
panic(fmt.Sprintf("failed to fetch user profile: %v", err))
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
panic(fmt.Sprintf("failed to fetch user profile: status code %d", resp.StatusCode))
}
body, err = io.ReadAll(resp.Body)
if err != nil {
panic(fmt.Sprintf("failed to read user profile response: %v", err))
}
var userProfile UserProfile
if err := json.Unmarshal(body, &userProfile); err != nil {
panic(fmt.Sprintf("failed to unmarshal user profile: %v", err))
}
return userProfile.Email, userProfile.Name, userProfile.AvatarGoogle
}
func getGitHubUserProfile(providerToken string) (string, string, string) {
// Create the request to fetch the user profile
req, err := http.NewRequest("GET", "https://api.github.com/user", nil)
if err != nil {
panic(fmt.Sprintf("failed to create user profile request: %v", err))
}
req.Header.Set("Authorization", "Bearer "+providerToken)
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
panic(fmt.Sprintf("failed to fetch user profile: %v", err))
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
panic(fmt.Sprintf("failed to fetch user profile: status code %d", resp.StatusCode))
}
body, err := io.ReadAll(resp.Body)
if err != nil {
panic(fmt.Sprintf("failed to read user profile response: %v", err))
}
var userProfile UserProfile
if err := json.Unmarshal(body, &userProfile); err != nil {
panic(fmt.Sprintf("failed to unmarshal user profile: %v", err))
}
return userProfile.Email, userProfile.Name, userProfile.AvatarGitHub
}
func generatePKCE() (string, string) {
verifier_source := make([]byte, 32)
_, err := rand.Read(verifier_source)
if err != nil {
panic(err)
}
verifier := base64.RawURLEncoding.EncodeToString(verifier_source)
challenge := sha256.Sum256([]byte(verifier))
return verifier, base64.RawURLEncoding.EncodeToString(challenge[:])
}
func handleUiSignIn(c *fiber.Ctx) error {
verifier, challenge := generatePKCE()
c.Cookie(&fiber.Cookie{
Name: "jade-edgedb-pkce-verifier",
Value: verifier,
HTTPOnly: true,
Path: "/",
Secure: true,
SameSite: "Strict",
})
return c.Redirect(fmt.Sprintf("%s/ui/signup?challenge=%s", EDGEDB_AUTH_BASE_URL, challenge), fiber.StatusTemporaryRedirect)
}
func handleCallbackSignup(c *fiber.Ctx) error {
code := c.Query("code")
if code == "" {
err := c.Query("error")
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("OAuth callback is missing 'code'. OAuth provider responded with error: %s", err))
}
verifier := c.Cookies("jade-edgedb-pkce-verifier", "")
if verifier == "" {
return c.Status(fiber.StatusBadRequest).SendString("Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?")
}
codeExchangeURL := fmt.Sprintf("%s/token?code=%s&verifier=%s", EDGEDB_AUTH_BASE_URL, code, verifier)
resp, err := http.Get(codeExchangeURL)
if err != nil {
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", err.Error()))
}
defer resp.Body.Close()
if resp.StatusCode != fiber.StatusOK {
body, _ := io.ReadAll(resp.Body)
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", string(body)))
}
var tokenResponse TokenResponse
err = json.NewDecoder(resp.Body).Decode(&tokenResponse)
if err != nil {
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error decoding auth server response: %s", err.Error()))
}
c.Cookie(&fiber.Cookie{
Name: "jade-edgedb-auth-token",
Value: tokenResponse.AuthToken,
HTTPOnly: true,
Path: "/",
Secure: true,
SameSite: "Strict",
})
// Get the issuer of the identity
var identity Identity
identityUUID, err := edgedb.ParseUUID(tokenResponse.IdentityID)
if err != nil {
panic(err)
}
err = edgeClient.QuerySingle(edgeCtx, `
SELECT ext::auth::Identity {
issuer
} FILTER .id = <uuid>$0
`, &identity, identityUUID)
if err != nil {
panic(err)
}
var (
providerEmail string
providerName string
providerAvatar string
)
// Get the email and name from the provider
if identity.Issuer == "https://accounts.google.com" {
providerEmail, providerName, providerAvatar = getGoogleUserProfile(tokenResponse.ProviderToken)
} else if identity.Issuer == "https://github.com" {
providerEmail, providerName, providerAvatar = getGitHubUserProfile(tokenResponse.ProviderToken) // Work !!!!
}
stripCustID := CreateNewStripeCustomer(providerName, providerEmail)
err = edgeClient.Execute(edgeCtx, `
INSERT User {
stripe_id := <str>$0,
email := <str>$1,
name := <str>$2,
avatar := <str>$3,
setting := (
INSERT Setting {
default_model := "gpt-3.5-turbo"
}
),
identity := (SELECT ext::auth::Identity FILTER .id = <uuid>$4)
}
`, stripCustID, providerEmail, providerName, providerAvatar, identityUUID)
if err != nil {
return c.Status(fiber.StatusInternalServerError).SendString(fmt.Sprintf("Error in edgedb.QuerySingle: in handleCallbackSignup: %s", err.Error()))
}
edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": tokenResponse.AuthToken})
err = edgeClient.Execute(edgeCtx, `
INSERT Conversation {
name := 'Default',
user := global currentUser,
position := 1
}`)
if err != nil {
panic(err)
}
return c.Redirect("/", fiber.StatusPermanentRedirect)
}
func handleCallback(c *fiber.Ctx) error {
code := c.Query("code")
if code == "" {
err := c.Query("error")
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("OAuth callback is missing 'code'. OAuth provider responded with error: %s", err))
}
verifier := c.Cookies("jade-edgedb-pkce-verifier", "")
if verifier == "" {
return c.Status(fiber.StatusBadRequest).SendString("Could not find 'verifier' in the cookie store. Is this the same user agent/browser that started the authorization flow?")
}
codeExchangeURL := fmt.Sprintf("%s/token?code=%s&verifier=%s", EDGEDB_AUTH_BASE_URL, code, verifier)
resp, err := http.Get(codeExchangeURL)
if err != nil {
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", err.Error()))
}
defer resp.Body.Close()
if resp.StatusCode != fiber.StatusOK {
body, _ := io.ReadAll(resp.Body)
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error from the auth server: %s", string(body)))
}
var tokenResponse TokenResponse
err = json.NewDecoder(resp.Body).Decode(&tokenResponse)
if err != nil {
return c.Status(fiber.StatusBadRequest).SendString(fmt.Sprintf("Error decoding auth server response: %s", err.Error()))
}
c.Cookie(&fiber.Cookie{
Name: "jade-edgedb-auth-token",
Value: tokenResponse.AuthToken,
HTTPOnly: true,
Path: "/",
Secure: true,
SameSite: "Strict",
})
edgeClient = edgeClient.WithGlobals(map[string]interface{}{"ext::auth::client_token": tokenResponse.AuthToken})
return c.Redirect("/", fiber.StatusPermanentRedirect)
}
func handleSignOut(c *fiber.Ctx) error {
c.ClearCookie("jade-edgedb-auth-token")
edgeClient = edgeClient.WithoutGlobals("ext::auth::client_token")
return c.Redirect("/", fiber.StatusTemporaryRedirect)
}