package auth import ( "context" "database/sql" "errors" "log" "net/http" ory "github.com/ory/client-go" "decor-by-hannahs/internal/db" ) type contextKey string const sessionContextKey contextKey = "req.session" func SessionMiddleware(oryClient *ory.APIClient, tunnelURL string) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { cookies := r.Header.Get("Cookie") session, resp, err := oryClient.FrontendAPI.ToSession(r.Context()).Cookie(cookies).Execute() if err != nil { log.Printf("Session check failed: %v", err) if resp != nil { log.Printf("Response status: %d", resp.StatusCode) } log.Printf("Redirecting to login: %s/ui/login", tunnelURL) http.Redirect(w, r, tunnelURL+"/ui/login", http.StatusSeeOther) return } if session == nil || !*session.Active { log.Printf("Session inactive, redirecting to login") http.Redirect(w, r, tunnelURL+"/ui/login", http.StatusSeeOther) return } ctx := context.WithValue(r.Context(), sessionContextKey, session) next.ServeHTTP(w, r.WithContext(ctx)) } } } func AuthMiddleware(oryClient *ory.APIClient, queries *db.Queries) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { session, err := GetSession(r.Context()) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } email := getEmailFromSession(session) if email == "" { http.Error(w, "No email in session", http.StatusBadRequest) return } oryID := sql.NullString{String: session.Identity.Id, Valid: true} user, err := queries.GetUserByOryID(r.Context(), oryID) if err != nil { user, err = queries.CreateUser(r.Context(), db.CreateUserParams{ Email: email, OryIdentityID: oryID, }) if err != nil { log.Printf("Failed to create user: %v", err) http.Error(w, "Internal server error", http.StatusInternalServerError) return } log.Printf("Created new user: %s (ID: %d)", email, user.ID) } ctx := context.WithValue(r.Context(), "user", user) next.ServeHTTP(w, r.WithContext(ctx)) } } } func GetSession(ctx context.Context) (*ory.Session, error) { session, ok := ctx.Value(sessionContextKey).(*ory.Session) if !ok || session == nil { return nil, errors.New("session not found in context") } return session, nil } func GetUser(ctx context.Context) (*db.User, error) { user, ok := ctx.Value("user").(*db.User) if !ok || user == nil { return nil, errors.New("user not found in context") } return user, nil } func getEmailFromSession(session *ory.Session) string { if session.Identity.Traits == nil { return "" } traits, ok := session.Identity.Traits.(map[string]interface{}) if !ok { return "" } email, ok := traits["email"].(string) if !ok { return "" } return email }