from typing import Annotated from uuid import UUID from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from jose import JWTError from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_db from app.core.security import decode_access_token from app.models.user import User, UserRole bearer_scheme = HTTPBearer() async def get_current_user( credentials: Annotated[HTTPAuthorizationCredentials, Depends(bearer_scheme)], db: Annotated[AsyncSession, Depends(get_db)], ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = decode_access_token(credentials.credentials) user_id: str = payload.get("sub") if user_id is None: raise credentials_exception except JWTError: raise credentials_exception # User lookup happens while bypass_rls = 'on' (set in get_db), so the # SELECT on users is unrestricted — necessary because we don't yet know # the company_id at this point. user = await db.get(User, UUID(user_id)) if user is None or not user.is_active: raise credentials_exception # ── RLS context ──────────────────────────────────────────────────────── # SUPER_ADMIN can see all companies → keep bypass_rls = 'on'. # Every other role gets the RLS fence applied: set company_id and disable # bypass so subsequent queries in the same transaction are automatically # filtered to the user's company. if user.role != UserRole.SUPER_ADMIN and user.company_id is not None: # SET LOCAL does not accept bind parameters in PostgreSQL; the value # must be inlined. We sanitise by converting through uuid.UUID first # so an attacker-supplied token payload cannot inject arbitrary SQL. safe_company_id = str(user.company_id) # already a UUID object from db.get() await db.execute(text(f"SET LOCAL app.company_id = '{safe_company_id}'")) await db.execute(text("SET LOCAL app.bypass_rls = 'off'")) return user CurrentUser = Annotated[User, Depends(get_current_user)] def require_role(*roles: UserRole): """Dependency factory: require_role(UserRole.MANAGER, UserRole.COMPANY_ADMIN)""" async def checker(current_user: CurrentUser) -> User: if current_user.role not in roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions", ) return current_user return Depends(checker) def require_same_company(target_company_id: UUID, current_user: User) -> None: """Raise 403 if user tries to access another company's data.""" if ( current_user.role != UserRole.SUPER_ADMIN and current_user.company_id != target_company_id ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Access to this resource is not allowed", )