"""Add PostgreSQL Row Level Security (RLS) for tenant isolation Revision ID: 0024 Revises: 0023 Create Date: 2026-05-23 """ from alembic import op revision = "0024" down_revision = "0023" branch_labels = None depends_on = None # --------------------------------------------------------------------------- # Helper # --------------------------------------------------------------------------- def _rls_bypass_expr() -> str: """USING-expression that always allows when bypass is set.""" return "COALESCE(current_setting('app.bypass_rls', true), 'off') = 'on'" def _company_id_expr(col: str = "company_id") -> str: """USING-expression that matches the session company_id.""" return ( f"{col} = NULLIF(current_setting('app.company_id', true), '')::uuid" ) def _using(col: str = "company_id") -> str: return f"({_rls_bypass_expr()} OR {_company_id_expr(col)})" def _with_check(col: str = "company_id") -> str: return f"({_rls_bypass_expr()} OR {_company_id_expr(col)})" def _user_join_expr() -> str: """For tables with user_id: restrict via JOIN to users.company_id.""" return ( f"({_rls_bypass_expr()} OR " f"user_id IN (" f"SELECT id FROM users WHERE company_id = " f"NULLIF(current_setting('app.company_id', true), '')::uuid" f"))" ) def _enable_rls(table: str) -> str: return f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY" def _force_rls(table: str) -> str: return f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY" def _disable_rls(table: str) -> str: return f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY" def _no_force_rls(table: str) -> str: return f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY" def _create_policies_company_col(table: str, col: str = "company_id") -> list[str]: """SELECT/INSERT/UPDATE/DELETE policies for tables that have a direct company_id column.""" using = _using(col) with_check = _with_check(col) return [ f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}", f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {with_check}", f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {with_check}", f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}", ] def _create_policies_user_join(table: str) -> list[str]: """Policies for tables that reference users (no direct company_id).""" using = _user_join_expr() return [ f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}", f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {using}", f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {using}", f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}", ] def _drop_policies(table: str) -> list[str]: return [ f"DROP POLICY IF EXISTS rls_{table}_select ON {table}", f"DROP POLICY IF EXISTS rls_{table}_insert ON {table}", f"DROP POLICY IF EXISTS rls_{table}_update ON {table}", f"DROP POLICY IF EXISTS rls_{table}_delete ON {table}", ] # --------------------------------------------------------------------------- # Tables covered by RLS # --------------------------------------------------------------------------- # Tables with a direct company_id column COMPANY_COL_TABLES = [ "absence_types", "audit_logs", "caldav_company_configs", "departments", "kiosk_devices", "ldap_configs", "overtime_balances", "smtp_configs", "users", "work_schedules", ] # Tables where the row references users (which belong to a company) USER_JOIN_TABLES = [ "absences", "caldav_user_configs", "password_resets", "sessions", "time_entries", "vacation_balances", ] # The companies table itself – restrict by id COMPANIES_TABLE = "companies" # public_holidays is global (no tenant column) – RLS not applied def upgrade() -> None: conn = op.get_bind() # --- companies table --- conn.execute(__import__("sqlalchemy").text(_enable_rls(COMPANIES_TABLE))) conn.execute(__import__("sqlalchemy").text(_force_rls(COMPANIES_TABLE))) companies_using = ( f"({_rls_bypass_expr()} OR " f"id = NULLIF(current_setting('app.company_id', true), '')::uuid)" ) conn.execute(__import__("sqlalchemy").text( f"DROP POLICY IF EXISTS rls_companies_select ON companies" )) conn.execute(__import__("sqlalchemy").text( f"DROP POLICY IF EXISTS rls_companies_insert ON companies" )) conn.execute(__import__("sqlalchemy").text( f"DROP POLICY IF EXISTS rls_companies_update ON companies" )) conn.execute(__import__("sqlalchemy").text( f"DROP POLICY IF EXISTS rls_companies_delete ON companies" )) conn.execute(__import__("sqlalchemy").text( f"CREATE POLICY rls_companies_select ON companies FOR SELECT USING {companies_using}" )) conn.execute(__import__("sqlalchemy").text( f"CREATE POLICY rls_companies_insert ON companies FOR INSERT WITH CHECK {companies_using}" )) conn.execute(__import__("sqlalchemy").text( f"CREATE POLICY rls_companies_update ON companies FOR UPDATE " f"USING {companies_using} WITH CHECK {companies_using}" )) conn.execute(__import__("sqlalchemy").text( f"CREATE POLICY rls_companies_delete ON companies FOR DELETE USING {companies_using}" )) # --- tables with direct company_id --- for table in COMPANY_COL_TABLES: conn.execute(__import__("sqlalchemy").text(_enable_rls(table))) conn.execute(__import__("sqlalchemy").text(_force_rls(table))) for drop_stmt in _drop_policies(table): conn.execute(__import__("sqlalchemy").text(drop_stmt)) for create_stmt in _create_policies_company_col(table): conn.execute(__import__("sqlalchemy").text(create_stmt)) # --- tables with user_id join --- for table in USER_JOIN_TABLES: conn.execute(__import__("sqlalchemy").text(_enable_rls(table))) conn.execute(__import__("sqlalchemy").text(_force_rls(table))) for drop_stmt in _drop_policies(table): conn.execute(__import__("sqlalchemy").text(drop_stmt)) for create_stmt in _create_policies_user_join(table): conn.execute(__import__("sqlalchemy").text(create_stmt)) def downgrade() -> None: conn = op.get_bind() # --- companies --- for stmt in _drop_policies(COMPANIES_TABLE): conn.execute(__import__("sqlalchemy").text(stmt)) conn.execute(__import__("sqlalchemy").text(_no_force_rls(COMPANIES_TABLE))) conn.execute(__import__("sqlalchemy").text(_disable_rls(COMPANIES_TABLE))) # --- tables with direct company_id --- for table in COMPANY_COL_TABLES: for stmt in _drop_policies(table): conn.execute(__import__("sqlalchemy").text(stmt)) conn.execute(__import__("sqlalchemy").text(_no_force_rls(table))) conn.execute(__import__("sqlalchemy").text(_disable_rls(table))) # --- tables with user_id join --- for table in USER_JOIN_TABLES: for stmt in _drop_policies(table): conn.execute(__import__("sqlalchemy").text(stmt)) conn.execute(__import__("sqlalchemy").text(_no_force_rls(table))) conn.execute(__import__("sqlalchemy").text(_disable_rls(table)))