package api import ( "context" "net/http" "github.com/jacek/pamietnik/backend/internal/auth" "github.com/jacek/pamietnik/backend/internal/domain" ) type contextKey string const ctxUserID contextKey = "user_id" const ctxUser contextKey = "user" const sessionCookieName = "session" // RequireAuth validates the session cookie and stores user info in context. // On failure it redirects to /login for browser requests (text/html) or returns JSON 401. func RequireAuth(authStore *auth.Store) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user, err := userFromRequest(r, authStore) if err != nil { redirectOrUnauthorized(w, r) return } ctx := context.WithValue(r.Context(), ctxUserID, user.UserID) ctx = context.WithValue(ctx, ctxUser, user) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // requireAdmin checks that the authenticated user is an admin. func requireAdmin(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { u, ok := r.Context().Value(ctxUser).(domain.User) if !ok || !u.IsAdmin { http.Redirect(w, r, "/days", http.StatusSeeOther) return } next.ServeHTTP(w, r) }) } func userFromRequest(r *http.Request, authStore *auth.Store) (domain.User, error) { cookie, err := r.Cookie(sessionCookieName) if err != nil { return domain.User{}, auth.ErrSessionNotFound } return authStore.GetUserBySession(r.Context(), cookie.Value) } func redirectOrUnauthorized(w http.ResponseWriter, r *http.Request) { accept := r.Header.Get("Accept") if len(accept) > 0 && (accept == "application/json" || r.Header.Get("X-Requested-With") == "XMLHttpRequest") { writeError(w, http.StatusUnauthorized, "UNAUTHORIZED", "login required") return } http.Redirect(w, r, "/login", http.StatusSeeOther) } func userIDFromContext(ctx context.Context) string { v, _ := ctx.Value(ctxUserID).(string) return v } func userFromContext(ctx context.Context) domain.User { v, _ := ctx.Value(ctxUser).(domain.User) return v } func contextWithUserID(ctx context.Context, userID string) context.Context { return context.WithValue(ctx, ctxUserID, userID) }