from datetime import datetime, timedelta, timezone from typing import Any import secrets import hashlib import bcrypt from jose import JWTError, jwt from app.core.config import settings # ── Password ──────────────────────────────────────────────────────────────── def hash_password(password: str) -> str: return bcrypt.hashpw(password.encode(), bcrypt.gensalt(rounds=12)).decode() def verify_password(plain_password: str, hashed_password: str) -> bool: return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) def hash_token(token: str) -> str: """SHA-256 hash for storing tokens (refresh, invite, reset) in DB.""" return hashlib.sha256(token.encode()).hexdigest() # ── JWT ───────────────────────────────────────────────────────────────────── def create_access_token(subject: str, extra: dict[str, Any] | None = None) -> str: expire = datetime.now(timezone.utc) + timedelta( minutes=settings.access_token_expire_minutes ) payload = {"sub": subject, "exp": expire, "type": "access"} if extra: payload.update(extra) return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm) def create_refresh_token() -> tuple[str, str]: """Returns (raw_token, hashed_token). Store hash in DB, send raw to client.""" raw = secrets.token_urlsafe(64) return raw, hash_token(raw) def decode_access_token(token: str) -> dict[str, Any]: """Raises JWTError on invalid/expired token.""" payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) if payload.get("type") != "access": raise JWTError("Invalid token type") return payload # ── Partial token (TOTP pending) ───────────────────────────────────────────── def create_partial_token(user_id: str) -> str: """Short-lived token issued after password-OK but before TOTP verification. Valid 5 min.""" expire = datetime.now(timezone.utc) + timedelta(minutes=5) payload = {"sub": user_id, "exp": expire, "type": "partial"} return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm) def decode_partial_token(token: str) -> str: """Returns user_id (sub). Raises JWTError on invalid/expired/wrong-type token.""" payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) if payload.get("type") != "partial": raise JWTError("Invalid token type") return payload["sub"] # ── One-time tokens ────────────────────────────────────────────────────────── def generate_invite_token() -> tuple[str, str]: """Returns (raw, hashed). Invite valid for 7 days.""" raw = secrets.token_urlsafe(32) return raw, hash_token(raw) def generate_reset_token() -> tuple[str, str]: """Returns (raw, hashed). Reset valid for 1 hour.""" raw = secrets.token_urlsafe(32) return raw, hash_token(raw)