the rest
This commit is contained in:
112
internal/auth/middleware.go
Normal file
112
internal/auth/middleware.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
ory "github.com/ory/client-go"
|
||||
|
||||
"decor-by-hannahs/internal/db"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const sessionContextKey contextKey = "req.session"
|
||||
|
||||
func SessionMiddleware(oryClient *ory.APIClient, tunnelURL string) func(http.HandlerFunc) http.HandlerFunc {
|
||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
cookies := r.Header.Get("Cookie")
|
||||
|
||||
session, resp, err := oryClient.FrontendAPI.ToSession(r.Context()).Cookie(cookies).Execute()
|
||||
if err != nil {
|
||||
log.Printf("Session check failed: %v", err)
|
||||
if resp != nil {
|
||||
log.Printf("Response status: %d", resp.StatusCode)
|
||||
}
|
||||
log.Printf("Redirecting to login: %s/ui/login", tunnelURL)
|
||||
http.Redirect(w, r, tunnelURL+"/ui/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
if session == nil || !*session.Active {
|
||||
log.Printf("Session inactive, redirecting to login")
|
||||
http.Redirect(w, r, tunnelURL+"/ui/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), sessionContextKey, session)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AuthMiddleware(oryClient *ory.APIClient, queries *db.Queries) func(http.HandlerFunc) http.HandlerFunc {
|
||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := GetSession(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
email := getEmailFromSession(session)
|
||||
if email == "" {
|
||||
http.Error(w, "No email in session", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
oryID := sql.NullString{String: session.Identity.Id, Valid: true}
|
||||
user, err := queries.GetUserByOryID(r.Context(), oryID)
|
||||
if err != nil {
|
||||
user, err = queries.CreateUser(r.Context(), db.CreateUserParams{
|
||||
Email: email,
|
||||
OryIdentityID: oryID,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to create user: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
log.Printf("Created new user: %s (ID: %d)", email, user.ID)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), "user", user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetSession(ctx context.Context) (*ory.Session, error) {
|
||||
session, ok := ctx.Value(sessionContextKey).(*ory.Session)
|
||||
if !ok || session == nil {
|
||||
return nil, errors.New("session not found in context")
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func GetUser(ctx context.Context) (*db.User, error) {
|
||||
user, ok := ctx.Value("user").(*db.User)
|
||||
if !ok || user == nil {
|
||||
return nil, errors.New("user not found in context")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func getEmailFromSession(session *ory.Session) string {
|
||||
if session.Identity.Traits == nil {
|
||||
return ""
|
||||
}
|
||||
traits, ok := session.Identity.Traits.(map[string]interface{})
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
email, ok := traits["email"].(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return email
|
||||
}
|
||||
Reference in New Issue
Block a user