fix: router db.refresh() nach commit bricht RLS-Kontext

SET LOCAL Werte (bypass_rls, company_id) sind transaktions-gebunden.
Nach db.commit() ist der Kontext weg – ein nachfolgendes db.refresh()
läuft in einer neuen Transaktion ohne RLS-Kontext und liefert 0 Rows.

Da expire_on_commit=False gesetzt ist, sind alle Instanz-Attribute
nach dem Commit bereits im Speicher vorhanden. Die expliziten
db.refresh()-Aufrufe nach db.commit() in allen Routers sind daher
redundant und wurden entfernt.

test_rls.py: 6 neue Tests beweisen DB-seitige Mandanten-Isolation.
conftest.py: _apply_rls() wendet RLS-Policies auf Test-DB an.
migrations/0024: korrigiert auf op.execute(text()) API.
migrations/env.py: SET LOCAL außerhalb Transaktion entfernt.

Ergebnis: 8 failed (pre-existing), 126 passed – identisch zur Baseline vor RLS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-23 22:34:48 +02:00
parent 6d4b8a9f17
commit dd3e069466
12 changed files with 305 additions and 202 deletions
+16
View File
@@ -638,3 +638,19 @@ Keine Commits in dieser Session.
- update.sh | 53 +++++++++++++++++++++++------------------------------ - update.sh | 53 +++++++++++++++++++++++------------------------------
--- ---
## 2026-05-23 21:51 21:58 (7m)
**Beschreibung:** Claude Code Session
**Projekt:** timemaster
### Commits
- 6d4b8a9 agent-rls: PostgreSQL Row Level Security für Mandanten-Isolation
### Geänderte Dateien
- DEVLOG.md | 79 ++++++++
- backend/app/core/database.py | 6 +
- backend/app/core/dependencies.py | 18 ++
- backend/migrations/env.py | 4 +
- .../migrations/versions/0024_row_level_security.py | 208 +++++++++++++++++++++
- backend/tests/conftest.py | 4 +
---
-10
View File
@@ -72,7 +72,6 @@ async def create_absence_type(
): ):
at = await absence_service.create_type(current_user.company_id, data, db) at = await absence_service.create_type(current_user.company_id, data, db)
await db.commit() await db.commit()
await db.refresh(at)
return AbsenceTypeOut.model_validate(at) return AbsenceTypeOut.model_validate(at)
@@ -85,7 +84,6 @@ async def update_absence_type(
): ):
at = await absence_service.update_type(type_id, current_user.company_id, data, db) at = await absence_service.update_type(type_id, current_user.company_id, data, db)
await db.commit() await db.commit()
await db.refresh(at)
return AbsenceTypeOut.model_validate(at) return AbsenceTypeOut.model_validate(at)
@@ -111,7 +109,6 @@ async def create_public_holiday(
): ):
holiday = await absence_service.create_holiday(data, db) holiday = await absence_service.create_holiday(data, db)
await db.commit() await db.commit()
await db.refresh(holiday)
return PublicHolidayOut.model_validate(holiday) return PublicHolidayOut.model_validate(holiday)
@@ -181,7 +178,6 @@ async def quick_sick(
data.start_date, data.end_date, current_user, db data.start_date, data.end_date, current_user, db
) )
await db.commit() await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence) return AbsenceOut.model_validate(absence)
@@ -256,7 +252,6 @@ async def create_absence(
acting_user = target acting_user = target
absence, warnings = await absence_service.create_absence(data, acting_user, db) absence, warnings = await absence_service.create_absence(data, acting_user, db)
await db.commit() await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence) return AbsenceOut.model_validate(absence)
@@ -270,7 +265,6 @@ async def update_absence(
"""Ausstehenden Antrag bearbeiten (Mitarbeiter: eigene; Manager: alle der Company).""" """Ausstehenden Antrag bearbeiten (Mitarbeiter: eigene; Manager: alle der Company)."""
absence = await absence_service.update_absence(absence_id, data, current_user, db) absence = await absence_service.update_absence(absence_id, data, current_user, db)
await db.commit() await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence) return AbsenceOut.model_validate(absence)
@@ -303,7 +297,6 @@ async def approve_absence(
): ):
absence = await absence_service.approve_absence(absence_id, current_user, db) absence = await absence_service.approve_absence(absence_id, current_user, db)
await db.commit() await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence) return AbsenceOut.model_validate(absence)
@@ -316,7 +309,6 @@ async def reject_absence(
): ):
absence = await absence_service.reject_absence(absence_id, data, current_user, db) absence = await absence_service.reject_absence(absence_id, data, current_user, db)
await db.commit() await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence) return AbsenceOut.model_validate(absence)
@@ -332,7 +324,6 @@ async def mark_certificate_received(
absence_id, data.received_at, current_user, db absence_id, data.received_at, current_user, db
) )
await db.commit() await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence) return AbsenceOut.model_validate(absence)
@@ -357,7 +348,6 @@ async def update_balance(
for field, value in data.model_dump(exclude_unset=True).items(): for field, value in data.model_dump(exclude_unset=True).items():
setattr(balance, field, value) setattr(balance, field, value)
await db.commit() await db.commit()
await db.refresh(balance)
pending = await absence_service.get_pending_days(user_id, year, db) pending = await absence_service.get_pending_days(user_id, year, db)
company = await db.get(Company, current_user.company_id) company = await db.get(Company, current_user.company_id)
expires_at, expired = _carryover_expiry(company, year) if company else (None, False) expires_at, expired = _carryover_expiry(company, year) if company else (None, False)
-2
View File
@@ -62,7 +62,6 @@ async def save_company_config(
raise HTTPException(status_code=400, detail="Passwort wird beim ersten Speichern benötigt.") raise HTTPException(status_code=400, detail="Passwort wird beim ersten Speichern benötigt.")
await db.commit() await db.commit()
await db.refresh(cfg)
return CaldavCompanyConfigOut.model_validate(cfg) return CaldavCompanyConfigOut.model_validate(cfg)
@@ -125,7 +124,6 @@ async def save_user_config(
raise HTTPException(status_code=400, detail="Passwort wird beim ersten Speichern benötigt.") raise HTTPException(status_code=400, detail="Passwort wird beim ersten Speichern benötigt.")
await db.commit() await db.commit()
await db.refresh(cfg)
return CaldavUserConfigOut.model_validate(cfg) return CaldavUserConfigOut.model_validate(cfg)
-3
View File
@@ -34,7 +34,6 @@ async def create_device(
"""Neues Kiosk-Gerät registrieren. Token wird nur einmalig zurückgegeben.""" """Neues Kiosk-Gerät registrieren. Token wird nur einmalig zurückgegeben."""
device, raw_token = await kiosk_service.create_device(current_user.company_id, data, db) device, raw_token = await kiosk_service.create_device(current_user.company_id, data, db)
await db.commit() await db.commit()
await db.refresh(device)
return KioskDeviceCreated( return KioskDeviceCreated(
**KioskDeviceOut.model_validate(device).model_dump(), **KioskDeviceOut.model_validate(device).model_dump(),
token=raw_token, token=raw_token,
@@ -59,7 +58,6 @@ async def update_device(
): ):
device = await kiosk_service.update_device(device_id, current_user.company_id, data, db) device = await kiosk_service.update_device(device_id, current_user.company_id, data, db)
await db.commit() await db.commit()
await db.refresh(device)
return KioskDeviceOut.model_validate(device) return KioskDeviceOut.model_validate(device)
@@ -72,7 +70,6 @@ async def rotate_token(
"""Token rotieren das alte Token wird sofort ungültig.""" """Token rotieren das alte Token wird sofort ungültig."""
device, raw_token = await kiosk_service.rotate_token(device_id, current_user.company_id, db) device, raw_token = await kiosk_service.rotate_token(device_id, current_user.company_id, db)
await db.commit() await db.commit()
await db.refresh(device)
return KioskDeviceCreated( return KioskDeviceCreated(
**KioskDeviceOut.model_validate(device).model_dump(), **KioskDeviceOut.model_validate(device).model_dump(),
token=raw_token, token=raw_token,
-2
View File
@@ -66,7 +66,6 @@ async def create_ldap_config(
) )
db.add(cfg) db.add(cfg)
await db.commit() await db.commit()
await db.refresh(cfg)
return cfg return cfg
@@ -135,5 +134,4 @@ async def _apply_update(cfg: LdapConfig, updates: dict, db: AsyncSession) -> Lda
elif hasattr(cfg, field): elif hasattr(cfg, field):
setattr(cfg, field, value) setattr(cfg, field, value)
await db.commit() await db.commit()
await db.refresh(cfg)
return cfg return cfg
-2
View File
@@ -58,7 +58,6 @@ async def create_project(
) )
db.add(project) db.add(project)
await db.commit() await db.commit()
await db.refresh(project)
return project return project
@@ -141,7 +140,6 @@ async def update_project(
setattr(project, field, value) setattr(project, field, value)
await db.commit() await db.commit()
await db.refresh(project)
return project return project
-1
View File
@@ -71,7 +71,6 @@ async def save_smtp_config(
cfg.password_encrypted = _encrypt(data.password) cfg.password_encrypted = _encrypt(data.password)
await db.commit() await db.commit()
await db.refresh(cfg)
return SmtpConfigOut.model_validate(cfg) return SmtpConfigOut.model_validate(cfg)
-10
View File
@@ -39,7 +39,6 @@ async def stamp_in(
"""Einstempeln startet einen neuen Zeiterfassungseintrag.""" """Einstempeln startet einen neuen Zeiterfassungseintrag."""
entry, warnings = await time_service.stamp_in(current_user, data, db) entry, warnings = await time_service.stamp_in(current_user, data, db)
await db.commit() await db.commit()
await db.refresh(entry)
return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings) return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings)
@@ -52,7 +51,6 @@ async def stamp_out(
"""Ausstempeln schließt den offenen Zeiterfassungseintrag.""" """Ausstempeln schließt den offenen Zeiterfassungseintrag."""
entry, warnings = await time_service.stamp_out(current_user, data.note, db) entry, warnings = await time_service.stamp_out(current_user, data.note, db)
await db.commit() await db.commit()
await db.refresh(entry)
return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings) return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings)
@@ -64,7 +62,6 @@ async def break_start(
"""Pause beginnen.""" """Pause beginnen."""
entry = await time_service.break_start(current_user, db) entry = await time_service.break_start(current_user, db)
await db.commit() await db.commit()
await db.refresh(entry)
return TimeEntryOut.model_validate(entry) return TimeEntryOut.model_validate(entry)
@@ -76,7 +73,6 @@ async def break_end(
"""Pause beenden.""" """Pause beenden."""
entry = await time_service.break_end(current_user, db) entry = await time_service.break_end(current_user, db)
await db.commit() await db.commit()
await db.refresh(entry)
return TimeEntryOut.model_validate(entry) return TimeEntryOut.model_validate(entry)
@@ -122,7 +118,6 @@ async def create_manual_entry(
"""Manuellen Zeiterfassungseintrag anlegen.""" """Manuellen Zeiterfassungseintrag anlegen."""
entry, warnings = await time_service.create_manual(data, current_user, db) entry, warnings = await time_service.create_manual(data, current_user, db)
await db.commit() await db.commit()
await db.refresh(entry)
return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings) return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings)
@@ -136,7 +131,6 @@ async def update_entry(
"""Zeiterfassungseintrag korrigieren.""" """Zeiterfassungseintrag korrigieren."""
entry = await time_service.update_entry(entry_id, data, current_user, db) entry = await time_service.update_entry(entry_id, data, current_user, db)
await db.commit() await db.commit()
await db.refresh(entry)
return TimeEntryOut.model_validate(entry) return TimeEntryOut.model_validate(entry)
@@ -149,7 +143,6 @@ async def approve_entry(
"""Zeiterfassungseintrag genehmigen.""" """Zeiterfassungseintrag genehmigen."""
entry = await time_service.approve_entry(entry_id, current_user, db) entry = await time_service.approve_entry(entry_id, current_user, db)
await db.commit() await db.commit()
await db.refresh(entry)
return TimeEntryOut.model_validate(entry) return TimeEntryOut.model_validate(entry)
@@ -163,7 +156,6 @@ async def reject_entry(
"""Zeiterfassungseintrag ablehnen.""" """Zeiterfassungseintrag ablehnen."""
entry = await time_service.reject_entry(entry_id, current_user, data.rejection_note, db) entry = await time_service.reject_entry(entry_id, current_user, data.rejection_note, db)
await db.commit() await db.commit()
await db.refresh(entry)
return TimeEntryOut.model_validate(entry) return TimeEntryOut.model_validate(entry)
@@ -228,7 +220,6 @@ async def create_schedule(
): ):
schedule = await time_service.create_work_schedule(current_user.company_id, data, db) schedule = await time_service.create_work_schedule(current_user.company_id, data, db)
await db.commit() await db.commit()
await db.refresh(schedule)
return WorkScheduleOut.model_validate(schedule) return WorkScheduleOut.model_validate(schedule)
@@ -253,7 +244,6 @@ async def update_schedule(
for field, value in data.model_dump().items(): for field, value in data.model_dump().items():
setattr(schedule, field, value) setattr(schedule, field, value)
await db.commit() await db.commit()
await db.refresh(schedule)
return WorkScheduleOut.model_validate(schedule) return WorkScheduleOut.model_validate(schedule)
-4
View File
@@ -34,10 +34,6 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None: def do_run_migrations(connection: Connection) -> None:
# Ensure Alembic itself is never blocked by RLS policies.
# SET LOCAL is transaction-scoped; context.begin_transaction() opens one.
from sqlalchemy import text as sa_text
connection.execute(sa_text("SET LOCAL app.bypass_rls = 'on'"))
context.configure(connection=connection, target_metadata=target_metadata) context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()
@@ -5,204 +5,87 @@ Revises: 0023
Create Date: 2026-05-23 Create Date: 2026-05-23
""" """
from alembic import op from alembic import op
from sqlalchemy import text
revision = "0024" revision = "0024"
down_revision = "0023" down_revision = "0023"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helper # RLS expression helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _rls_bypass_expr() -> str: _BYPASS = "COALESCE(current_setting('app.bypass_rls', true), 'off') = 'on'"
"""USING-expression that always allows when bypass is set.""" _CID = "company_id = NULLIF(current_setting('app.company_id', true), '')::uuid"
return "COALESCE(current_setting('app.bypass_rls', true), 'off') = 'on'" _IID = "id = NULLIF(current_setting('app.company_id', true), '')::uuid"
def _using_cid(): return f"({_BYPASS} OR {_CID})"
def _company_id_expr(col: str = "company_id") -> str: def _using_iid(): return f"({_BYPASS} OR {_IID})"
"""USING-expression that matches the session company_id.""" def _using_join(): return (
return ( f"({_BYPASS} OR user_id IN ("
f"{col} = NULLIF(current_setting('app.company_id', true), '')::uuid" f"SELECT id FROM users WHERE {_CID}))"
) )
def _using(col: str = "company_id") -> str:
return f"({_rls_bypass_expr()} OR {_company_id_expr(col)})"
def _with_check(col: str = "company_id") -> str:
return f"({_rls_bypass_expr()} OR {_company_id_expr(col)})"
def _user_join_expr() -> str:
"""For tables with user_id: restrict via JOIN to users.company_id."""
return (
f"({_rls_bypass_expr()} OR "
f"user_id IN ("
f"SELECT id FROM users WHERE company_id = "
f"NULLIF(current_setting('app.company_id', true), '')::uuid"
f"))"
)
def _enable_rls(table: str) -> str:
return f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY"
def _force_rls(table: str) -> str:
return f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY"
def _disable_rls(table: str) -> str:
return f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY"
def _no_force_rls(table: str) -> str:
return f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY"
def _create_policies_company_col(table: str, col: str = "company_id") -> list[str]:
"""SELECT/INSERT/UPDATE/DELETE policies for tables that have a direct company_id column."""
using = _using(col)
with_check = _with_check(col)
return [
f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}",
f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {with_check}",
f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {with_check}",
f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}",
]
def _create_policies_user_join(table: str) -> list[str]:
"""Policies for tables that reference users (no direct company_id)."""
using = _user_join_expr()
return [
f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}",
f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {using}",
f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {using}",
f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}",
]
def _drop_policies(table: str) -> list[str]:
return [
f"DROP POLICY IF EXISTS rls_{table}_select ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_insert ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_update ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_delete ON {table}",
]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Tables covered by RLS # Tables
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Tables with a direct company_id column # Direct company_id column
COMPANY_COL_TABLES = [ COMPANY_COL_TABLES = [
"absence_types", "absence_types", "audit_logs", "caldav_company_configs", "departments",
"audit_logs", "kiosk_devices", "ldap_configs", "overtime_balances", "smtp_configs",
"caldav_company_configs", "users", "work_schedules",
"departments",
"kiosk_devices",
"ldap_configs",
"overtime_balances",
"smtp_configs",
"users",
"work_schedules",
] ]
# Tables where the row references users (which belong to a company) # Linked via user_id → users.company_id
USER_JOIN_TABLES = [ USER_JOIN_TABLES = [
"absences", "absences", "caldav_user_configs", "password_resets",
"caldav_user_configs", "sessions", "time_entries", "vacation_balances",
"password_resets",
"sessions",
"time_entries",
"vacation_balances",
] ]
# The companies table itself restrict by id # public_holidays is global no RLS
COMPANIES_TABLE = "companies"
# public_holidays is global (no tenant column) RLS not applied # ---------------------------------------------------------------------------
# upgrade / downgrade
# ---------------------------------------------------------------------------
def _exec(sql: str) -> None:
op.execute(text(sql))
def _enable(table: str, using: str) -> None:
_exec(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
_exec(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
for cmd in ("SELECT", "INSERT", "UPDATE", "DELETE"):
_exec(f"DROP POLICY IF EXISTS rls_{table}_{cmd.lower()} ON {table}")
_exec(f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}")
_exec(f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {using}")
_exec(f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {using}")
_exec(f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}")
def _disable(table: str) -> None:
for cmd in ("select", "insert", "update", "delete"):
_exec(f"DROP POLICY IF EXISTS rls_{table}_{cmd} ON {table}")
_exec(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")
_exec(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
def upgrade() -> None: def upgrade() -> None:
conn = op.get_bind() # companies: restrict by id (companies IS the tenant root)
_enable("companies", _using_iid())
# --- companies table ---
conn.execute(__import__("sqlalchemy").text(_enable_rls(COMPANIES_TABLE)))
conn.execute(__import__("sqlalchemy").text(_force_rls(COMPANIES_TABLE)))
companies_using = (
f"({_rls_bypass_expr()} OR "
f"id = NULLIF(current_setting('app.company_id', true), '')::uuid)"
)
conn.execute(__import__("sqlalchemy").text(
f"DROP POLICY IF EXISTS rls_companies_select ON companies"
))
conn.execute(__import__("sqlalchemy").text(
f"DROP POLICY IF EXISTS rls_companies_insert ON companies"
))
conn.execute(__import__("sqlalchemy").text(
f"DROP POLICY IF EXISTS rls_companies_update ON companies"
))
conn.execute(__import__("sqlalchemy").text(
f"DROP POLICY IF EXISTS rls_companies_delete ON companies"
))
conn.execute(__import__("sqlalchemy").text(
f"CREATE POLICY rls_companies_select ON companies FOR SELECT USING {companies_using}"
))
conn.execute(__import__("sqlalchemy").text(
f"CREATE POLICY rls_companies_insert ON companies FOR INSERT WITH CHECK {companies_using}"
))
conn.execute(__import__("sqlalchemy").text(
f"CREATE POLICY rls_companies_update ON companies FOR UPDATE "
f"USING {companies_using} WITH CHECK {companies_using}"
))
conn.execute(__import__("sqlalchemy").text(
f"CREATE POLICY rls_companies_delete ON companies FOR DELETE USING {companies_using}"
))
# --- tables with direct company_id ---
for table in COMPANY_COL_TABLES: for table in COMPANY_COL_TABLES:
conn.execute(__import__("sqlalchemy").text(_enable_rls(table))) _enable(table, _using_cid())
conn.execute(__import__("sqlalchemy").text(_force_rls(table)))
for drop_stmt in _drop_policies(table):
conn.execute(__import__("sqlalchemy").text(drop_stmt))
for create_stmt in _create_policies_company_col(table):
conn.execute(__import__("sqlalchemy").text(create_stmt))
# --- tables with user_id join ---
for table in USER_JOIN_TABLES: for table in USER_JOIN_TABLES:
conn.execute(__import__("sqlalchemy").text(_enable_rls(table))) _enable(table, _using_join())
conn.execute(__import__("sqlalchemy").text(_force_rls(table)))
for drop_stmt in _drop_policies(table):
conn.execute(__import__("sqlalchemy").text(drop_stmt))
for create_stmt in _create_policies_user_join(table):
conn.execute(__import__("sqlalchemy").text(create_stmt))
def downgrade() -> None: def downgrade() -> None:
conn = op.get_bind() _disable("companies")
# --- companies ---
for stmt in _drop_policies(COMPANIES_TABLE):
conn.execute(__import__("sqlalchemy").text(stmt))
conn.execute(__import__("sqlalchemy").text(_no_force_rls(COMPANIES_TABLE)))
conn.execute(__import__("sqlalchemy").text(_disable_rls(COMPANIES_TABLE)))
# --- tables with direct company_id ---
for table in COMPANY_COL_TABLES: for table in COMPANY_COL_TABLES:
for stmt in _drop_policies(table): _disable(table)
conn.execute(__import__("sqlalchemy").text(stmt))
conn.execute(__import__("sqlalchemy").text(_no_force_rls(table)))
conn.execute(__import__("sqlalchemy").text(_disable_rls(table)))
# --- tables with user_id join ---
for table in USER_JOIN_TABLES: for table in USER_JOIN_TABLES:
for stmt in _drop_policies(table): _disable(table)
conn.execute(__import__("sqlalchemy").text(stmt))
conn.execute(__import__("sqlalchemy").text(_no_force_rls(table)))
conn.execute(__import__("sqlalchemy").text(_disable_rls(table)))
+48
View File
@@ -15,6 +15,53 @@ test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
TestSessionLocal = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False) TestSessionLocal = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
_BYPASS = "COALESCE(current_setting('app.bypass_rls', true), 'off') = 'on'"
_CID = "company_id = NULLIF(current_setting('app.company_id', true), '')::uuid"
_IID = "id = NULLIF(current_setting('app.company_id', true), '')::uuid"
def _rls_using_cid(): return f"({_BYPASS} OR {_CID})"
def _rls_using_iid(): return f"({_BYPASS} OR {_IID})"
def _rls_using_join(): return (
f"({_BYPASS} OR user_id IN (SELECT id FROM users WHERE {_CID}))"
)
_COMPANY_COL_TABLES = [
"absence_types", "audit_logs", "caldav_company_configs", "departments",
"kiosk_devices", "ldap_configs", "overtime_balances", "smtp_configs",
"users", "work_schedules",
]
_USER_JOIN_TABLES = [
"absences", "caldav_user_configs", "password_resets",
"sessions", "time_entries", "vacation_balances",
]
async def _apply_rls(conn) -> None:
"""Apply the same RLS policies as migration 0024 to the test database."""
def enable(table: str, using: str):
return [
f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY",
f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY",
f"DROP POLICY IF EXISTS rls_{table}_select ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_insert ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_update ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_delete ON {table}",
f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}",
f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {using}",
f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {using}",
f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}",
]
for sql in enable("companies", _rls_using_iid()):
await conn.execute(text(sql))
for table in _COMPANY_COL_TABLES:
for sql in enable(table, _rls_using_cid()):
await conn.execute(text(sql))
for table in _USER_JOIN_TABLES:
for sql in enable(table, _rls_using_join()):
await conn.execute(text(sql))
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True) @pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
async def setup_db(): async def setup_db():
async with test_engine.begin() as conn: async with test_engine.begin() as conn:
@@ -24,6 +71,7 @@ async def setup_db():
await conn.execute(text("GRANT ALL ON SCHEMA public TO timemaster")) await conn.execute(text("GRANT ALL ON SCHEMA public TO timemaster"))
await conn.execute(text("GRANT ALL ON SCHEMA public TO public")) await conn.execute(text("GRANT ALL ON SCHEMA public TO public"))
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
await _apply_rls(conn)
yield yield
async with test_engine.begin() as conn: async with test_engine.begin() as conn:
await conn.execute(text("DROP SCHEMA public CASCADE")) await conn.execute(text("DROP SCHEMA public CASCADE"))
+190
View File
@@ -0,0 +1,190 @@
"""
RLS-Tests: Verifiziert dass PostgreSQL Row Level Security
tatsächlich cross-tenant Datenzugriff blockiert unabhängig
von App-seitigen WHERE-Klauseln.
"""
import pytest
import uuid
from httpx import AsyncClient
from sqlalchemy import text
pytestmark = pytest.mark.asyncio
REG_URL = "/api/v1/auth/register"
LOGIN_URL = "/api/v1/auth/login"
async def register_company(client: AsyncClient, suffix: str) -> dict:
"""Registriert eine neue Firma und gibt tokens + company_id zurück."""
resp = await client.post(REG_URL, json={
"company_name": f"RLS-Firma-{suffix}",
"first_name": "Admin",
"last_name": suffix,
"email": f"rls-admin-{suffix}@test.de",
"password": "Secret123",
})
assert resp.status_code == 201, resp.text
tokens = resp.json()
me = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert me.status_code == 200
return {"tokens": tokens, "user": me.json()}
# ── Hilfsfunktion: Raw-SQL mit explizitem RLS-Kontext ────────────────────────
async def query_users_as_tenant(db_session, company_id: str) -> list[dict]:
"""
Führt SELECT * FROM users direkt aus,
mit gesetztem app.company_id simuliert einen Request als dieser Mandant.
bypass_rls ist OFF, d.h. RLS greift.
"""
await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'"))
await db_session.execute(
text(f"SET LOCAL app.company_id = '{company_id}'")
)
result = await db_session.execute(text("SELECT id, email, company_id FROM users"))
rows = [dict(r._mapping) for r in result.fetchall()]
# Reset: bypass wieder an (damit nachfolgende Tests nicht leiden)
await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))
return rows
# ── Tests ────────────────────────────────────────────────────────────────────
async def test_rls_tenant_a_cannot_see_tenant_b_users(
client: AsyncClient, db_session
):
"""Mandant A darf Mandant B's Users nicht sehen (DB-Ebene)."""
a = await register_company(client, "RLS-A")
b = await register_company(client, "RLS-B")
cid_a = str(a["user"]["company_id"])
cid_b = str(b["user"]["company_id"])
# Als Mandant A abfragen
rows_as_a = await query_users_as_tenant(db_session, cid_a)
emails_as_a = {r["email"] for r in rows_as_a}
# Mandant A sieht sich selbst
assert f"rls-admin-RLS-A@test.de" in emails_as_a, \
"Mandant A sollte eigene Daten sehen"
# Mandant A sieht Mandant B NICHT
assert f"rls-admin-RLS-B@test.de" not in emails_as_a, \
"RLS BLOCKIERT NICHT: Mandant A sieht Daten von Mandant B!"
# Doppelcheck: alle zurückgegebenen Rows gehören zu company A
for row in rows_as_a:
assert str(row["company_id"]) == cid_a, \
f"Fremder Mandant in Ergebnis: {row}"
async def test_rls_tenant_b_cannot_see_tenant_a_users(
client: AsyncClient, db_session
):
"""Symmetrietest: Mandant B sieht Mandant A nicht."""
# Firmen aus vorherigem Test existieren schon direkt aus DB holen
result = await db_session.execute(
text("SELECT id FROM companies WHERE name LIKE 'RLS-Firma-%' ORDER BY name")
)
companies = [str(r[0]) for r in result.fetchall()]
assert len(companies) >= 2, "Firmen aus vorherigem Test fehlen"
cid_a, cid_b = companies[0], companies[1]
rows_as_b = await query_users_as_tenant(db_session, cid_b)
company_ids_seen = {str(r["company_id"]) for r in rows_as_b}
assert cid_a not in company_ids_seen, \
"RLS BLOCKIERT NICHT: Mandant B sieht Daten von Mandant A!"
assert cid_b in company_ids_seen, \
"Mandant B sieht keine eigenen Daten RLS zu restriktiv"
async def test_rls_no_context_returns_nothing(db_session):
"""Ohne gesetztes app.company_id gibt SELECT nichts zurück (kein bypass)."""
await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'"))
# app.company_id explizit leeren
await db_session.execute(text("SET LOCAL app.company_id = ''"))
result = await db_session.execute(text("SELECT id FROM users"))
rows = result.fetchall()
# Reset
await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))
assert len(rows) == 0, \
f"RLS BLOCKIERT NICHT: {len(rows)} Rows ohne company_id-Kontext sichtbar!"
async def test_rls_bypass_on_sees_all(db_session):
"""Mit bypass_rls='on' sieht ein SUPER_ADMIN alle Mandanten."""
await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))
result = await db_session.execute(
text("SELECT DISTINCT company_id FROM users")
)
company_ids = {str(r[0]) for r in result.fetchall()}
assert len(company_ids) >= 2, \
"SUPER_ADMIN sollte mindestens 2 Mandanten sehen (bypass_rls='on')"
async def test_rls_companies_table_isolated(db_session):
"""Auch die companies-Tabelle ist mandanten-isoliert."""
result = await db_session.execute(
text("SELECT id FROM companies WHERE name LIKE 'RLS-Firma-%' ORDER BY name")
)
companies = [str(r[0]) for r in result.fetchall()]
assert len(companies) >= 2
cid_a, cid_b = companies[0], companies[1]
# Als Mandant A: sehe nur meine Firma
await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'"))
await db_session.execute(text(f"SET LOCAL app.company_id = '{cid_a}'"))
result = await db_session.execute(text("SELECT id FROM companies"))
visible = [str(r[0]) for r in result.fetchall()]
await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))
assert cid_a in visible, "Eigene Firma nicht sichtbar"
assert cid_b not in visible, "Fremde Firma trotz RLS sichtbar!"
assert len(visible) == 1, f"Mehr als 1 Firma sichtbar: {visible}"
async def test_rls_insert_blocked_for_wrong_tenant(db_session):
"""INSERT in fremden Mandanten wird durch WITH CHECK blockiert."""
result = await db_session.execute(
text("SELECT id FROM companies WHERE name LIKE 'RLS-Firma-%' ORDER BY name LIMIT 2")
)
companies = [str(r[0]) for r in result.fetchall()]
cid_a, cid_b = companies[0], companies[1]
await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'"))
await db_session.execute(text(f"SET LOCAL app.company_id = '{cid_a}'"))
try:
# Versuche, einen User unter Mandant B zu erstellen, während Context = A
fake_id = str(uuid.uuid4())
await db_session.execute(text(f"""
INSERT INTO users (id, company_id, email, first_name, last_name,
role, password_hash, is_active, created_at)
VALUES ('{fake_id}', '{cid_b}', 'rls-inject@evil.de',
'Evil', 'Inject', 'EMPLOYEE', 'fakehash', false, NOW())
"""))
await db_session.flush()
# Wenn wir hier ankommen, hat RLS nicht blockiert
await db_session.rollback()
pytest.fail("RLS WITH CHECK hat INSERT in fremden Mandanten NICHT blockiert!")
except Exception as e:
await db_session.rollback()
# Erwartet: new row violates row-level security policy
assert "row-level security" in str(e).lower() or "policy" in str(e).lower(), \
f"Unerwarteter Fehler (kein RLS-Fehler): {e}"
finally:
await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))