package auth import ( "context" "crypto/rand" "crypto/subtle" "encoding/hex" "errors" "fmt" "strings" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/crypto/argon2" "github.com/jacek/pamietnik/backend/internal/domain" ) const sessionDuration = 24 * time.Hour var ErrInvalidCredentials = errors.New("invalid username or password") var ErrSessionNotFound = errors.New("session not found or expired") type Store struct { pool *pgxpool.Pool } func NewStore(pool *pgxpool.Pool) *Store { return &Store{pool: pool} } // HashPassword returns an argon2id hash of the password. func HashPassword(password string) (string, error) { salt := make([]byte, 16) if _, err := rand.Read(salt); err != nil { return "", fmt.Errorf("generate salt: %w", err) } hash := argon2.IDKey([]byte(password), salt, 1, 64*1024, 4, 32) return fmt.Sprintf("$argon2id$%x$%x", salt, hash), nil } // VerifyPassword checks password against stored hash. // Format: $argon2id$$ func VerifyPassword(password, stored string) bool { parts := strings.Split(stored, "$") // ["", "argon2id", "", ""] if len(parts) != 4 || parts[1] != "argon2id" { return false } salt, err := hex.DecodeString(parts[2]) if err != nil { return false } expected, err := hex.DecodeString(parts[3]) if err != nil { return false } hash := argon2.IDKey([]byte(password), salt, 1, 64*1024, 4, 32) return subtle.ConstantTimeCompare(hash, expected) == 1 } // Login verifies credentials and creates a session. func (s *Store) Login(ctx context.Context, username, password string) (domain.Session, error) { var user domain.User err := s.pool.QueryRow(ctx, `SELECT user_id, username, password_hash FROM users WHERE username = $1`, username, ).Scan(&user.UserID, &user.Username, &user.PasswordHash) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return domain.Session{}, ErrInvalidCredentials } return domain.Session{}, err } if !VerifyPassword(password, user.PasswordHash) { return domain.Session{}, ErrInvalidCredentials } sessionID, err := newSessionID() if err != nil { return domain.Session{}, fmt.Errorf("create session: %w", err) } now := time.Now().UTC() sess := domain.Session{ SessionID: sessionID, UserID: user.UserID, CreatedAt: now, ExpiresAt: now.Add(sessionDuration), } _, err = s.pool.Exec(ctx, `INSERT INTO sessions (session_id, user_id, created_at, expires_at) VALUES ($1, $2, $3, $4)`, sess.SessionID, sess.UserID, sess.CreatedAt, sess.ExpiresAt, ) if err != nil { return domain.Session{}, err } return sess, nil } // GetSession validates a session and returns user_id. func (s *Store) GetSession(ctx context.Context, sessionID string) (domain.Session, error) { var sess domain.Session err := s.pool.QueryRow(ctx, `SELECT session_id, user_id, created_at, expires_at FROM sessions WHERE session_id = $1 AND expires_at > NOW()`, sessionID, ).Scan(&sess.SessionID, &sess.UserID, &sess.CreatedAt, &sess.ExpiresAt) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return domain.Session{}, ErrSessionNotFound } return domain.Session{}, err } return sess, nil } // Logout deletes a session. func (s *Store) Logout(ctx context.Context, sessionID string) error { _, err := s.pool.Exec(ctx, `DELETE FROM sessions WHERE session_id = $1`, sessionID) if err != nil { return fmt.Errorf("delete session: %w", err) } return nil } func newSessionID() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("generate session id: %w", err) } return hex.EncodeToString(b), nil }