From dd3e0694668f46fe20ad21ba1311917855aca327 Mon Sep 17 00:00:00 2001 From: patrick Date: Sat, 23 May 2026 22:34:48 +0200 Subject: [PATCH] fix: router db.refresh() nach commit bricht RLS-Kontext MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SET LOCAL Werte (bypass_rls, company_id) sind transaktions-gebunden. Nach db.commit() ist der Kontext weg – ein nachfolgendes db.refresh() läuft in einer neuen Transaktion ohne RLS-Kontext und liefert 0 Rows. Da expire_on_commit=False gesetzt ist, sind alle Instanz-Attribute nach dem Commit bereits im Speicher vorhanden. Die expliziten db.refresh()-Aufrufe nach db.commit() in allen Routers sind daher redundant und wurden entfernt. test_rls.py: 6 neue Tests beweisen DB-seitige Mandanten-Isolation. conftest.py: _apply_rls() wendet RLS-Policies auf Test-DB an. migrations/0024: korrigiert auf op.execute(text()) API. migrations/env.py: SET LOCAL außerhalb Transaktion entfernt. Ergebnis: 8 failed (pre-existing), 126 passed – identisch zur Baseline vor RLS. Co-Authored-By: Claude Sonnet 4.6 --- DEVLOG.md | 16 ++ backend/app/routers/absences.py | 10 - backend/app/routers/caldav.py | 2 - backend/app/routers/kiosk.py | 3 - backend/app/routers/ldap.py | 2 - backend/app/routers/projects.py | 2 - backend/app/routers/smtp.py | 1 - backend/app/routers/time_entries.py | 10 - backend/migrations/env.py | 4 - .../versions/0024_row_level_security.py | 219 ++++-------------- backend/tests/conftest.py | 48 ++++ backend/tests/test_rls.py | 190 +++++++++++++++ 12 files changed, 305 insertions(+), 202 deletions(-) create mode 100644 backend/tests/test_rls.py diff --git a/DEVLOG.md b/DEVLOG.md index 6b9bc4b..51dfcad 100644 --- a/DEVLOG.md +++ b/DEVLOG.md @@ -638,3 +638,19 @@ Keine Commits in dieser Session. - update.sh | 53 +++++++++++++++++++++++------------------------------ --- +## 2026-05-23 21:51 – 21:58 (7m) +**Beschreibung:** Claude Code Session +**Projekt:** timemaster + +### Commits +- 6d4b8a9 agent-rls: PostgreSQL Row Level Security für Mandanten-Isolation + +### Geänderte Dateien +- DEVLOG.md | 79 ++++++++ +- backend/app/core/database.py | 6 + +- backend/app/core/dependencies.py | 18 ++ +- backend/migrations/env.py | 4 + +- .../migrations/versions/0024_row_level_security.py | 208 +++++++++++++++++++++ +- backend/tests/conftest.py | 4 + + +--- diff --git a/backend/app/routers/absences.py b/backend/app/routers/absences.py index 1df9655..23105a3 100644 --- a/backend/app/routers/absences.py +++ b/backend/app/routers/absences.py @@ -72,7 +72,6 @@ async def create_absence_type( ): at = await absence_service.create_type(current_user.company_id, data, db) await db.commit() - await db.refresh(at) return AbsenceTypeOut.model_validate(at) @@ -85,7 +84,6 @@ async def update_absence_type( ): at = await absence_service.update_type(type_id, current_user.company_id, data, db) await db.commit() - await db.refresh(at) return AbsenceTypeOut.model_validate(at) @@ -111,7 +109,6 @@ async def create_public_holiday( ): holiday = await absence_service.create_holiday(data, db) await db.commit() - await db.refresh(holiday) return PublicHolidayOut.model_validate(holiday) @@ -181,7 +178,6 @@ async def quick_sick( data.start_date, data.end_date, current_user, db ) await db.commit() - await db.refresh(absence) return AbsenceOut.model_validate(absence) @@ -256,7 +252,6 @@ async def create_absence( acting_user = target absence, warnings = await absence_service.create_absence(data, acting_user, db) await db.commit() - await db.refresh(absence) return AbsenceOut.model_validate(absence) @@ -270,7 +265,6 @@ async def update_absence( """Ausstehenden Antrag bearbeiten (Mitarbeiter: eigene; Manager: alle der Company).""" absence = await absence_service.update_absence(absence_id, data, current_user, db) await db.commit() - await db.refresh(absence) return AbsenceOut.model_validate(absence) @@ -303,7 +297,6 @@ async def approve_absence( ): absence = await absence_service.approve_absence(absence_id, current_user, db) await db.commit() - await db.refresh(absence) return AbsenceOut.model_validate(absence) @@ -316,7 +309,6 @@ async def reject_absence( ): absence = await absence_service.reject_absence(absence_id, data, current_user, db) await db.commit() - await db.refresh(absence) return AbsenceOut.model_validate(absence) @@ -332,7 +324,6 @@ async def mark_certificate_received( absence_id, data.received_at, current_user, db ) await db.commit() - await db.refresh(absence) return AbsenceOut.model_validate(absence) @@ -357,7 +348,6 @@ async def update_balance( for field, value in data.model_dump(exclude_unset=True).items(): setattr(balance, field, value) await db.commit() - await db.refresh(balance) pending = await absence_service.get_pending_days(user_id, year, db) company = await db.get(Company, current_user.company_id) expires_at, expired = _carryover_expiry(company, year) if company else (None, False) diff --git a/backend/app/routers/caldav.py b/backend/app/routers/caldav.py index 227ee84..519def0 100644 --- a/backend/app/routers/caldav.py +++ b/backend/app/routers/caldav.py @@ -62,7 +62,6 @@ async def save_company_config( raise HTTPException(status_code=400, detail="Passwort wird beim ersten Speichern benötigt.") await db.commit() - await db.refresh(cfg) return CaldavCompanyConfigOut.model_validate(cfg) @@ -125,7 +124,6 @@ async def save_user_config( raise HTTPException(status_code=400, detail="Passwort wird beim ersten Speichern benötigt.") await db.commit() - await db.refresh(cfg) return CaldavUserConfigOut.model_validate(cfg) diff --git a/backend/app/routers/kiosk.py b/backend/app/routers/kiosk.py index 99ad9a4..f57713b 100644 --- a/backend/app/routers/kiosk.py +++ b/backend/app/routers/kiosk.py @@ -34,7 +34,6 @@ async def create_device( """Neues Kiosk-Gerät registrieren. Token wird nur einmalig zurückgegeben.""" device, raw_token = await kiosk_service.create_device(current_user.company_id, data, db) await db.commit() - await db.refresh(device) return KioskDeviceCreated( **KioskDeviceOut.model_validate(device).model_dump(), token=raw_token, @@ -59,7 +58,6 @@ async def update_device( ): device = await kiosk_service.update_device(device_id, current_user.company_id, data, db) await db.commit() - await db.refresh(device) return KioskDeviceOut.model_validate(device) @@ -72,7 +70,6 @@ async def rotate_token( """Token rotieren – das alte Token wird sofort ungültig.""" device, raw_token = await kiosk_service.rotate_token(device_id, current_user.company_id, db) await db.commit() - await db.refresh(device) return KioskDeviceCreated( **KioskDeviceOut.model_validate(device).model_dump(), token=raw_token, diff --git a/backend/app/routers/ldap.py b/backend/app/routers/ldap.py index bf76b2a..684579a 100644 --- a/backend/app/routers/ldap.py +++ b/backend/app/routers/ldap.py @@ -66,7 +66,6 @@ async def create_ldap_config( ) db.add(cfg) await db.commit() - await db.refresh(cfg) return cfg @@ -135,5 +134,4 @@ async def _apply_update(cfg: LdapConfig, updates: dict, db: AsyncSession) -> Lda elif hasattr(cfg, field): setattr(cfg, field, value) await db.commit() - await db.refresh(cfg) return cfg diff --git a/backend/app/routers/projects.py b/backend/app/routers/projects.py index 0f19507..fcf4860 100644 --- a/backend/app/routers/projects.py +++ b/backend/app/routers/projects.py @@ -58,7 +58,6 @@ async def create_project( ) db.add(project) await db.commit() - await db.refresh(project) return project @@ -141,7 +140,6 @@ async def update_project( setattr(project, field, value) await db.commit() - await db.refresh(project) return project diff --git a/backend/app/routers/smtp.py b/backend/app/routers/smtp.py index b1cdabc..b0fa81e 100644 --- a/backend/app/routers/smtp.py +++ b/backend/app/routers/smtp.py @@ -71,7 +71,6 @@ async def save_smtp_config( cfg.password_encrypted = _encrypt(data.password) await db.commit() - await db.refresh(cfg) return SmtpConfigOut.model_validate(cfg) diff --git a/backend/app/routers/time_entries.py b/backend/app/routers/time_entries.py index 114e1b0..0e0fce1 100644 --- a/backend/app/routers/time_entries.py +++ b/backend/app/routers/time_entries.py @@ -39,7 +39,6 @@ async def stamp_in( """Einstempeln – startet einen neuen Zeiterfassungseintrag.""" entry, warnings = await time_service.stamp_in(current_user, data, db) await db.commit() - await db.refresh(entry) return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings) @@ -52,7 +51,6 @@ async def stamp_out( """Ausstempeln – schließt den offenen Zeiterfassungseintrag.""" entry, warnings = await time_service.stamp_out(current_user, data.note, db) await db.commit() - await db.refresh(entry) return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings) @@ -64,7 +62,6 @@ async def break_start( """Pause beginnen.""" entry = await time_service.break_start(current_user, db) await db.commit() - await db.refresh(entry) return TimeEntryOut.model_validate(entry) @@ -76,7 +73,6 @@ async def break_end( """Pause beenden.""" entry = await time_service.break_end(current_user, db) await db.commit() - await db.refresh(entry) return TimeEntryOut.model_validate(entry) @@ -122,7 +118,6 @@ async def create_manual_entry( """Manuellen Zeiterfassungseintrag anlegen.""" entry, warnings = await time_service.create_manual(data, current_user, db) await db.commit() - await db.refresh(entry) return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings) @@ -136,7 +131,6 @@ async def update_entry( """Zeiterfassungseintrag korrigieren.""" entry = await time_service.update_entry(entry_id, data, current_user, db) await db.commit() - await db.refresh(entry) return TimeEntryOut.model_validate(entry) @@ -149,7 +143,6 @@ async def approve_entry( """Zeiterfassungseintrag genehmigen.""" entry = await time_service.approve_entry(entry_id, current_user, db) await db.commit() - await db.refresh(entry) return TimeEntryOut.model_validate(entry) @@ -163,7 +156,6 @@ async def reject_entry( """Zeiterfassungseintrag ablehnen.""" entry = await time_service.reject_entry(entry_id, current_user, data.rejection_note, db) await db.commit() - await db.refresh(entry) return TimeEntryOut.model_validate(entry) @@ -228,7 +220,6 @@ async def create_schedule( ): schedule = await time_service.create_work_schedule(current_user.company_id, data, db) await db.commit() - await db.refresh(schedule) return WorkScheduleOut.model_validate(schedule) @@ -253,7 +244,6 @@ async def update_schedule( for field, value in data.model_dump().items(): setattr(schedule, field, value) await db.commit() - await db.refresh(schedule) return WorkScheduleOut.model_validate(schedule) diff --git a/backend/migrations/env.py b/backend/migrations/env.py index 354dcda..7ef9498 100644 --- a/backend/migrations/env.py +++ b/backend/migrations/env.py @@ -34,10 +34,6 @@ def run_migrations_offline() -> None: def do_run_migrations(connection: Connection) -> None: - # Ensure Alembic itself is never blocked by RLS policies. - # SET LOCAL is transaction-scoped; context.begin_transaction() opens one. - from sqlalchemy import text as sa_text - connection.execute(sa_text("SET LOCAL app.bypass_rls = 'on'")) context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/backend/migrations/versions/0024_row_level_security.py b/backend/migrations/versions/0024_row_level_security.py index 51854c0..862409f 100644 --- a/backend/migrations/versions/0024_row_level_security.py +++ b/backend/migrations/versions/0024_row_level_security.py @@ -5,204 +5,87 @@ Revises: 0023 Create Date: 2026-05-23 """ from alembic import op +from sqlalchemy import text revision = "0024" down_revision = "0023" branch_labels = None depends_on = None - # --------------------------------------------------------------------------- -# Helper +# RLS expression helpers # --------------------------------------------------------------------------- -def _rls_bypass_expr() -> str: - """USING-expression that always allows when bypass is set.""" - return "COALESCE(current_setting('app.bypass_rls', true), 'off') = 'on'" - - -def _company_id_expr(col: str = "company_id") -> str: - """USING-expression that matches the session company_id.""" - return ( - f"{col} = NULLIF(current_setting('app.company_id', true), '')::uuid" - ) - - -def _using(col: str = "company_id") -> str: - return f"({_rls_bypass_expr()} OR {_company_id_expr(col)})" - - -def _with_check(col: str = "company_id") -> str: - return f"({_rls_bypass_expr()} OR {_company_id_expr(col)})" - - -def _user_join_expr() -> str: - """For tables with user_id: restrict via JOIN to users.company_id.""" - return ( - f"({_rls_bypass_expr()} OR " - f"user_id IN (" - f"SELECT id FROM users WHERE company_id = " - f"NULLIF(current_setting('app.company_id', true), '')::uuid" - f"))" - ) - - -def _enable_rls(table: str) -> str: - return f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY" - - -def _force_rls(table: str) -> str: - return f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY" - - -def _disable_rls(table: str) -> str: - return f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY" - - -def _no_force_rls(table: str) -> str: - return f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY" - - -def _create_policies_company_col(table: str, col: str = "company_id") -> list[str]: - """SELECT/INSERT/UPDATE/DELETE policies for tables that have a direct company_id column.""" - using = _using(col) - with_check = _with_check(col) - return [ - f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}", - f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {with_check}", - f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {with_check}", - f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}", - ] - - -def _create_policies_user_join(table: str) -> list[str]: - """Policies for tables that reference users (no direct company_id).""" - using = _user_join_expr() - return [ - f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}", - f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {using}", - f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {using}", - f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}", - ] - - -def _drop_policies(table: str) -> list[str]: - return [ - f"DROP POLICY IF EXISTS rls_{table}_select ON {table}", - f"DROP POLICY IF EXISTS rls_{table}_insert ON {table}", - f"DROP POLICY IF EXISTS rls_{table}_update ON {table}", - f"DROP POLICY IF EXISTS rls_{table}_delete ON {table}", - ] +_BYPASS = "COALESCE(current_setting('app.bypass_rls', true), 'off') = 'on'" +_CID = "company_id = NULLIF(current_setting('app.company_id', true), '')::uuid" +_IID = "id = NULLIF(current_setting('app.company_id', true), '')::uuid" +def _using_cid(): return f"({_BYPASS} OR {_CID})" +def _using_iid(): return f"({_BYPASS} OR {_IID})" +def _using_join(): return ( + f"({_BYPASS} OR user_id IN (" + f"SELECT id FROM users WHERE {_CID}))" +) # --------------------------------------------------------------------------- -# Tables covered by RLS +# Tables # --------------------------------------------------------------------------- -# Tables with a direct company_id column +# Direct company_id column COMPANY_COL_TABLES = [ - "absence_types", - "audit_logs", - "caldav_company_configs", - "departments", - "kiosk_devices", - "ldap_configs", - "overtime_balances", - "smtp_configs", - "users", - "work_schedules", + "absence_types", "audit_logs", "caldav_company_configs", "departments", + "kiosk_devices", "ldap_configs", "overtime_balances", "smtp_configs", + "users", "work_schedules", ] -# Tables where the row references users (which belong to a company) +# Linked via user_id → users.company_id USER_JOIN_TABLES = [ - "absences", - "caldav_user_configs", - "password_resets", - "sessions", - "time_entries", - "vacation_balances", + "absences", "caldav_user_configs", "password_resets", + "sessions", "time_entries", "vacation_balances", ] -# The companies table itself – restrict by id -COMPANIES_TABLE = "companies" +# public_holidays is global – no RLS -# public_holidays is global (no tenant column) – RLS not applied +# --------------------------------------------------------------------------- +# upgrade / downgrade +# --------------------------------------------------------------------------- + +def _exec(sql: str) -> None: + op.execute(text(sql)) + + +def _enable(table: str, using: str) -> None: + _exec(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY") + _exec(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY") + for cmd in ("SELECT", "INSERT", "UPDATE", "DELETE"): + _exec(f"DROP POLICY IF EXISTS rls_{table}_{cmd.lower()} ON {table}") + _exec(f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}") + _exec(f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {using}") + _exec(f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {using}") + _exec(f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}") + + +def _disable(table: str) -> None: + for cmd in ("select", "insert", "update", "delete"): + _exec(f"DROP POLICY IF EXISTS rls_{table}_{cmd} ON {table}") + _exec(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY") + _exec(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY") def upgrade() -> None: - conn = op.get_bind() + # companies: restrict by id (companies IS the tenant root) + _enable("companies", _using_iid()) - # --- companies table --- - conn.execute(__import__("sqlalchemy").text(_enable_rls(COMPANIES_TABLE))) - conn.execute(__import__("sqlalchemy").text(_force_rls(COMPANIES_TABLE))) - companies_using = ( - f"({_rls_bypass_expr()} OR " - f"id = NULLIF(current_setting('app.company_id', true), '')::uuid)" - ) - conn.execute(__import__("sqlalchemy").text( - f"DROP POLICY IF EXISTS rls_companies_select ON companies" - )) - conn.execute(__import__("sqlalchemy").text( - f"DROP POLICY IF EXISTS rls_companies_insert ON companies" - )) - conn.execute(__import__("sqlalchemy").text( - f"DROP POLICY IF EXISTS rls_companies_update ON companies" - )) - conn.execute(__import__("sqlalchemy").text( - f"DROP POLICY IF EXISTS rls_companies_delete ON companies" - )) - conn.execute(__import__("sqlalchemy").text( - f"CREATE POLICY rls_companies_select ON companies FOR SELECT USING {companies_using}" - )) - conn.execute(__import__("sqlalchemy").text( - f"CREATE POLICY rls_companies_insert ON companies FOR INSERT WITH CHECK {companies_using}" - )) - conn.execute(__import__("sqlalchemy").text( - f"CREATE POLICY rls_companies_update ON companies FOR UPDATE " - f"USING {companies_using} WITH CHECK {companies_using}" - )) - conn.execute(__import__("sqlalchemy").text( - f"CREATE POLICY rls_companies_delete ON companies FOR DELETE USING {companies_using}" - )) - - # --- tables with direct company_id --- for table in COMPANY_COL_TABLES: - conn.execute(__import__("sqlalchemy").text(_enable_rls(table))) - conn.execute(__import__("sqlalchemy").text(_force_rls(table))) - for drop_stmt in _drop_policies(table): - conn.execute(__import__("sqlalchemy").text(drop_stmt)) - for create_stmt in _create_policies_company_col(table): - conn.execute(__import__("sqlalchemy").text(create_stmt)) + _enable(table, _using_cid()) - # --- tables with user_id join --- for table in USER_JOIN_TABLES: - conn.execute(__import__("sqlalchemy").text(_enable_rls(table))) - conn.execute(__import__("sqlalchemy").text(_force_rls(table))) - for drop_stmt in _drop_policies(table): - conn.execute(__import__("sqlalchemy").text(drop_stmt)) - for create_stmt in _create_policies_user_join(table): - conn.execute(__import__("sqlalchemy").text(create_stmt)) + _enable(table, _using_join()) def downgrade() -> None: - conn = op.get_bind() - - # --- companies --- - for stmt in _drop_policies(COMPANIES_TABLE): - conn.execute(__import__("sqlalchemy").text(stmt)) - conn.execute(__import__("sqlalchemy").text(_no_force_rls(COMPANIES_TABLE))) - conn.execute(__import__("sqlalchemy").text(_disable_rls(COMPANIES_TABLE))) - - # --- tables with direct company_id --- + _disable("companies") for table in COMPANY_COL_TABLES: - for stmt in _drop_policies(table): - conn.execute(__import__("sqlalchemy").text(stmt)) - conn.execute(__import__("sqlalchemy").text(_no_force_rls(table))) - conn.execute(__import__("sqlalchemy").text(_disable_rls(table))) - - # --- tables with user_id join --- + _disable(table) for table in USER_JOIN_TABLES: - for stmt in _drop_policies(table): - conn.execute(__import__("sqlalchemy").text(stmt)) - conn.execute(__import__("sqlalchemy").text(_no_force_rls(table))) - conn.execute(__import__("sqlalchemy").text(_disable_rls(table))) + _disable(table) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index cedf07f..6334309 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -15,6 +15,53 @@ test_engine = create_async_engine(TEST_DATABASE_URL, echo=False) TestSessionLocal = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) +_BYPASS = "COALESCE(current_setting('app.bypass_rls', true), 'off') = 'on'" +_CID = "company_id = NULLIF(current_setting('app.company_id', true), '')::uuid" +_IID = "id = NULLIF(current_setting('app.company_id', true), '')::uuid" + +def _rls_using_cid(): return f"({_BYPASS} OR {_CID})" +def _rls_using_iid(): return f"({_BYPASS} OR {_IID})" +def _rls_using_join(): return ( + f"({_BYPASS} OR user_id IN (SELECT id FROM users WHERE {_CID}))" +) + +_COMPANY_COL_TABLES = [ + "absence_types", "audit_logs", "caldav_company_configs", "departments", + "kiosk_devices", "ldap_configs", "overtime_balances", "smtp_configs", + "users", "work_schedules", +] +_USER_JOIN_TABLES = [ + "absences", "caldav_user_configs", "password_resets", + "sessions", "time_entries", "vacation_balances", +] + + +async def _apply_rls(conn) -> None: + """Apply the same RLS policies as migration 0024 to the test database.""" + def enable(table: str, using: str): + return [ + f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY", + f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY", + f"DROP POLICY IF EXISTS rls_{table}_select ON {table}", + f"DROP POLICY IF EXISTS rls_{table}_insert ON {table}", + f"DROP POLICY IF EXISTS rls_{table}_update ON {table}", + f"DROP POLICY IF EXISTS rls_{table}_delete ON {table}", + f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}", + f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {using}", + f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {using}", + f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}", + ] + + for sql in enable("companies", _rls_using_iid()): + await conn.execute(text(sql)) + for table in _COMPANY_COL_TABLES: + for sql in enable(table, _rls_using_cid()): + await conn.execute(text(sql)) + for table in _USER_JOIN_TABLES: + for sql in enable(table, _rls_using_join()): + await conn.execute(text(sql)) + + @pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True) async def setup_db(): async with test_engine.begin() as conn: @@ -24,6 +71,7 @@ async def setup_db(): await conn.execute(text("GRANT ALL ON SCHEMA public TO timemaster")) await conn.execute(text("GRANT ALL ON SCHEMA public TO public")) await conn.run_sync(Base.metadata.create_all) + await _apply_rls(conn) yield async with test_engine.begin() as conn: await conn.execute(text("DROP SCHEMA public CASCADE")) diff --git a/backend/tests/test_rls.py b/backend/tests/test_rls.py new file mode 100644 index 0000000..e765c78 --- /dev/null +++ b/backend/tests/test_rls.py @@ -0,0 +1,190 @@ +""" +RLS-Tests: Verifiziert dass PostgreSQL Row Level Security +tatsächlich cross-tenant Datenzugriff blockiert – unabhängig +von App-seitigen WHERE-Klauseln. +""" +import pytest +import uuid +from httpx import AsyncClient +from sqlalchemy import text + +pytestmark = pytest.mark.asyncio + +REG_URL = "/api/v1/auth/register" +LOGIN_URL = "/api/v1/auth/login" + + +async def register_company(client: AsyncClient, suffix: str) -> dict: + """Registriert eine neue Firma und gibt tokens + company_id zurück.""" + resp = await client.post(REG_URL, json={ + "company_name": f"RLS-Firma-{suffix}", + "first_name": "Admin", + "last_name": suffix, + "email": f"rls-admin-{suffix}@test.de", + "password": "Secret123", + }) + assert resp.status_code == 201, resp.text + tokens = resp.json() + + me = await client.get( + "/api/v1/auth/me", + headers={"Authorization": f"Bearer {tokens['access_token']}"}, + ) + assert me.status_code == 200 + return {"tokens": tokens, "user": me.json()} + + +# ── Hilfsfunktion: Raw-SQL mit explizitem RLS-Kontext ──────────────────────── + +async def query_users_as_tenant(db_session, company_id: str) -> list[dict]: + """ + Führt SELECT * FROM users direkt aus, + mit gesetztem app.company_id – simuliert einen Request als dieser Mandant. + bypass_rls ist OFF, d.h. RLS greift. + """ + await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'")) + await db_session.execute( + text(f"SET LOCAL app.company_id = '{company_id}'") + ) + result = await db_session.execute(text("SELECT id, email, company_id FROM users")) + rows = [dict(r._mapping) for r in result.fetchall()] + # Reset: bypass wieder an (damit nachfolgende Tests nicht leiden) + await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'")) + return rows + + +# ── Tests ──────────────────────────────────────────────────────────────────── + +async def test_rls_tenant_a_cannot_see_tenant_b_users( + client: AsyncClient, db_session +): + """Mandant A darf Mandant B's Users nicht sehen (DB-Ebene).""" + a = await register_company(client, "RLS-A") + b = await register_company(client, "RLS-B") + + cid_a = str(a["user"]["company_id"]) + cid_b = str(b["user"]["company_id"]) + + # Als Mandant A abfragen + rows_as_a = await query_users_as_tenant(db_session, cid_a) + emails_as_a = {r["email"] for r in rows_as_a} + + # Mandant A sieht sich selbst + assert f"rls-admin-RLS-A@test.de" in emails_as_a, \ + "Mandant A sollte eigene Daten sehen" + + # Mandant A sieht Mandant B NICHT + assert f"rls-admin-RLS-B@test.de" not in emails_as_a, \ + "RLS BLOCKIERT NICHT: Mandant A sieht Daten von Mandant B!" + + # Doppelcheck: alle zurückgegebenen Rows gehören zu company A + for row in rows_as_a: + assert str(row["company_id"]) == cid_a, \ + f"Fremder Mandant in Ergebnis: {row}" + + +async def test_rls_tenant_b_cannot_see_tenant_a_users( + client: AsyncClient, db_session +): + """Symmetrietest: Mandant B sieht Mandant A nicht.""" + # Firmen aus vorherigem Test existieren schon – direkt aus DB holen + result = await db_session.execute( + text("SELECT id FROM companies WHERE name LIKE 'RLS-Firma-%' ORDER BY name") + ) + companies = [str(r[0]) for r in result.fetchall()] + assert len(companies) >= 2, "Firmen aus vorherigem Test fehlen" + + cid_a, cid_b = companies[0], companies[1] + + rows_as_b = await query_users_as_tenant(db_session, cid_b) + company_ids_seen = {str(r["company_id"]) for r in rows_as_b} + + assert cid_a not in company_ids_seen, \ + "RLS BLOCKIERT NICHT: Mandant B sieht Daten von Mandant A!" + assert cid_b in company_ids_seen, \ + "Mandant B sieht keine eigenen Daten – RLS zu restriktiv" + + +async def test_rls_no_context_returns_nothing(db_session): + """Ohne gesetztes app.company_id gibt SELECT nichts zurück (kein bypass).""" + await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'")) + # app.company_id explizit leeren + await db_session.execute(text("SET LOCAL app.company_id = ''")) + + result = await db_session.execute(text("SELECT id FROM users")) + rows = result.fetchall() + + # Reset + await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'")) + + assert len(rows) == 0, \ + f"RLS BLOCKIERT NICHT: {len(rows)} Rows ohne company_id-Kontext sichtbar!" + + +async def test_rls_bypass_on_sees_all(db_session): + """Mit bypass_rls='on' sieht ein SUPER_ADMIN alle Mandanten.""" + await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'")) + result = await db_session.execute( + text("SELECT DISTINCT company_id FROM users") + ) + company_ids = {str(r[0]) for r in result.fetchall()} + + assert len(company_ids) >= 2, \ + "SUPER_ADMIN sollte mindestens 2 Mandanten sehen (bypass_rls='on')" + + +async def test_rls_companies_table_isolated(db_session): + """Auch die companies-Tabelle ist mandanten-isoliert.""" + result = await db_session.execute( + text("SELECT id FROM companies WHERE name LIKE 'RLS-Firma-%' ORDER BY name") + ) + companies = [str(r[0]) for r in result.fetchall()] + assert len(companies) >= 2 + + cid_a, cid_b = companies[0], companies[1] + + # Als Mandant A: sehe nur meine Firma + await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'")) + await db_session.execute(text(f"SET LOCAL app.company_id = '{cid_a}'")) + + result = await db_session.execute(text("SELECT id FROM companies")) + visible = [str(r[0]) for r in result.fetchall()] + + await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'")) + + assert cid_a in visible, "Eigene Firma nicht sichtbar" + assert cid_b not in visible, "Fremde Firma trotz RLS sichtbar!" + assert len(visible) == 1, f"Mehr als 1 Firma sichtbar: {visible}" + + +async def test_rls_insert_blocked_for_wrong_tenant(db_session): + """INSERT in fremden Mandanten wird durch WITH CHECK blockiert.""" + result = await db_session.execute( + text("SELECT id FROM companies WHERE name LIKE 'RLS-Firma-%' ORDER BY name LIMIT 2") + ) + companies = [str(r[0]) for r in result.fetchall()] + cid_a, cid_b = companies[0], companies[1] + + await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'")) + await db_session.execute(text(f"SET LOCAL app.company_id = '{cid_a}'")) + + try: + # Versuche, einen User unter Mandant B zu erstellen, während Context = A + fake_id = str(uuid.uuid4()) + await db_session.execute(text(f""" + INSERT INTO users (id, company_id, email, first_name, last_name, + role, password_hash, is_active, created_at) + VALUES ('{fake_id}', '{cid_b}', 'rls-inject@evil.de', + 'Evil', 'Inject', 'EMPLOYEE', 'fakehash', false, NOW()) + """)) + await db_session.flush() + # Wenn wir hier ankommen, hat RLS nicht blockiert + await db_session.rollback() + pytest.fail("RLS WITH CHECK hat INSERT in fremden Mandanten NICHT blockiert!") + except Exception as e: + await db_session.rollback() + # Erwartet: new row violates row-level security policy + assert "row-level security" in str(e).lower() or "policy" in str(e).lower(), \ + f"Unerwarteter Fehler (kein RLS-Fehler): {e}" + finally: + await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))