package audit_test import ( "bufio" "context" "encoding/json" "log/slog" "os" "path/filepath" "strings" "testing" "time" "github.com/jackc/pgx/v5" "archivmail/internal/audit" ) // testDSN returns a schema-scoped DSN for an isolated audit_log table, or skips // the test if no PostgreSQL is configured. func testDSN(t *testing.T) (dsn, schemaDSN string) { t.Helper() dsn = os.Getenv("TEST_DATABASE_URL") if dsn == "" { t.Skip("TEST_DATABASE_URL not set — skipping (needs PostgreSQL)") } schema := "autest_" + strings.ToLower(strings.ReplaceAll(t.Name(), "/", "_")) if len(schema) > 63 { schema = schema[:63] } ctx := context.Background() conn, err := pgx.Connect(ctx, dsn) if err != nil { t.Fatalf("connect: %v", err) } conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schema) conn.Close(ctx) t.Cleanup(func() { conn2, _ := pgx.Connect(context.Background(), dsn) if conn2 != nil { conn2.Exec(context.Background(), "DROP SCHEMA "+schema+" CASCADE") conn2.Close(context.Background()) } }) sep := "?" if strings.Contains(dsn, "?") { sep = "&" } return dsn, dsn + sep + "search_path=" + schema } func newTestAudit(t *testing.T) *audit.Logger { t.Helper() dsn := os.Getenv("TEST_DATABASE_URL") if dsn == "" { t.Skip("TEST_DATABASE_URL not set — skipping (needs PostgreSQL)") } schema := "autest_" + strings.ToLower(strings.ReplaceAll(t.Name(), "/", "_")) // truncate schema name to 63 chars (PostgreSQL limit) if len(schema) > 63 { schema = schema[:63] } ctx := context.Background() conn, err := pgx.Connect(ctx, dsn) if err != nil { t.Fatalf("connect: %v", err) } conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schema) conn.Close(ctx) sep := "?" if strings.Contains(dsn, "?") { sep = "&" } schemaDSN := dsn + sep + "search_path=" + schema logger := slog.New(slog.NewTextHandler(os.Discard, nil)) l, err := audit.New(schemaDSN, t.TempDir(), logger) if err != nil { t.Fatalf("audit.New: %v", err) } t.Cleanup(func() { l.Close() conn2, _ := pgx.Connect(context.Background(), dsn) if conn2 != nil { conn2.Exec(context.Background(), "DROP SCHEMA "+schema+" CASCADE") conn2.Close(context.Background()) } }) return l } func TestLogAndQuery(t *testing.T) { l := newTestAudit(t) l.Log(audit.Entry{ EventType: audit.EventLogin, Username: "alice", IPAddress: "192.168.1.1", Success: true, }) l.Log(audit.Entry{ EventType: audit.EventSearch, Username: "alice", IPAddress: "192.168.1.1", Query: "invoice", Success: true, }) l.Log(audit.Entry{ EventType: audit.EventLogin, Username: "bob", IPAddress: "10.0.0.1", Success: false, Detail: "wrong password", }) all, total, err := l.Query(audit.QueryFilter{PageSize: 50}) if err != nil { t.Fatalf("Query all: %v", err) } if total != 3 { t.Errorf("expected 3 entries, got %d", total) } _ = all } func TestQueryByUsername(t *testing.T) { l := newTestAudit(t) l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true}) l.Log(audit.Entry{EventType: audit.EventSearch, Username: "alice", Success: true}) l.Log(audit.Entry{EventType: audit.EventLogin, Username: "bob", Success: true}) entries, total, _ := l.Query(audit.QueryFilter{Username: "alice", PageSize: 50}) if total != 2 { t.Errorf("alice: expected 2 entries, got %d", total) } for _, e := range entries { if e.Username != "alice" { t.Errorf("got entry for user %q in alice filter", e.Username) } } } func TestQueryByEventType(t *testing.T) { l := newTestAudit(t) l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true}) l.Log(audit.Entry{EventType: audit.EventSearch, Username: "alice", Success: true}) l.Log(audit.Entry{EventType: audit.EventMailView, Username: "alice", MailID: "abc123", Success: true}) _, total, _ := l.Query(audit.QueryFilter{EventType: audit.EventSearch, PageSize: 50}) if total != 1 { t.Errorf("search event filter: expected 1, got %d", total) } } func TestQueryByMailID(t *testing.T) { l := newTestAudit(t) l.Log(audit.Entry{EventType: audit.EventMailView, Username: "alice", MailID: "mail-001", Success: true}) l.Log(audit.Entry{EventType: audit.EventMailView, Username: "bob", MailID: "mail-001", Success: true}) l.Log(audit.Entry{EventType: audit.EventMailView, Username: "alice", MailID: "mail-002", Success: true}) _, total, _ := l.Query(audit.QueryFilter{MailID: "mail-001", PageSize: 50}) if total != 2 { t.Errorf("mailID filter: expected 2, got %d", total) } } func TestQueryDateRange(t *testing.T) { l := newTestAudit(t) l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true}) l.Log(audit.Entry{EventType: audit.EventLogin, Username: "bob", Success: true}) // Query with future date range — should return 0 future := time.Now().Add(24 * time.Hour) futureEnd := time.Now().Add(48 * time.Hour) _, total, _ := l.Query(audit.QueryFilter{From: &future, To: &futureEnd, PageSize: 50}) if total != 0 { t.Errorf("future date range should return 0, got %d", total) } // Query with past-to-now range — should return all past := time.Now().Add(-1 * time.Minute) now := time.Now().Add(1 * time.Minute) _, total, _ = l.Query(audit.QueryFilter{From: &past, To: &now, PageSize: 50}) if total != 2 { t.Errorf("current date range should return 2, got %d", total) } } func TestQueryPagination(t *testing.T) { l := newTestAudit(t) for i := 0; i < 10; i++ { l.Log(audit.Entry{EventType: audit.EventSearch, Username: "alice", Success: true}) } page0, total, _ := l.Query(audit.QueryFilter{PageSize: 4, Page: 0}) _, _, _ = l.Query(audit.QueryFilter{PageSize: 4, Page: 1}) page2, _, _ := l.Query(audit.QueryFilter{PageSize: 4, Page: 2}) if total != 10 { t.Errorf("total = %d, want 10", total) } if len(page0) != 4 { t.Errorf("page 0 len = %d, want 4", len(page0)) } if len(page2) != 2 { t.Errorf("page 2 len = %d, want 2", len(page2)) } } // TestImmutableTrigger verifies the PROJ-48 DB trigger rejects UPDATE and // DELETE on audit_log through the normal application connection. func TestImmutableTrigger(t *testing.T) { dsn, schemaDSN := testDSN(t) logger := slog.New(slog.NewTextHandler(os.Discard, nil)) l, err := audit.New(schemaDSN, t.TempDir(), logger) if err != nil { t.Fatalf("audit.New: %v", err) } defer l.Close() l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true}) // Connect through the same search_path to hit the protected table. _ = dsn ctx := context.Background() conn, err := pgx.Connect(ctx, schemaDSN) if err != nil { t.Fatalf("connect: %v", err) } defer conn.Close(ctx) if _, err := conn.Exec(ctx, "UPDATE audit_log SET detail = 'tampered'"); err == nil { t.Error("UPDATE on audit_log should fail but succeeded") } if _, err := conn.Exec(ctx, "DELETE FROM audit_log"); err == nil { t.Error("DELETE on audit_log should fail but succeeded") } // INSERT must still work (verified indirectly via Query count). _, total, _ := l.Query(audit.QueryFilter{PageSize: 50}) if total != 1 { t.Errorf("expected 1 surviving entry, got %d", total) } } // TestFileLogging verifies that each event is mirrored to the append-only // JSON-Lines file (PROJ-48). func TestFileLogging(t *testing.T) { _, schemaDSN := testDSN(t) logPath := filepath.Join(t.TempDir(), "audit.log") logger := slog.New(slog.NewTextHandler(os.Discard, nil)) l, err := audit.New(schemaDSN, logPath, logger) if err != nil { t.Fatalf("audit.New: %v", err) } l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", IPAddress: "10.0.0.1", Success: true}) l.Log(audit.Entry{EventType: audit.EventSearch, Username: "bob", Query: "invoice", Success: false, Detail: "denied"}) l.Close() f, err := os.Open(logPath) if err != nil { t.Fatalf("open log file: %v", err) } defer f.Close() var lines []map[string]any sc := bufio.NewScanner(f) for sc.Scan() { var m map[string]any if err := json.Unmarshal(sc.Bytes(), &m); err != nil { t.Fatalf("invalid JSON line %q: %v", sc.Text(), err) } lines = append(lines, m) } if len(lines) != 2 { t.Fatalf("expected 2 log lines, got %d", len(lines)) } if lines[0]["username"] != "alice" || lines[0]["event_type"] != "login" { t.Errorf("line 0 unexpected: %v", lines[0]) } if _, err := time.Parse(time.RFC3339, lines[0]["timestamp"].(string)); err != nil { t.Errorf("timestamp not RFC3339: %v", lines[0]["timestamp"]) } } // TestFileLoggingUnwritableContinues verifies the service tolerates an // unwritable log path (DB-only logging continues, no panic). func TestFileLoggingUnwritableContinues(t *testing.T) { _, schemaDSN := testDSN(t) // A directory cannot be opened for writing → file logging disabled. logPath := t.TempDir() logger := slog.New(slog.NewTextHandler(os.Discard, nil)) l, err := audit.New(schemaDSN, logPath, logger) if err != nil { t.Fatalf("audit.New should not fail on unwritable path: %v", err) } defer l.Close() l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true}) _, total, _ := l.Query(audit.QueryFilter{PageSize: 50}) if total != 1 { t.Errorf("DB logging must continue, got total=%d", total) } }