package userstore import ( "context" "errors" "fmt" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/crypto/bcrypt" ) const ( RoleUser = "user" RoleAdmin = "admin" RoleAuditor = "auditor" ) // User represents a user account in the system. type User struct { ID int64 Username string Email string Role string Source string // "local" or "ldap" Active bool CreatedAt time.Time } // CreateUserRequest holds parameters for creating a new user. type CreateUserRequest struct { Username string Email string Password string Role string } // UpdateUserRequest holds optional fields for updating a user. type UpdateUserRequest struct { Email *string Role *string Active *bool Password *string } // Store is a PostgreSQL-backed user store. type Store struct { pool *pgxpool.Pool } // New connects to PostgreSQL using the given DSN and initialises the schema. func New(dsn string) (*Store, error) { ctx := context.Background() pool, err := pgxpool.New(ctx, dsn) if err != nil { return nil, fmt.Errorf("userstore: connect: %w", err) } s := &Store{pool: pool} if err := s.initSchema(ctx); err != nil { pool.Close() return nil, fmt.Errorf("userstore: init schema: %w", err) } return s, nil } func (s *Store) initSchema(ctx context.Context) error { _, err := s.pool.Exec(ctx, ` CREATE TABLE IF NOT EXISTS users ( id BIGSERIAL PRIMARY KEY, username VARCHAR(100) UNIQUE NOT NULL, email VARCHAR(255) UNIQUE NOT NULL, password_hash VARCHAR(255) NOT NULL DEFAULT '', role VARCHAR(20) NOT NULL CHECK (role IN ('user','auditor','admin')), source VARCHAR(20) NOT NULL DEFAULT 'local', active BOOLEAN NOT NULL DEFAULT true, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); CREATE TABLE IF NOT EXISTS token_blacklist ( jti VARCHAR(255) PRIMARY KEY, expires_at TIMESTAMPTZ NOT NULL ); `) return err } // Close closes the underlying connection pool. func (s *Store) Close() error { s.pool.Close() return nil } // Create inserts a new local user with a bcrypt-hashed password. func (s *Store) Create(req CreateUserRequest) (*User, error) { hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("userstore: bcrypt: %w", err) } ctx := context.Background() var id int64 err = s.pool.QueryRow(ctx, `INSERT INTO users (username, email, password_hash, role, source, active, created_at) VALUES ($1, $2, $3, $4, 'local', true, NOW()) RETURNING id`, req.Username, req.Email, string(hash), req.Role, ).Scan(&id) if err != nil { return nil, fmt.Errorf("userstore: create: %w", err) } return s.GetByID(id) } // GetByID retrieves a user by their numeric ID. func (s *Store) GetByID(id int64) (*User, error) { ctx := context.Background() row := s.pool.QueryRow(ctx, `SELECT id, username, email, role, source, active, created_at FROM users WHERE id = $1`, id, ) return scanUser(row) } // GetByUsername retrieves a user by their username. func (s *Store) GetByUsername(username string) (*User, error) { ctx := context.Background() row := s.pool.QueryRow(ctx, `SELECT id, username, email, role, source, active, created_at FROM users WHERE username = $1`, username, ) return scanUser(row) } // VerifyPassword checks credentials and returns the user, or an error if the // password is wrong or the account is disabled. func (s *Store) VerifyPassword(username, password string) (*User, error) { ctx := context.Background() row := s.pool.QueryRow(ctx, `SELECT id, username, email, role, source, active, created_at, password_hash FROM users WHERE username = $1`, username, ) var u User var hash string err := row.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt, &hash) if errors.Is(err, pgx.ErrNoRows) { return nil, errors.New("userstore: user not found") } if err != nil { return nil, fmt.Errorf("userstore: scan: %w", err) } if !u.Active { return nil, errors.New("userstore: account disabled") } if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil { return nil, errors.New("userstore: wrong password") } return &u, nil } // Update applies a partial update to a user record. func (s *Store) Update(id int64, req UpdateUserRequest) (*User, error) { ctx := context.Background() if req.Email != nil { if _, err := s.pool.Exec(ctx, `UPDATE users SET email = $1 WHERE id = $2`, *req.Email, id); err != nil { return nil, fmt.Errorf("userstore: update email: %w", err) } } if req.Role != nil { if _, err := s.pool.Exec(ctx, `UPDATE users SET role = $1 WHERE id = $2`, *req.Role, id); err != nil { return nil, fmt.Errorf("userstore: update role: %w", err) } } if req.Active != nil { if _, err := s.pool.Exec(ctx, `UPDATE users SET active = $1 WHERE id = $2`, *req.Active, id); err != nil { return nil, fmt.Errorf("userstore: update active: %w", err) } } if req.Password != nil { hash, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("userstore: bcrypt: %w", err) } if _, err := s.pool.Exec(ctx, `UPDATE users SET password_hash = $1 WHERE id = $2`, string(hash), id); err != nil { return nil, fmt.Errorf("userstore: update password: %w", err) } } return s.GetByID(id) } // Delete removes a user by ID. Returns an error if the user does not exist. func (s *Store) Delete(id int64) error { ctx := context.Background() tag, err := s.pool.Exec(ctx, `DELETE FROM users WHERE id = $1`, id) if err != nil { return fmt.Errorf("userstore: delete: %w", err) } if tag.RowsAffected() == 0 { return fmt.Errorf("userstore: user %d not found", id) } return nil } // List returns all users, optionally filtered by role. Pass role="" to list all. func (s *Store) List(role string) ([]*User, error) { ctx := context.Background() var rows pgx.Rows var err error if role == "" { rows, err = s.pool.Query(ctx, `SELECT id, username, email, role, source, active, created_at FROM users ORDER BY id`) } else { rows, err = s.pool.Query(ctx, `SELECT id, username, email, role, source, active, created_at FROM users WHERE role = $1 ORDER BY id`, role) } if err != nil { return nil, fmt.Errorf("userstore: list: %w", err) } defer rows.Close() var users []*User for rows.Next() { u, err := scanUserRow(rows) if err != nil { return nil, err } users = append(users, u) } return users, rows.Err() } // BlacklistToken adds a JWT ID to the token blacklist. func (s *Store) BlacklistToken(jti string, expires time.Time) error { ctx := context.Background() _, err := s.pool.Exec(ctx, `INSERT INTO token_blacklist (jti, expires_at) VALUES ($1, $2) ON CONFLICT (jti) DO UPDATE SET expires_at = EXCLUDED.expires_at`, jti, expires.UTC(), ) return err } // IsBlacklisted returns true if the given JTI is in the blacklist. func (s *Store) IsBlacklisted(jti string) (bool, error) { ctx := context.Background() var count int err := s.pool.QueryRow(ctx, `SELECT COUNT(*) FROM token_blacklist WHERE jti = $1`, jti, ).Scan(&count) return count > 0, err } // CleanExpiredTokens removes blacklist entries whose expiry has passed. func (s *Store) CleanExpiredTokens() error { ctx := context.Background() _, err := s.pool.Exec(ctx, `DELETE FROM token_blacklist WHERE expires_at < NOW()`) return err } // UpsertLDAPUser creates or updates an LDAP-sourced user. func (s *Store) UpsertLDAPUser(username, email, role string) (*User, error) { ctx := context.Background() _, err := s.pool.Exec(ctx, ` INSERT INTO users (username, email, password_hash, role, source, active, created_at) VALUES ($1, $2, '', $3, 'ldap', true, NOW()) ON CONFLICT (username) DO UPDATE SET email = EXCLUDED.email, role = EXCLUDED.role, source = 'ldap' `, username, email, role) if err != nil { return nil, fmt.Errorf("userstore: upsert ldap: %w", err) } return s.GetByUsername(username) } // --- helpers --- func scanUser(row pgx.Row) (*User, error) { var u User err := row.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt) if errors.Is(err, pgx.ErrNoRows) { return nil, fmt.Errorf("userstore: not found") } if err != nil { return nil, fmt.Errorf("userstore: scan: %w", err) } return &u, nil } func scanUserRow(rows pgx.Rows) (*User, error) { var u User if err := rows.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt); err != nil { return nil, fmt.Errorf("userstore: scan row: %w", err) } return &u, nil }