fix: router db.refresh() nach commit bricht RLS-Kontext

SET LOCAL Werte (bypass_rls, company_id) sind transaktions-gebunden.
Nach db.commit() ist der Kontext weg – ein nachfolgendes db.refresh()
läuft in einer neuen Transaktion ohne RLS-Kontext und liefert 0 Rows.

Da expire_on_commit=False gesetzt ist, sind alle Instanz-Attribute
nach dem Commit bereits im Speicher vorhanden. Die expliziten
db.refresh()-Aufrufe nach db.commit() in allen Routers sind daher
redundant und wurden entfernt.

test_rls.py: 6 neue Tests beweisen DB-seitige Mandanten-Isolation.
conftest.py: _apply_rls() wendet RLS-Policies auf Test-DB an.
migrations/0024: korrigiert auf op.execute(text()) API.
migrations/env.py: SET LOCAL außerhalb Transaktion entfernt.

Ergebnis: 8 failed (pre-existing), 126 passed – identisch zur Baseline vor RLS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-23 22:34:48 +02:00
parent 6d4b8a9f17
commit dd3e069466
12 changed files with 305 additions and 202 deletions
+16
View File
@@ -638,3 +638,19 @@ Keine Commits in dieser Session.
- update.sh | 53 +++++++++++++++++++++++------------------------------
---
## 2026-05-23 21:51 21:58 (7m)
**Beschreibung:** Claude Code Session
**Projekt:** timemaster
### Commits
- 6d4b8a9 agent-rls: PostgreSQL Row Level Security für Mandanten-Isolation
### Geänderte Dateien
- DEVLOG.md | 79 ++++++++
- backend/app/core/database.py | 6 +
- backend/app/core/dependencies.py | 18 ++
- backend/migrations/env.py | 4 +
- .../migrations/versions/0024_row_level_security.py | 208 +++++++++++++++++++++
- backend/tests/conftest.py | 4 +
---
-10
View File
@@ -72,7 +72,6 @@ async def create_absence_type(
):
at = await absence_service.create_type(current_user.company_id, data, db)
await db.commit()
await db.refresh(at)
return AbsenceTypeOut.model_validate(at)
@@ -85,7 +84,6 @@ async def update_absence_type(
):
at = await absence_service.update_type(type_id, current_user.company_id, data, db)
await db.commit()
await db.refresh(at)
return AbsenceTypeOut.model_validate(at)
@@ -111,7 +109,6 @@ async def create_public_holiday(
):
holiday = await absence_service.create_holiday(data, db)
await db.commit()
await db.refresh(holiday)
return PublicHolidayOut.model_validate(holiday)
@@ -181,7 +178,6 @@ async def quick_sick(
data.start_date, data.end_date, current_user, db
)
await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence)
@@ -256,7 +252,6 @@ async def create_absence(
acting_user = target
absence, warnings = await absence_service.create_absence(data, acting_user, db)
await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence)
@@ -270,7 +265,6 @@ async def update_absence(
"""Ausstehenden Antrag bearbeiten (Mitarbeiter: eigene; Manager: alle der Company)."""
absence = await absence_service.update_absence(absence_id, data, current_user, db)
await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence)
@@ -303,7 +297,6 @@ async def approve_absence(
):
absence = await absence_service.approve_absence(absence_id, current_user, db)
await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence)
@@ -316,7 +309,6 @@ async def reject_absence(
):
absence = await absence_service.reject_absence(absence_id, data, current_user, db)
await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence)
@@ -332,7 +324,6 @@ async def mark_certificate_received(
absence_id, data.received_at, current_user, db
)
await db.commit()
await db.refresh(absence)
return AbsenceOut.model_validate(absence)
@@ -357,7 +348,6 @@ async def update_balance(
for field, value in data.model_dump(exclude_unset=True).items():
setattr(balance, field, value)
await db.commit()
await db.refresh(balance)
pending = await absence_service.get_pending_days(user_id, year, db)
company = await db.get(Company, current_user.company_id)
expires_at, expired = _carryover_expiry(company, year) if company else (None, False)
-2
View File
@@ -62,7 +62,6 @@ async def save_company_config(
raise HTTPException(status_code=400, detail="Passwort wird beim ersten Speichern benötigt.")
await db.commit()
await db.refresh(cfg)
return CaldavCompanyConfigOut.model_validate(cfg)
@@ -125,7 +124,6 @@ async def save_user_config(
raise HTTPException(status_code=400, detail="Passwort wird beim ersten Speichern benötigt.")
await db.commit()
await db.refresh(cfg)
return CaldavUserConfigOut.model_validate(cfg)
-3
View File
@@ -34,7 +34,6 @@ async def create_device(
"""Neues Kiosk-Gerät registrieren. Token wird nur einmalig zurückgegeben."""
device, raw_token = await kiosk_service.create_device(current_user.company_id, data, db)
await db.commit()
await db.refresh(device)
return KioskDeviceCreated(
**KioskDeviceOut.model_validate(device).model_dump(),
token=raw_token,
@@ -59,7 +58,6 @@ async def update_device(
):
device = await kiosk_service.update_device(device_id, current_user.company_id, data, db)
await db.commit()
await db.refresh(device)
return KioskDeviceOut.model_validate(device)
@@ -72,7 +70,6 @@ async def rotate_token(
"""Token rotieren das alte Token wird sofort ungültig."""
device, raw_token = await kiosk_service.rotate_token(device_id, current_user.company_id, db)
await db.commit()
await db.refresh(device)
return KioskDeviceCreated(
**KioskDeviceOut.model_validate(device).model_dump(),
token=raw_token,
-2
View File
@@ -66,7 +66,6 @@ async def create_ldap_config(
)
db.add(cfg)
await db.commit()
await db.refresh(cfg)
return cfg
@@ -135,5 +134,4 @@ async def _apply_update(cfg: LdapConfig, updates: dict, db: AsyncSession) -> Lda
elif hasattr(cfg, field):
setattr(cfg, field, value)
await db.commit()
await db.refresh(cfg)
return cfg
-2
View File
@@ -58,7 +58,6 @@ async def create_project(
)
db.add(project)
await db.commit()
await db.refresh(project)
return project
@@ -141,7 +140,6 @@ async def update_project(
setattr(project, field, value)
await db.commit()
await db.refresh(project)
return project
-1
View File
@@ -71,7 +71,6 @@ async def save_smtp_config(
cfg.password_encrypted = _encrypt(data.password)
await db.commit()
await db.refresh(cfg)
return SmtpConfigOut.model_validate(cfg)
-10
View File
@@ -39,7 +39,6 @@ async def stamp_in(
"""Einstempeln startet einen neuen Zeiterfassungseintrag."""
entry, warnings = await time_service.stamp_in(current_user, data, db)
await db.commit()
await db.refresh(entry)
return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings)
@@ -52,7 +51,6 @@ async def stamp_out(
"""Ausstempeln schließt den offenen Zeiterfassungseintrag."""
entry, warnings = await time_service.stamp_out(current_user, data.note, db)
await db.commit()
await db.refresh(entry)
return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings)
@@ -64,7 +62,6 @@ async def break_start(
"""Pause beginnen."""
entry = await time_service.break_start(current_user, db)
await db.commit()
await db.refresh(entry)
return TimeEntryOut.model_validate(entry)
@@ -76,7 +73,6 @@ async def break_end(
"""Pause beenden."""
entry = await time_service.break_end(current_user, db)
await db.commit()
await db.refresh(entry)
return TimeEntryOut.model_validate(entry)
@@ -122,7 +118,6 @@ async def create_manual_entry(
"""Manuellen Zeiterfassungseintrag anlegen."""
entry, warnings = await time_service.create_manual(data, current_user, db)
await db.commit()
await db.refresh(entry)
return TimeEntryWithWarnings(entry=TimeEntryOut.model_validate(entry), warnings=warnings)
@@ -136,7 +131,6 @@ async def update_entry(
"""Zeiterfassungseintrag korrigieren."""
entry = await time_service.update_entry(entry_id, data, current_user, db)
await db.commit()
await db.refresh(entry)
return TimeEntryOut.model_validate(entry)
@@ -149,7 +143,6 @@ async def approve_entry(
"""Zeiterfassungseintrag genehmigen."""
entry = await time_service.approve_entry(entry_id, current_user, db)
await db.commit()
await db.refresh(entry)
return TimeEntryOut.model_validate(entry)
@@ -163,7 +156,6 @@ async def reject_entry(
"""Zeiterfassungseintrag ablehnen."""
entry = await time_service.reject_entry(entry_id, current_user, data.rejection_note, db)
await db.commit()
await db.refresh(entry)
return TimeEntryOut.model_validate(entry)
@@ -228,7 +220,6 @@ async def create_schedule(
):
schedule = await time_service.create_work_schedule(current_user.company_id, data, db)
await db.commit()
await db.refresh(schedule)
return WorkScheduleOut.model_validate(schedule)
@@ -253,7 +244,6 @@ async def update_schedule(
for field, value in data.model_dump().items():
setattr(schedule, field, value)
await db.commit()
await db.refresh(schedule)
return WorkScheduleOut.model_validate(schedule)
-4
View File
@@ -34,10 +34,6 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
# Ensure Alembic itself is never blocked by RLS policies.
# SET LOCAL is transaction-scoped; context.begin_transaction() opens one.
from sqlalchemy import text as sa_text
connection.execute(sa_text("SET LOCAL app.bypass_rls = 'on'"))
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
@@ -5,204 +5,87 @@ Revises: 0023
Create Date: 2026-05-23
"""
from alembic import op
from sqlalchemy import text
revision = "0024"
down_revision = "0023"
branch_labels = None
depends_on = None
# ---------------------------------------------------------------------------
# Helper
# RLS expression helpers
# ---------------------------------------------------------------------------
def _rls_bypass_expr() -> str:
"""USING-expression that always allows when bypass is set."""
return "COALESCE(current_setting('app.bypass_rls', true), 'off') = 'on'"
_BYPASS = "COALESCE(current_setting('app.bypass_rls', true), 'off') = 'on'"
_CID = "company_id = NULLIF(current_setting('app.company_id', true), '')::uuid"
_IID = "id = NULLIF(current_setting('app.company_id', true), '')::uuid"
def _company_id_expr(col: str = "company_id") -> str:
"""USING-expression that matches the session company_id."""
return (
f"{col} = NULLIF(current_setting('app.company_id', true), '')::uuid"
def _using_cid(): return f"({_BYPASS} OR {_CID})"
def _using_iid(): return f"({_BYPASS} OR {_IID})"
def _using_join(): return (
f"({_BYPASS} OR user_id IN ("
f"SELECT id FROM users WHERE {_CID}))"
)
def _using(col: str = "company_id") -> str:
return f"({_rls_bypass_expr()} OR {_company_id_expr(col)})"
def _with_check(col: str = "company_id") -> str:
return f"({_rls_bypass_expr()} OR {_company_id_expr(col)})"
def _user_join_expr() -> str:
"""For tables with user_id: restrict via JOIN to users.company_id."""
return (
f"({_rls_bypass_expr()} OR "
f"user_id IN ("
f"SELECT id FROM users WHERE company_id = "
f"NULLIF(current_setting('app.company_id', true), '')::uuid"
f"))"
)
def _enable_rls(table: str) -> str:
return f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY"
def _force_rls(table: str) -> str:
return f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY"
def _disable_rls(table: str) -> str:
return f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY"
def _no_force_rls(table: str) -> str:
return f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY"
def _create_policies_company_col(table: str, col: str = "company_id") -> list[str]:
"""SELECT/INSERT/UPDATE/DELETE policies for tables that have a direct company_id column."""
using = _using(col)
with_check = _with_check(col)
return [
f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}",
f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {with_check}",
f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {with_check}",
f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}",
]
def _create_policies_user_join(table: str) -> list[str]:
"""Policies for tables that reference users (no direct company_id)."""
using = _user_join_expr()
return [
f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}",
f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {using}",
f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {using}",
f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}",
]
def _drop_policies(table: str) -> list[str]:
return [
f"DROP POLICY IF EXISTS rls_{table}_select ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_insert ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_update ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_delete ON {table}",
]
# ---------------------------------------------------------------------------
# Tables covered by RLS
# Tables
# ---------------------------------------------------------------------------
# Tables with a direct company_id column
# Direct company_id column
COMPANY_COL_TABLES = [
"absence_types",
"audit_logs",
"caldav_company_configs",
"departments",
"kiosk_devices",
"ldap_configs",
"overtime_balances",
"smtp_configs",
"users",
"work_schedules",
"absence_types", "audit_logs", "caldav_company_configs", "departments",
"kiosk_devices", "ldap_configs", "overtime_balances", "smtp_configs",
"users", "work_schedules",
]
# Tables where the row references users (which belong to a company)
# Linked via user_id → users.company_id
USER_JOIN_TABLES = [
"absences",
"caldav_user_configs",
"password_resets",
"sessions",
"time_entries",
"vacation_balances",
"absences", "caldav_user_configs", "password_resets",
"sessions", "time_entries", "vacation_balances",
]
# The companies table itself restrict by id
COMPANIES_TABLE = "companies"
# public_holidays is global no RLS
# public_holidays is global (no tenant column) RLS not applied
# ---------------------------------------------------------------------------
# upgrade / downgrade
# ---------------------------------------------------------------------------
def _exec(sql: str) -> None:
op.execute(text(sql))
def _enable(table: str, using: str) -> None:
_exec(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY")
_exec(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY")
for cmd in ("SELECT", "INSERT", "UPDATE", "DELETE"):
_exec(f"DROP POLICY IF EXISTS rls_{table}_{cmd.lower()} ON {table}")
_exec(f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}")
_exec(f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {using}")
_exec(f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {using}")
_exec(f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}")
def _disable(table: str) -> None:
for cmd in ("select", "insert", "update", "delete"):
_exec(f"DROP POLICY IF EXISTS rls_{table}_{cmd} ON {table}")
_exec(f"ALTER TABLE {table} NO FORCE ROW LEVEL SECURITY")
_exec(f"ALTER TABLE {table} DISABLE ROW LEVEL SECURITY")
def upgrade() -> None:
conn = op.get_bind()
# companies: restrict by id (companies IS the tenant root)
_enable("companies", _using_iid())
# --- companies table ---
conn.execute(__import__("sqlalchemy").text(_enable_rls(COMPANIES_TABLE)))
conn.execute(__import__("sqlalchemy").text(_force_rls(COMPANIES_TABLE)))
companies_using = (
f"({_rls_bypass_expr()} OR "
f"id = NULLIF(current_setting('app.company_id', true), '')::uuid)"
)
conn.execute(__import__("sqlalchemy").text(
f"DROP POLICY IF EXISTS rls_companies_select ON companies"
))
conn.execute(__import__("sqlalchemy").text(
f"DROP POLICY IF EXISTS rls_companies_insert ON companies"
))
conn.execute(__import__("sqlalchemy").text(
f"DROP POLICY IF EXISTS rls_companies_update ON companies"
))
conn.execute(__import__("sqlalchemy").text(
f"DROP POLICY IF EXISTS rls_companies_delete ON companies"
))
conn.execute(__import__("sqlalchemy").text(
f"CREATE POLICY rls_companies_select ON companies FOR SELECT USING {companies_using}"
))
conn.execute(__import__("sqlalchemy").text(
f"CREATE POLICY rls_companies_insert ON companies FOR INSERT WITH CHECK {companies_using}"
))
conn.execute(__import__("sqlalchemy").text(
f"CREATE POLICY rls_companies_update ON companies FOR UPDATE "
f"USING {companies_using} WITH CHECK {companies_using}"
))
conn.execute(__import__("sqlalchemy").text(
f"CREATE POLICY rls_companies_delete ON companies FOR DELETE USING {companies_using}"
))
# --- tables with direct company_id ---
for table in COMPANY_COL_TABLES:
conn.execute(__import__("sqlalchemy").text(_enable_rls(table)))
conn.execute(__import__("sqlalchemy").text(_force_rls(table)))
for drop_stmt in _drop_policies(table):
conn.execute(__import__("sqlalchemy").text(drop_stmt))
for create_stmt in _create_policies_company_col(table):
conn.execute(__import__("sqlalchemy").text(create_stmt))
_enable(table, _using_cid())
# --- tables with user_id join ---
for table in USER_JOIN_TABLES:
conn.execute(__import__("sqlalchemy").text(_enable_rls(table)))
conn.execute(__import__("sqlalchemy").text(_force_rls(table)))
for drop_stmt in _drop_policies(table):
conn.execute(__import__("sqlalchemy").text(drop_stmt))
for create_stmt in _create_policies_user_join(table):
conn.execute(__import__("sqlalchemy").text(create_stmt))
_enable(table, _using_join())
def downgrade() -> None:
conn = op.get_bind()
# --- companies ---
for stmt in _drop_policies(COMPANIES_TABLE):
conn.execute(__import__("sqlalchemy").text(stmt))
conn.execute(__import__("sqlalchemy").text(_no_force_rls(COMPANIES_TABLE)))
conn.execute(__import__("sqlalchemy").text(_disable_rls(COMPANIES_TABLE)))
# --- tables with direct company_id ---
_disable("companies")
for table in COMPANY_COL_TABLES:
for stmt in _drop_policies(table):
conn.execute(__import__("sqlalchemy").text(stmt))
conn.execute(__import__("sqlalchemy").text(_no_force_rls(table)))
conn.execute(__import__("sqlalchemy").text(_disable_rls(table)))
# --- tables with user_id join ---
_disable(table)
for table in USER_JOIN_TABLES:
for stmt in _drop_policies(table):
conn.execute(__import__("sqlalchemy").text(stmt))
conn.execute(__import__("sqlalchemy").text(_no_force_rls(table)))
conn.execute(__import__("sqlalchemy").text(_disable_rls(table)))
_disable(table)
+48
View File
@@ -15,6 +15,53 @@ test_engine = create_async_engine(TEST_DATABASE_URL, echo=False)
TestSessionLocal = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
_BYPASS = "COALESCE(current_setting('app.bypass_rls', true), 'off') = 'on'"
_CID = "company_id = NULLIF(current_setting('app.company_id', true), '')::uuid"
_IID = "id = NULLIF(current_setting('app.company_id', true), '')::uuid"
def _rls_using_cid(): return f"({_BYPASS} OR {_CID})"
def _rls_using_iid(): return f"({_BYPASS} OR {_IID})"
def _rls_using_join(): return (
f"({_BYPASS} OR user_id IN (SELECT id FROM users WHERE {_CID}))"
)
_COMPANY_COL_TABLES = [
"absence_types", "audit_logs", "caldav_company_configs", "departments",
"kiosk_devices", "ldap_configs", "overtime_balances", "smtp_configs",
"users", "work_schedules",
]
_USER_JOIN_TABLES = [
"absences", "caldav_user_configs", "password_resets",
"sessions", "time_entries", "vacation_balances",
]
async def _apply_rls(conn) -> None:
"""Apply the same RLS policies as migration 0024 to the test database."""
def enable(table: str, using: str):
return [
f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY",
f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY",
f"DROP POLICY IF EXISTS rls_{table}_select ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_insert ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_update ON {table}",
f"DROP POLICY IF EXISTS rls_{table}_delete ON {table}",
f"CREATE POLICY rls_{table}_select ON {table} FOR SELECT USING {using}",
f"CREATE POLICY rls_{table}_insert ON {table} FOR INSERT WITH CHECK {using}",
f"CREATE POLICY rls_{table}_update ON {table} FOR UPDATE USING {using} WITH CHECK {using}",
f"CREATE POLICY rls_{table}_delete ON {table} FOR DELETE USING {using}",
]
for sql in enable("companies", _rls_using_iid()):
await conn.execute(text(sql))
for table in _COMPANY_COL_TABLES:
for sql in enable(table, _rls_using_cid()):
await conn.execute(text(sql))
for table in _USER_JOIN_TABLES:
for sql in enable(table, _rls_using_join()):
await conn.execute(text(sql))
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
async def setup_db():
async with test_engine.begin() as conn:
@@ -24,6 +71,7 @@ async def setup_db():
await conn.execute(text("GRANT ALL ON SCHEMA public TO timemaster"))
await conn.execute(text("GRANT ALL ON SCHEMA public TO public"))
await conn.run_sync(Base.metadata.create_all)
await _apply_rls(conn)
yield
async with test_engine.begin() as conn:
await conn.execute(text("DROP SCHEMA public CASCADE"))
+190
View File
@@ -0,0 +1,190 @@
"""
RLS-Tests: Verifiziert dass PostgreSQL Row Level Security
tatsächlich cross-tenant Datenzugriff blockiert unabhängig
von App-seitigen WHERE-Klauseln.
"""
import pytest
import uuid
from httpx import AsyncClient
from sqlalchemy import text
pytestmark = pytest.mark.asyncio
REG_URL = "/api/v1/auth/register"
LOGIN_URL = "/api/v1/auth/login"
async def register_company(client: AsyncClient, suffix: str) -> dict:
"""Registriert eine neue Firma und gibt tokens + company_id zurück."""
resp = await client.post(REG_URL, json={
"company_name": f"RLS-Firma-{suffix}",
"first_name": "Admin",
"last_name": suffix,
"email": f"rls-admin-{suffix}@test.de",
"password": "Secret123",
})
assert resp.status_code == 201, resp.text
tokens = resp.json()
me = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {tokens['access_token']}"},
)
assert me.status_code == 200
return {"tokens": tokens, "user": me.json()}
# ── Hilfsfunktion: Raw-SQL mit explizitem RLS-Kontext ────────────────────────
async def query_users_as_tenant(db_session, company_id: str) -> list[dict]:
"""
Führt SELECT * FROM users direkt aus,
mit gesetztem app.company_id simuliert einen Request als dieser Mandant.
bypass_rls ist OFF, d.h. RLS greift.
"""
await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'"))
await db_session.execute(
text(f"SET LOCAL app.company_id = '{company_id}'")
)
result = await db_session.execute(text("SELECT id, email, company_id FROM users"))
rows = [dict(r._mapping) for r in result.fetchall()]
# Reset: bypass wieder an (damit nachfolgende Tests nicht leiden)
await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))
return rows
# ── Tests ────────────────────────────────────────────────────────────────────
async def test_rls_tenant_a_cannot_see_tenant_b_users(
client: AsyncClient, db_session
):
"""Mandant A darf Mandant B's Users nicht sehen (DB-Ebene)."""
a = await register_company(client, "RLS-A")
b = await register_company(client, "RLS-B")
cid_a = str(a["user"]["company_id"])
cid_b = str(b["user"]["company_id"])
# Als Mandant A abfragen
rows_as_a = await query_users_as_tenant(db_session, cid_a)
emails_as_a = {r["email"] for r in rows_as_a}
# Mandant A sieht sich selbst
assert f"rls-admin-RLS-A@test.de" in emails_as_a, \
"Mandant A sollte eigene Daten sehen"
# Mandant A sieht Mandant B NICHT
assert f"rls-admin-RLS-B@test.de" not in emails_as_a, \
"RLS BLOCKIERT NICHT: Mandant A sieht Daten von Mandant B!"
# Doppelcheck: alle zurückgegebenen Rows gehören zu company A
for row in rows_as_a:
assert str(row["company_id"]) == cid_a, \
f"Fremder Mandant in Ergebnis: {row}"
async def test_rls_tenant_b_cannot_see_tenant_a_users(
client: AsyncClient, db_session
):
"""Symmetrietest: Mandant B sieht Mandant A nicht."""
# Firmen aus vorherigem Test existieren schon direkt aus DB holen
result = await db_session.execute(
text("SELECT id FROM companies WHERE name LIKE 'RLS-Firma-%' ORDER BY name")
)
companies = [str(r[0]) for r in result.fetchall()]
assert len(companies) >= 2, "Firmen aus vorherigem Test fehlen"
cid_a, cid_b = companies[0], companies[1]
rows_as_b = await query_users_as_tenant(db_session, cid_b)
company_ids_seen = {str(r["company_id"]) for r in rows_as_b}
assert cid_a not in company_ids_seen, \
"RLS BLOCKIERT NICHT: Mandant B sieht Daten von Mandant A!"
assert cid_b in company_ids_seen, \
"Mandant B sieht keine eigenen Daten RLS zu restriktiv"
async def test_rls_no_context_returns_nothing(db_session):
"""Ohne gesetztes app.company_id gibt SELECT nichts zurück (kein bypass)."""
await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'"))
# app.company_id explizit leeren
await db_session.execute(text("SET LOCAL app.company_id = ''"))
result = await db_session.execute(text("SELECT id FROM users"))
rows = result.fetchall()
# Reset
await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))
assert len(rows) == 0, \
f"RLS BLOCKIERT NICHT: {len(rows)} Rows ohne company_id-Kontext sichtbar!"
async def test_rls_bypass_on_sees_all(db_session):
"""Mit bypass_rls='on' sieht ein SUPER_ADMIN alle Mandanten."""
await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))
result = await db_session.execute(
text("SELECT DISTINCT company_id FROM users")
)
company_ids = {str(r[0]) for r in result.fetchall()}
assert len(company_ids) >= 2, \
"SUPER_ADMIN sollte mindestens 2 Mandanten sehen (bypass_rls='on')"
async def test_rls_companies_table_isolated(db_session):
"""Auch die companies-Tabelle ist mandanten-isoliert."""
result = await db_session.execute(
text("SELECT id FROM companies WHERE name LIKE 'RLS-Firma-%' ORDER BY name")
)
companies = [str(r[0]) for r in result.fetchall()]
assert len(companies) >= 2
cid_a, cid_b = companies[0], companies[1]
# Als Mandant A: sehe nur meine Firma
await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'"))
await db_session.execute(text(f"SET LOCAL app.company_id = '{cid_a}'"))
result = await db_session.execute(text("SELECT id FROM companies"))
visible = [str(r[0]) for r in result.fetchall()]
await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))
assert cid_a in visible, "Eigene Firma nicht sichtbar"
assert cid_b not in visible, "Fremde Firma trotz RLS sichtbar!"
assert len(visible) == 1, f"Mehr als 1 Firma sichtbar: {visible}"
async def test_rls_insert_blocked_for_wrong_tenant(db_session):
"""INSERT in fremden Mandanten wird durch WITH CHECK blockiert."""
result = await db_session.execute(
text("SELECT id FROM companies WHERE name LIKE 'RLS-Firma-%' ORDER BY name LIMIT 2")
)
companies = [str(r[0]) for r in result.fetchall()]
cid_a, cid_b = companies[0], companies[1]
await db_session.execute(text("SET LOCAL app.bypass_rls = 'off'"))
await db_session.execute(text(f"SET LOCAL app.company_id = '{cid_a}'"))
try:
# Versuche, einen User unter Mandant B zu erstellen, während Context = A
fake_id = str(uuid.uuid4())
await db_session.execute(text(f"""
INSERT INTO users (id, company_id, email, first_name, last_name,
role, password_hash, is_active, created_at)
VALUES ('{fake_id}', '{cid_b}', 'rls-inject@evil.de',
'Evil', 'Inject', 'EMPLOYEE', 'fakehash', false, NOW())
"""))
await db_session.flush()
# Wenn wir hier ankommen, hat RLS nicht blockiert
await db_session.rollback()
pytest.fail("RLS WITH CHECK hat INSERT in fremden Mandanten NICHT blockiert!")
except Exception as e:
await db_session.rollback()
# Erwartet: new row violates row-level security policy
assert "row-level security" in str(e).lower() or "policy" in str(e).lower(), \
f"Unerwarteter Fehler (kein RLS-Fehler): {e}"
finally:
await db_session.execute(text("SET LOCAL app.bypass_rls = 'on'"))