// Package tokenstore manages single-use tokens for email verification, // password reset, and tenant invitations (PROJ-28). package tokenstore import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "fmt" "time" "github.com/jackc/pgx/v5/pgxpool" ) const ( TypeVerify = "verify" TypeReset = "reset" TypeInvite = "invite" ) // Token represents a stored auth token record. type Token struct { ID int64 Type string UserID *int64 TenantID *int64 ExpiresAt time.Time } // Store manages auth_tokens in PostgreSQL. type Store struct { pool *pgxpool.Pool } // New connects to PostgreSQL and initialises the token schema. func New(pool *pgxpool.Pool) (*Store, error) { s := &Store{pool: pool} if err := s.initSchema(context.Background()); err != nil { return nil, fmt.Errorf("tokenstore: init schema: %w", err) } return s, nil } func (s *Store) initSchema(ctx context.Context) error { _, err := s.pool.Exec(ctx, ` CREATE TABLE IF NOT EXISTS auth_tokens ( id BIGSERIAL PRIMARY KEY, type VARCHAR(50) NOT NULL, token_hash TEXT NOT NULL UNIQUE, user_id BIGINT REFERENCES users(id) ON DELETE CASCADE, tenant_id BIGINT, expires_at TIMESTAMPTZ NOT NULL, used_at TIMESTAMPTZ ); CREATE INDEX IF NOT EXISTS idx_auth_tokens_hash ON auth_tokens (token_hash); CREATE INDEX IF NOT EXISTS idx_auth_tokens_expires ON auth_tokens (expires_at); `) return err } // Create generates a new random token, stores its SHA-256 hash, and returns the plaintext token. // userID may be nil for invite tokens (user does not exist yet). func (s *Store) Create(ctx context.Context, tokenType string, userID *int64, tenantID *int64, ttl time.Duration) (string, error) { raw := make([]byte, 32) if _, err := rand.Read(raw); err != nil { return "", fmt.Errorf("tokenstore: rand: %w", err) } plain := base64.RawURLEncoding.EncodeToString(raw) hash := hashToken(plain) _, err := s.pool.Exec(ctx, `INSERT INTO auth_tokens (type, token_hash, user_id, tenant_id, expires_at) VALUES ($1, $2, $3, $4, $5)`, tokenType, hash, userID, tenantID, time.Now().Add(ttl), ) if err != nil { return "", fmt.Errorf("tokenstore: create: %w", err) } return plain, nil } // Use validates a plaintext token (type must match), marks it as used, and returns the record. // Returns an error if the token is invalid, expired, or already used. func (s *Store) Use(ctx context.Context, tokenType, plain string) (*Token, error) { hash := hashToken(plain) var t Token err := s.pool.QueryRow(ctx, `SELECT id, type, user_id, tenant_id, expires_at, used_at FROM auth_tokens WHERE token_hash = $1 AND type = $2`, hash, tokenType, ).Scan(&t.ID, &t.Type, &t.UserID, &t.TenantID, &t.ExpiresAt, new(*time.Time)) if err != nil { return nil, fmt.Errorf("tokenstore: token not found or invalid") } if time.Now().After(t.ExpiresAt) { return nil, fmt.Errorf("tokenstore: token expired") } // Mark used tag, err := s.pool.Exec(ctx, `UPDATE auth_tokens SET used_at = NOW() WHERE id = $1 AND used_at IS NULL`, t.ID, ) if err != nil { return nil, fmt.Errorf("tokenstore: mark used: %w", err) } if tag.RowsAffected() == 0 { return nil, fmt.Errorf("tokenstore: token already used") } return &t, nil } // Peek validates a token without consuming it. Used for invite token preview. func (s *Store) Peek(ctx context.Context, tokenType, plain string) (*Token, error) { hash := hashToken(plain) var t Token var usedAt *time.Time err := s.pool.QueryRow(ctx, `SELECT id, type, user_id, tenant_id, expires_at, used_at FROM auth_tokens WHERE token_hash = $1 AND type = $2`, hash, tokenType, ).Scan(&t.ID, &t.Type, &t.UserID, &t.TenantID, &t.ExpiresAt, &usedAt) if err != nil { return nil, fmt.Errorf("tokenstore: token not found or invalid") } if usedAt != nil { return nil, fmt.Errorf("tokenstore: token already used") } if time.Now().After(t.ExpiresAt) { return nil, fmt.Errorf("tokenstore: token expired") } return &t, nil } // Cleanup deletes tokens that are expired or used. func (s *Store) Cleanup(ctx context.Context) error { _, err := s.pool.Exec(ctx, `DELETE FROM auth_tokens WHERE expires_at < NOW() OR used_at IS NOT NULL`, ) return err } func hashToken(plain string) string { sum := sha256.Sum256([]byte(plain)) return fmt.Sprintf("%x", sum[:]) }