feat(PROJ-17): Admin Dashboard Systemauslastung immer anzeigen
- Systemauslastungs-Sektion wird immer gerendert (nicht nur bei Erfolg) - Fehlermeldung wenn /api/admin/system/stats nicht erreichbar ist - Feature-Status auf In Review gesetzt Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,269 @@
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/archivmail/config"
|
||||
"github.com/archivmail/internal/api"
|
||||
"github.com/archivmail/internal/audit"
|
||||
"github.com/archivmail/internal/auth"
|
||||
"github.com/archivmail/internal/index"
|
||||
"github.com/archivmail/internal/storage"
|
||||
"github.com/archivmail/internal/userstore"
|
||||
)
|
||||
|
||||
type testEnv struct {
|
||||
server *api.Server
|
||||
users *userstore.Store
|
||||
store *storage.Store
|
||||
idx index.Indexer
|
||||
}
|
||||
|
||||
func newTestEnv(t *testing.T) *testEnv {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
logger := slog.New(slog.NewTextHandler(os.Discard, nil))
|
||||
|
||||
store, err := storage.New(filepath.Join(dir, "store"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
idx, err := index.New(filepath.Join(dir, "index"), 100, "xapian")
|
||||
if err != nil {
|
||||
t.Skip("xapian not available:", err)
|
||||
}
|
||||
|
||||
dsn := os.Getenv("TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Skip("TEST_DATABASE_URL not set — skipping (needs PostgreSQL)")
|
||||
}
|
||||
|
||||
// Create isolated schemas for this test
|
||||
schemaUsers := "apitest_users_" + strings.ToLower(strings.ReplaceAll(t.Name(), "/", "_"))
|
||||
schemaAudit := "apitest_audit_" + strings.ToLower(strings.ReplaceAll(t.Name(), "/", "_"))
|
||||
if len(schemaUsers) > 63 {
|
||||
schemaUsers = schemaUsers[:63]
|
||||
}
|
||||
if len(schemaAudit) > 63 {
|
||||
schemaAudit = schemaAudit[:63]
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := pgx.Connect(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("connect: %v", err)
|
||||
}
|
||||
conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schemaUsers)
|
||||
conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schemaAudit)
|
||||
conn.Close(ctx)
|
||||
|
||||
sep := "?"
|
||||
if strings.Contains(dsn, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
usersDSN := dsn + sep + "search_path=" + schemaUsers
|
||||
auditDSN := dsn + sep + "search_path=" + schemaAudit
|
||||
|
||||
users, err := userstore.New(usersDSN)
|
||||
if err != nil {
|
||||
t.Fatalf("userstore.New: %v", err)
|
||||
}
|
||||
|
||||
audlog, err := audit.New(auditDSN, dir, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("audit.New: %v", err)
|
||||
}
|
||||
|
||||
// Seed users
|
||||
users.Create(userstore.CreateUserRequest{Username: "admin", Email: "admin@x.com", Password: "adminpass", Role: userstore.RoleAdmin})
|
||||
users.Create(userstore.CreateUserRequest{Username: "auditor", Email: "auditor@x.com", Password: "auditorpass", Role: userstore.RoleAuditor})
|
||||
users.Create(userstore.CreateUserRequest{Username: "user1", Email: "user1@x.com", Password: "userpass", Role: userstore.RoleUser})
|
||||
|
||||
authMgr := auth.New(users, nil, "test-secret-must-be-long-enough-32")
|
||||
cfg := config.APIConfig{Bind: ":18080", Secret: "test-secret-must-be-long-enough-32"}
|
||||
srv := api.New(cfg, store, idx, authMgr, users, audlog, logger)
|
||||
|
||||
t.Cleanup(func() {
|
||||
idx.Close()
|
||||
users.Close()
|
||||
audlog.Close()
|
||||
conn2, _ := pgx.Connect(context.Background(), dsn)
|
||||
if conn2 != nil {
|
||||
conn2.Exec(context.Background(), "DROP SCHEMA "+schemaUsers+" CASCADE")
|
||||
conn2.Exec(context.Background(), "DROP SCHEMA "+schemaAudit+" CASCADE")
|
||||
conn2.Close(context.Background())
|
||||
}
|
||||
})
|
||||
|
||||
return &testEnv{server: srv, users: users, store: store, idx: idx}
|
||||
}
|
||||
|
||||
func (e *testEnv) do(t *testing.T, method, path string, body interface{}, token string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
if body != nil {
|
||||
json.NewEncoder(&buf).Encode(body)
|
||||
}
|
||||
req := httptest.NewRequest(method, path, &buf)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
e.server.ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
func (e *testEnv) login(t *testing.T, username, password string) string {
|
||||
t.Helper()
|
||||
w := e.do(t, "POST", "/api/auth/login",
|
||||
map[string]string{"username": username, "password": password}, "")
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("login %s: status %d, body: %s", username, w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
return resp["token"].(string)
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
func TestHealth(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
w := env.do(t, "GET", "/api/health", nil, "")
|
||||
if w.Code != 200 {
|
||||
t.Errorf("health: status %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginAndMe(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
token := env.login(t, "admin", "adminpass")
|
||||
|
||||
w := env.do(t, "GET", "/api/auth/me", nil, token)
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("me: status %d", w.Code)
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if resp["username"] != "admin" {
|
||||
t.Errorf("me username = %q", resp["username"])
|
||||
}
|
||||
if resp["role"] != "admin" {
|
||||
t.Errorf("me role = %q", resp["role"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginWrongCredentials(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
w := env.do(t, "POST", "/api/auth/login",
|
||||
map[string]string{"username": "admin", "password": "wrong"}, "")
|
||||
if w.Code != 401 {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnauthenticatedSearchBlocked(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
w := env.do(t, "GET", "/api/search?q=test", nil, "")
|
||||
if w.Code != 401 {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
token := env.login(t, "admin", "adminpass")
|
||||
|
||||
w := env.do(t, "POST", "/api/auth/logout", nil, token)
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("logout: status %d", w.Code)
|
||||
}
|
||||
|
||||
// Token should now be invalid
|
||||
w2 := env.do(t, "GET", "/api/auth/me", nil, token)
|
||||
if w2.Code != 401 {
|
||||
t.Errorf("after logout, me should return 401, got %d", w2.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminUserCRUD(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
token := env.login(t, "admin", "adminpass")
|
||||
|
||||
// List users
|
||||
w := env.do(t, "GET", "/api/users", nil, token)
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("list users: status %d", w.Code)
|
||||
}
|
||||
var users []map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &users)
|
||||
if len(users) != 3 { // admin + auditor + user1
|
||||
t.Errorf("expected 3 users, got %d", len(users))
|
||||
}
|
||||
|
||||
// Create user
|
||||
w = env.do(t, "POST", "/api/users",
|
||||
map[string]string{"username": "newuser", "email": "new@x.com", "password": "pw123", "role": "user"},
|
||||
token)
|
||||
if w.Code != 201 {
|
||||
t.Fatalf("create user: status %d, body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonAdminCannotManageUsers(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
token := env.login(t, "user1", "userpass")
|
||||
|
||||
w := env.do(t, "GET", "/api/users", nil, token)
|
||||
if w.Code != 403 {
|
||||
t.Errorf("user role should not list users, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditorCanAccessAuditLog(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
token := env.login(t, "auditor", "auditorpass")
|
||||
|
||||
w := env.do(t, "GET", "/api/audit", nil, token)
|
||||
if w.Code != 200 {
|
||||
t.Errorf("auditor should access audit log, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserCannotAccessAuditLog(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
token := env.login(t, "user1", "userpass")
|
||||
|
||||
w := env.do(t, "GET", "/api/audit", nil, token)
|
||||
if w.Code != 403 {
|
||||
t.Errorf("user role should not access audit log, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchReturnsResults(t *testing.T) {
|
||||
env := newTestEnv(t)
|
||||
token := env.login(t, "admin", "adminpass")
|
||||
|
||||
w := env.do(t, "GET", "/api/search?q=test", nil, token)
|
||||
if w.Code != 200 {
|
||||
t.Fatalf("search: status %d, body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
if _, ok := result["total"]; !ok {
|
||||
t.Error("search response missing 'total' field")
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,196 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const (
|
||||
EventLogin = "login"
|
||||
EventLogout = "logout"
|
||||
EventSearch = "search"
|
||||
EventMailView = "mail_view"
|
||||
EventImport = "import"
|
||||
EventExport = "export"
|
||||
EventUserMgmt = "user_mgmt"
|
||||
)
|
||||
|
||||
// Entry is a single audit log record.
|
||||
type Entry struct {
|
||||
ID int64 `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
EventType string `json:"event_type"`
|
||||
Username string `json:"username"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
Query string `json:"query"`
|
||||
MailID string `json:"mail_id"`
|
||||
Success bool `json:"success"`
|
||||
Detail string `json:"detail"`
|
||||
}
|
||||
|
||||
// QueryFilter specifies filtering options for audit log queries.
|
||||
type QueryFilter struct {
|
||||
Username string
|
||||
EventType string
|
||||
MailID string
|
||||
From *time.Time
|
||||
To *time.Time
|
||||
PageSize int
|
||||
Page int
|
||||
}
|
||||
|
||||
// Logger is a PostgreSQL-backed, append-only audit log.
|
||||
type Logger struct {
|
||||
pool *pgxpool.Pool
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New connects to PostgreSQL using the given DSN and initialises the schema.
|
||||
// logDir is reserved for future flat-file logging.
|
||||
func New(dsn, logDir string, logger *slog.Logger) (*Logger, error) {
|
||||
ctx := context.Background()
|
||||
pool, err := pgxpool.New(ctx, dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("audit: connect: %w", err)
|
||||
}
|
||||
|
||||
_, err = pool.Exec(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS audit_log (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
event_type VARCHAR(50) NOT NULL,
|
||||
username VARCHAR(255) NOT NULL DEFAULT '',
|
||||
ip_address VARCHAR(45) NOT NULL DEFAULT '',
|
||||
query TEXT NOT NULL DEFAULT '',
|
||||
mail_id VARCHAR(255) NOT NULL DEFAULT '',
|
||||
success BOOLEAN NOT NULL DEFAULT true,
|
||||
detail TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("audit: create schema: %w", err)
|
||||
}
|
||||
|
||||
return &Logger{pool: pool, logger: logger}, nil
|
||||
}
|
||||
|
||||
// Log appends an entry to the audit log. Errors are logged but not returned.
|
||||
func (l *Logger) Log(entry Entry) {
|
||||
ts := entry.Timestamp
|
||||
if ts.IsZero() {
|
||||
ts = time.Now().UTC()
|
||||
}
|
||||
ctx := context.Background()
|
||||
_, err := l.pool.Exec(ctx,
|
||||
`INSERT INTO audit_log (timestamp, event_type, username, ip_address, query, mail_id, success, detail)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
|
||||
ts.UTC(),
|
||||
entry.EventType,
|
||||
entry.Username,
|
||||
entry.IPAddress,
|
||||
entry.Query,
|
||||
entry.MailID,
|
||||
entry.Success,
|
||||
entry.Detail,
|
||||
)
|
||||
if err != nil {
|
||||
l.logger.Error("audit: insert failed", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Query retrieves audit entries matching the given filter, returning the
|
||||
// matched entries, the total count (ignoring pagination), and any error.
|
||||
func (l *Logger) Query(filter QueryFilter) ([]Entry, int, error) {
|
||||
pageSize := filter.PageSize
|
||||
if pageSize <= 0 {
|
||||
pageSize = 50
|
||||
}
|
||||
|
||||
where, args := buildWhere(filter)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Count total
|
||||
countSQL := "SELECT COUNT(*) FROM audit_log" + where
|
||||
var total int
|
||||
if err := l.pool.QueryRow(ctx, countSQL, args...).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("audit: count: %w", err)
|
||||
}
|
||||
|
||||
offset := filter.Page * pageSize
|
||||
// Append limit and offset as next positional args
|
||||
limitArg := len(args) + 1
|
||||
offsetArg := len(args) + 2
|
||||
querySQL := fmt.Sprintf(
|
||||
"SELECT id, timestamp, event_type, username, ip_address, query, mail_id, success, detail FROM audit_log%s ORDER BY timestamp DESC LIMIT $%d OFFSET $%d",
|
||||
where, limitArg, offsetArg,
|
||||
)
|
||||
allArgs := append(args, pageSize, offset)
|
||||
|
||||
rows, err := l.pool.Query(ctx, querySQL, allArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("audit: query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []Entry
|
||||
for rows.Next() {
|
||||
var e Entry
|
||||
if err := rows.Scan(&e.ID, &e.Timestamp, &e.EventType, &e.Username, &e.IPAddress, &e.Query, &e.MailID, &e.Success, &e.Detail); err != nil {
|
||||
return nil, 0, fmt.Errorf("audit: scan: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, total, rows.Err()
|
||||
}
|
||||
|
||||
// Close closes the audit connection pool.
|
||||
func (l *Logger) Close() error {
|
||||
l.pool.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildWhere constructs a SQL WHERE clause from QueryFilter fields using
|
||||
// positional parameters ($1, $2, ...) for PostgreSQL.
|
||||
func buildWhere(f QueryFilter) (string, []interface{}) {
|
||||
var clauses []string
|
||||
var args []interface{}
|
||||
n := 1
|
||||
|
||||
if f.Username != "" {
|
||||
clauses = append(clauses, fmt.Sprintf("username = $%d", n))
|
||||
args = append(args, f.Username)
|
||||
n++
|
||||
}
|
||||
if f.EventType != "" {
|
||||
clauses = append(clauses, fmt.Sprintf("event_type = $%d", n))
|
||||
args = append(args, f.EventType)
|
||||
n++
|
||||
}
|
||||
if f.MailID != "" {
|
||||
clauses = append(clauses, fmt.Sprintf("mail_id = $%d", n))
|
||||
args = append(args, f.MailID)
|
||||
n++
|
||||
}
|
||||
if f.From != nil {
|
||||
clauses = append(clauses, fmt.Sprintf("timestamp >= $%d", n))
|
||||
args = append(args, f.From.UTC())
|
||||
n++
|
||||
}
|
||||
if f.To != nil {
|
||||
clauses = append(clauses, fmt.Sprintf("timestamp <= $%d", n))
|
||||
args = append(args, f.To.UTC())
|
||||
n++
|
||||
}
|
||||
|
||||
if len(clauses) == 0 {
|
||||
return "", args
|
||||
}
|
||||
return " WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package audit_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/archivmail/internal/audit"
|
||||
)
|
||||
|
||||
func newTestAudit(t *testing.T) *audit.Logger {
|
||||
t.Helper()
|
||||
dsn := os.Getenv("TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Skip("TEST_DATABASE_URL not set — skipping (needs PostgreSQL)")
|
||||
}
|
||||
schema := "autest_" + strings.ToLower(strings.ReplaceAll(t.Name(), "/", "_"))
|
||||
// truncate schema name to 63 chars (PostgreSQL limit)
|
||||
if len(schema) > 63 {
|
||||
schema = schema[:63]
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := pgx.Connect(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("connect: %v", err)
|
||||
}
|
||||
conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schema)
|
||||
conn.Close(ctx)
|
||||
|
||||
sep := "?"
|
||||
if strings.Contains(dsn, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
schemaDSN := dsn + sep + "search_path=" + schema
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Discard, nil))
|
||||
l, err := audit.New(schemaDSN, t.TempDir(), logger)
|
||||
if err != nil {
|
||||
t.Fatalf("audit.New: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
l.Close()
|
||||
conn2, _ := pgx.Connect(context.Background(), dsn)
|
||||
if conn2 != nil {
|
||||
conn2.Exec(context.Background(), "DROP SCHEMA "+schema+" CASCADE")
|
||||
conn2.Close(context.Background())
|
||||
}
|
||||
})
|
||||
return l
|
||||
}
|
||||
|
||||
func TestLogAndQuery(t *testing.T) {
|
||||
l := newTestAudit(t)
|
||||
|
||||
l.Log(audit.Entry{
|
||||
EventType: audit.EventLogin,
|
||||
Username: "alice",
|
||||
IPAddress: "192.168.1.1",
|
||||
Success: true,
|
||||
})
|
||||
l.Log(audit.Entry{
|
||||
EventType: audit.EventSearch,
|
||||
Username: "alice",
|
||||
IPAddress: "192.168.1.1",
|
||||
Query: "invoice",
|
||||
Success: true,
|
||||
})
|
||||
l.Log(audit.Entry{
|
||||
EventType: audit.EventLogin,
|
||||
Username: "bob",
|
||||
IPAddress: "10.0.0.1",
|
||||
Success: false,
|
||||
Detail: "wrong password",
|
||||
})
|
||||
|
||||
all, total, err := l.Query(audit.QueryFilter{PageSize: 50})
|
||||
if err != nil {
|
||||
t.Fatalf("Query all: %v", err)
|
||||
}
|
||||
if total != 3 {
|
||||
t.Errorf("expected 3 entries, got %d", total)
|
||||
}
|
||||
_ = all
|
||||
}
|
||||
|
||||
func TestQueryByUsername(t *testing.T) {
|
||||
l := newTestAudit(t)
|
||||
|
||||
l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true})
|
||||
l.Log(audit.Entry{EventType: audit.EventSearch, Username: "alice", Success: true})
|
||||
l.Log(audit.Entry{EventType: audit.EventLogin, Username: "bob", Success: true})
|
||||
|
||||
entries, total, _ := l.Query(audit.QueryFilter{Username: "alice", PageSize: 50})
|
||||
if total != 2 {
|
||||
t.Errorf("alice: expected 2 entries, got %d", total)
|
||||
}
|
||||
for _, e := range entries {
|
||||
if e.Username != "alice" {
|
||||
t.Errorf("got entry for user %q in alice filter", e.Username)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryByEventType(t *testing.T) {
|
||||
l := newTestAudit(t)
|
||||
|
||||
l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true})
|
||||
l.Log(audit.Entry{EventType: audit.EventSearch, Username: "alice", Success: true})
|
||||
l.Log(audit.Entry{EventType: audit.EventMailView, Username: "alice", MailID: "abc123", Success: true})
|
||||
|
||||
_, total, _ := l.Query(audit.QueryFilter{EventType: audit.EventSearch, PageSize: 50})
|
||||
if total != 1 {
|
||||
t.Errorf("search event filter: expected 1, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryByMailID(t *testing.T) {
|
||||
l := newTestAudit(t)
|
||||
|
||||
l.Log(audit.Entry{EventType: audit.EventMailView, Username: "alice", MailID: "mail-001", Success: true})
|
||||
l.Log(audit.Entry{EventType: audit.EventMailView, Username: "bob", MailID: "mail-001", Success: true})
|
||||
l.Log(audit.Entry{EventType: audit.EventMailView, Username: "alice", MailID: "mail-002", Success: true})
|
||||
|
||||
_, total, _ := l.Query(audit.QueryFilter{MailID: "mail-001", PageSize: 50})
|
||||
if total != 2 {
|
||||
t.Errorf("mailID filter: expected 2, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryDateRange(t *testing.T) {
|
||||
l := newTestAudit(t)
|
||||
|
||||
l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true})
|
||||
l.Log(audit.Entry{EventType: audit.EventLogin, Username: "bob", Success: true})
|
||||
|
||||
// Query with future date range — should return 0
|
||||
future := time.Now().Add(24 * time.Hour)
|
||||
futureEnd := time.Now().Add(48 * time.Hour)
|
||||
_, total, _ := l.Query(audit.QueryFilter{From: &future, To: &futureEnd, PageSize: 50})
|
||||
if total != 0 {
|
||||
t.Errorf("future date range should return 0, got %d", total)
|
||||
}
|
||||
|
||||
// Query with past-to-now range — should return all
|
||||
past := time.Now().Add(-1 * time.Minute)
|
||||
now := time.Now().Add(1 * time.Minute)
|
||||
_, total, _ = l.Query(audit.QueryFilter{From: &past, To: &now, PageSize: 50})
|
||||
if total != 2 {
|
||||
t.Errorf("current date range should return 2, got %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryPagination(t *testing.T) {
|
||||
l := newTestAudit(t)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
l.Log(audit.Entry{EventType: audit.EventSearch, Username: "alice", Success: true})
|
||||
}
|
||||
|
||||
page0, total, _ := l.Query(audit.QueryFilter{PageSize: 4, Page: 0})
|
||||
_, _, _ = l.Query(audit.QueryFilter{PageSize: 4, Page: 1})
|
||||
page2, _, _ := l.Query(audit.QueryFilter{PageSize: 4, Page: 2})
|
||||
|
||||
if total != 10 {
|
||||
t.Errorf("total = %d, want 10", total)
|
||||
}
|
||||
if len(page0) != 4 {
|
||||
t.Errorf("page 0 len = %d, want 4", len(page0))
|
||||
}
|
||||
if len(page2) != 2 {
|
||||
t.Errorf("page 2 len = %d, want 2", len(page2))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/archivmail/internal/userstore"
|
||||
)
|
||||
|
||||
// Session holds the claims extracted from a validated JWT.
|
||||
type Session struct {
|
||||
UserID int64
|
||||
Username string
|
||||
Role string
|
||||
JTI string // unique JWT ID
|
||||
}
|
||||
|
||||
// Manager handles login, token issuance, validation, and logout.
|
||||
type Manager struct {
|
||||
store *userstore.Store
|
||||
ldap interface{} // placeholder for LDAP provider
|
||||
jwtSecret []byte
|
||||
}
|
||||
|
||||
// New creates a new auth Manager.
|
||||
func New(store *userstore.Store, ldap interface{}, jwtSecret string) *Manager {
|
||||
return &Manager{
|
||||
store: store,
|
||||
ldap: ldap,
|
||||
jwtSecret: []byte(jwtSecret),
|
||||
}
|
||||
}
|
||||
|
||||
// Login verifies credentials and returns a signed JWT token.
|
||||
func (m *Manager) Login(username, password string) (string, *userstore.User, error) {
|
||||
user, err := m.store.VerifyPassword(username, password)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("auth: login: %w", err)
|
||||
}
|
||||
|
||||
jti := generateJTI()
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
"sub": user.Username,
|
||||
"role": user.Role,
|
||||
"uid": user.ID,
|
||||
"jti": jti,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(8 * time.Hour).Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := token.SignedString(m.jwtSecret)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("auth: sign token: %w", err)
|
||||
}
|
||||
|
||||
return signed, user, nil
|
||||
}
|
||||
|
||||
// ValidateToken parses and validates the token, checking the blacklist.
|
||||
func (m *Manager) ValidateToken(tokenStr string) (*Session, error) {
|
||||
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("auth: unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return m.jwtSecret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth: invalid token: %w", err)
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, errors.New("auth: token not valid")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("auth: bad claims")
|
||||
}
|
||||
|
||||
jti, _ := claims["jti"].(string)
|
||||
blacklisted, err := m.store.IsBlacklisted(jti)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth: blacklist check: %w", err)
|
||||
}
|
||||
if blacklisted {
|
||||
return nil, errors.New("auth: token revoked")
|
||||
}
|
||||
|
||||
username, _ := claims["sub"].(string)
|
||||
role, _ := claims["role"].(string)
|
||||
|
||||
var userID int64
|
||||
switch v := claims["uid"].(type) {
|
||||
case float64:
|
||||
userID = int64(v)
|
||||
case int64:
|
||||
userID = v
|
||||
}
|
||||
|
||||
return &Session{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Role: role,
|
||||
JTI: jti,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Logout revokes the token by adding its JTI to the blacklist.
|
||||
func (m *Manager) Logout(tokenStr string) error {
|
||||
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("auth: unexpected signing method")
|
||||
}
|
||||
return m.jwtSecret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("auth: logout parse: %w", err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return errors.New("auth: bad claims on logout")
|
||||
}
|
||||
|
||||
jti, _ := claims["jti"].(string)
|
||||
var exp time.Time
|
||||
switch v := claims["exp"].(type) {
|
||||
case float64:
|
||||
exp = time.Unix(int64(v), 0)
|
||||
case int64:
|
||||
exp = time.Unix(v, 0)
|
||||
default:
|
||||
exp = time.Now().Add(8 * time.Hour)
|
||||
}
|
||||
|
||||
return m.store.BlacklistToken(jti, exp)
|
||||
}
|
||||
|
||||
// HasRole returns true when userRole satisfies the required role level.
|
||||
// Hierarchy: admin > auditor > user
|
||||
func HasRole(userRole, required string) bool {
|
||||
levels := map[string]int{
|
||||
userstore.RoleUser: 1,
|
||||
userstore.RoleAuditor: 2,
|
||||
userstore.RoleAdmin: 3,
|
||||
}
|
||||
return levels[userRole] >= levels[required]
|
||||
}
|
||||
|
||||
// generateJTI returns a pseudo-unique identifier for a JWT.
|
||||
func generateJTI() string {
|
||||
return fmt.Sprintf("%d-%x", time.Now().UnixNano(), time.Now().UnixNano()^0xdeadbeef)
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/archivmail/internal/auth"
|
||||
"github.com/archivmail/internal/userstore"
|
||||
)
|
||||
|
||||
func newTestAuth(t *testing.T) (*auth.Manager, *userstore.Store) {
|
||||
t.Helper()
|
||||
store, err := userstore.New(filepath.Join(t.TempDir(), "users.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("userstore.New: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { store.Close() })
|
||||
|
||||
// Seed a test user
|
||||
store.Create(userstore.CreateUserRequest{
|
||||
Username: "testadmin",
|
||||
Email: "admin@example.com",
|
||||
Password: "adminpass",
|
||||
Role: userstore.RoleAdmin,
|
||||
})
|
||||
store.Create(userstore.CreateUserRequest{
|
||||
Username: "regularuser",
|
||||
Email: "user@example.com",
|
||||
Password: "userpass",
|
||||
Role: userstore.RoleUser,
|
||||
})
|
||||
|
||||
mgr := auth.New(store, nil, "test-jwt-secret-32chars-long-enough")
|
||||
return mgr, store
|
||||
}
|
||||
|
||||
func TestLoginSuccess(t *testing.T) {
|
||||
mgr, _ := newTestAuth(t)
|
||||
|
||||
token, user, err := mgr.Login("testadmin", "adminpass")
|
||||
if err != nil {
|
||||
t.Fatalf("Login: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Error("expected non-empty token")
|
||||
}
|
||||
if user.Username != "testadmin" {
|
||||
t.Errorf("Username = %q", user.Username)
|
||||
}
|
||||
if user.Role != userstore.RoleAdmin {
|
||||
t.Errorf("Role = %q", user.Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginWrongPassword(t *testing.T) {
|
||||
mgr, _ := newTestAuth(t)
|
||||
|
||||
if _, _, err := mgr.Login("testadmin", "wrongpass"); err == nil {
|
||||
t.Error("expected error for wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginUnknownUser(t *testing.T) {
|
||||
mgr, _ := newTestAuth(t)
|
||||
|
||||
if _, _, err := mgr.Login("nobody", "pw"); err == nil {
|
||||
t.Error("expected error for unknown user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenValidation(t *testing.T) {
|
||||
mgr, _ := newTestAuth(t)
|
||||
|
||||
token, _, _ := mgr.Login("testadmin", "adminpass")
|
||||
sess, err := mgr.ValidateToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken: %v", err)
|
||||
}
|
||||
if sess.Username != "testadmin" {
|
||||
t.Errorf("Session Username = %q", sess.Username)
|
||||
}
|
||||
if sess.Role != userstore.RoleAdmin {
|
||||
t.Errorf("Session Role = %q", sess.Role)
|
||||
}
|
||||
if sess.JTI == "" {
|
||||
t.Error("Session JTI should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenTampering(t *testing.T) {
|
||||
mgr, _ := newTestAuth(t)
|
||||
|
||||
token, _, _ := mgr.Login("testadmin", "adminpass")
|
||||
tampered := token + "x"
|
||||
|
||||
if _, err := mgr.ValidateToken(tampered); err == nil {
|
||||
t.Error("tampered token should fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
mgr, _ := newTestAuth(t)
|
||||
|
||||
token, _, _ := mgr.Login("testadmin", "adminpass")
|
||||
|
||||
// Token valid before logout
|
||||
if _, err := mgr.ValidateToken(token); err != nil {
|
||||
t.Fatalf("token should be valid before logout: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.Logout(token); err != nil {
|
||||
t.Fatalf("Logout: %v", err)
|
||||
}
|
||||
|
||||
// Token invalid after logout
|
||||
if _, err := mgr.ValidateToken(token); err == nil {
|
||||
t.Error("token should be invalid after logout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasRole(t *testing.T) {
|
||||
tests := []struct {
|
||||
userRole string
|
||||
required string
|
||||
want bool
|
||||
}{
|
||||
{userstore.RoleAdmin, userstore.RoleAdmin, true},
|
||||
{userstore.RoleAdmin, userstore.RoleAuditor, true},
|
||||
{userstore.RoleAdmin, userstore.RoleUser, true},
|
||||
{userstore.RoleAuditor, userstore.RoleAdmin, false},
|
||||
{userstore.RoleAuditor, userstore.RoleAuditor, true},
|
||||
{userstore.RoleAuditor, userstore.RoleUser, true},
|
||||
{userstore.RoleUser, userstore.RoleAdmin, false},
|
||||
{userstore.RoleUser, userstore.RoleAuditor, false},
|
||||
{userstore.RoleUser, userstore.RoleUser, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := auth.HasRole(tt.userRole, tt.required)
|
||||
if got != tt.want {
|
||||
t.Errorf("HasRole(%q, %q) = %v, want %v", tt.userRole, tt.required, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleSessionsIndependent(t *testing.T) {
|
||||
mgr, _ := newTestAuth(t)
|
||||
|
||||
token1, _, _ := mgr.Login("testadmin", "adminpass")
|
||||
token2, _, _ := mgr.Login("testadmin", "adminpass")
|
||||
|
||||
if token1 == token2 {
|
||||
t.Error("two logins should produce different tokens (different JTIs)")
|
||||
}
|
||||
|
||||
// Logout session 1, session 2 should still work
|
||||
mgr.Logout(token1)
|
||||
if _, err := mgr.ValidateToken(token2); err != nil {
|
||||
t.Errorf("session 2 should still be valid after session 1 logout: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package imap
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
imapv2 "github.com/emersion/go-imap/v2"
|
||||
"github.com/emersion/go-imap/v2/imapclient"
|
||||
)
|
||||
|
||||
// FolderInfo describes a single IMAP folder with exclusion metadata.
|
||||
type FolderInfo struct {
|
||||
Name string `json:"name"`
|
||||
Excluded bool `json:"excluded"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
// junkTrashNames lists well-known junk/trash folder names for fallback detection.
|
||||
var junkTrashNames = []string{
|
||||
"junk", "spam", "trash", "deleted items",
|
||||
"deleted messages", "papierkorb", "gelöschte elemente",
|
||||
}
|
||||
|
||||
// Connect establishes an IMAP client connection using the specified TLS mode.
|
||||
func Connect(host string, port int, tlsMode string) (*imapclient.Client, error) {
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
|
||||
switch tlsMode {
|
||||
case "ssl":
|
||||
c, err := imapclient.DialTLS(addr, &imapclient.Options{
|
||||
TLSConfig: &tls.Config{ServerName: host},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("imap connect ssl: %w", err)
|
||||
}
|
||||
return c, nil
|
||||
case "starttls":
|
||||
c, err := imapclient.DialStartTLS(addr, &imapclient.Options{
|
||||
TLSConfig: &tls.Config{ServerName: host},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("imap connect starttls: %w", err)
|
||||
}
|
||||
return c, nil
|
||||
case "none":
|
||||
c, err := imapclient.DialInsecure(addr, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("imap connect plain: %w", err)
|
||||
}
|
||||
return c, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("imap: unknown tls mode %q", tlsMode)
|
||||
}
|
||||
}
|
||||
|
||||
// ListFolders retrieves all mailbox folders and detects junk/trash folders.
|
||||
func ListFolders(c *imapclient.Client) ([]FolderInfo, error) {
|
||||
listCmd := c.List("", "*", nil)
|
||||
mailboxes, err := listCmd.Collect()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("imap list folders: %w", err)
|
||||
}
|
||||
|
||||
var folders []FolderInfo
|
||||
for _, mb := range mailboxes {
|
||||
fi := FolderInfo{Name: mb.Mailbox}
|
||||
|
||||
// Check special-use attributes (RFC 6154)
|
||||
for _, attr := range mb.Attrs {
|
||||
if attr == imapv2.MailboxAttrJunk {
|
||||
fi.Excluded = true
|
||||
fi.Reason = "special_use"
|
||||
break
|
||||
}
|
||||
if attr == imapv2.MailboxAttrTrash {
|
||||
fi.Excluded = true
|
||||
fi.Reason = "special_use"
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: case-insensitive name matching
|
||||
if !fi.Excluded {
|
||||
lower := strings.ToLower(mb.Mailbox)
|
||||
for _, jt := range junkTrashNames {
|
||||
if lower == jt {
|
||||
fi.Excluded = true
|
||||
fi.Reason = "name_match"
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
folders = append(folders, fi)
|
||||
}
|
||||
|
||||
return folders, nil
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
package imap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
imapv2 "github.com/emersion/go-imap/v2"
|
||||
"github.com/emersion/go-imap/v2/imapclient"
|
||||
"github.com/archivmail/internal/index"
|
||||
"github.com/archivmail/internal/storage"
|
||||
"github.com/archivmail/pkg/mailparser"
|
||||
)
|
||||
|
||||
const batchSize = 50
|
||||
|
||||
// Importer runs background IMAP import jobs.
|
||||
type Importer struct {
|
||||
store *Store
|
||||
mailStore *storage.Store
|
||||
idx index.Indexer
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewImporter creates a new Importer wired to the storage and index backends.
|
||||
func NewImporter(store *Store, mailStore *storage.Store, idx index.Indexer, logger *slog.Logger) *Importer {
|
||||
return &Importer{
|
||||
store: store,
|
||||
mailStore: mailStore,
|
||||
idx: idx,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Run performs a full IMAP import for the given account. It is designed to be
|
||||
// called as a goroutine: go imp.Run(context.Background(), accountID)
|
||||
func (imp *Importer) Run(ctx context.Context, accountID int64) {
|
||||
log := imp.logger.With("component", "imap-importer", "account_id", accountID)
|
||||
|
||||
acc, err := imp.store.Get(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Error("failed to get account", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
password, err := imp.store.GetPassword(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Error("failed to decrypt password", "err", err)
|
||||
_ = imp.store.UpdateStatus(ctx, accountID, "error", "failed to decrypt password", 0, 0)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as running
|
||||
if err := imp.store.UpdateStatus(ctx, accountID, "running", "", 0, 0); err != nil {
|
||||
log.Error("failed to update status", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
imported, err := imp.doImport(ctx, acc, password, log)
|
||||
if err != nil {
|
||||
log.Error("import failed", "err", err)
|
||||
_ = imp.store.UpdateStatus(ctx, accountID, "error", err.Error(), 0, 0)
|
||||
return
|
||||
}
|
||||
|
||||
if err := imp.store.UpdateDone(ctx, accountID, imported); err != nil {
|
||||
log.Error("failed to update done", "err", err)
|
||||
}
|
||||
|
||||
log.Info("import completed", "imported", imported)
|
||||
}
|
||||
|
||||
// doImport handles the actual IMAP connection, folder iteration, and message fetching.
|
||||
func (imp *Importer) doImport(ctx context.Context, acc *Account, password string, log *slog.Logger) (int, error) {
|
||||
c, err := Connect(acc.Host, acc.Port, acc.TLS)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("connect: %w", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// Login
|
||||
if err := c.Login(acc.Username, password).Wait(); err != nil {
|
||||
return 0, fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
// List all folders
|
||||
folders, err := ListFolders(c)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("list folders: %w", err)
|
||||
}
|
||||
|
||||
// Build excluded set from account config
|
||||
excluded := make(map[string]bool)
|
||||
for _, f := range acc.ExcludedFolders {
|
||||
excluded[f] = true
|
||||
}
|
||||
|
||||
// Collect included folders
|
||||
var includedFolders []string
|
||||
for _, f := range folders {
|
||||
if !excluded[f.Name] {
|
||||
includedFolders = append(includedFolders, f.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Count total messages across all folders first
|
||||
totalMsgs := 0
|
||||
folderUIDs := make(map[string][]imapv2.UID)
|
||||
|
||||
for _, folder := range includedFolders {
|
||||
selectData, err := c.Select(folder, nil).Wait()
|
||||
if err != nil {
|
||||
log.Warn("failed to select folder, skipping", "folder", folder, "err", err)
|
||||
continue
|
||||
}
|
||||
_ = selectData
|
||||
|
||||
searchCmd := c.UIDSearch(&imapv2.SearchCriteria{}, nil)
|
||||
searchData, err := searchCmd.Wait()
|
||||
if err != nil {
|
||||
log.Warn("failed to search folder, skipping", "folder", folder, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
uids := searchData.AllUIDs()
|
||||
folderUIDs[folder] = uids
|
||||
totalMsgs += len(uids)
|
||||
}
|
||||
|
||||
log.Info("starting import", "folders", len(includedFolders), "total_messages", totalMsgs)
|
||||
_ = imp.store.UpdateStatus(ctx, acc.ID, "running", "", 0, totalMsgs)
|
||||
|
||||
imported := 0
|
||||
processed := 0
|
||||
|
||||
for _, folder := range includedFolders {
|
||||
uids, ok := folderUIDs[folder]
|
||||
if !ok || len(uids) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Need to re-select the folder before fetching
|
||||
if _, err := c.Select(folder, nil).Wait(); err != nil {
|
||||
log.Warn("failed to re-select folder", "folder", folder, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info("importing folder", "folder", folder, "messages", len(uids))
|
||||
|
||||
// Process in batches
|
||||
for i := 0; i < len(uids); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(uids) {
|
||||
end = len(uids)
|
||||
}
|
||||
batch := uids[i:end]
|
||||
|
||||
count, err := imp.fetchBatch(ctx, c, batch, log)
|
||||
if err != nil {
|
||||
log.Error("batch fetch error", "folder", folder, "offset", i, "err", err)
|
||||
// Continue with the next batch rather than aborting entirely
|
||||
continue
|
||||
}
|
||||
|
||||
imported += count
|
||||
processed += len(batch)
|
||||
|
||||
_ = imp.store.UpdateStatus(ctx, acc.ID, "running", "", processed, totalMsgs)
|
||||
}
|
||||
}
|
||||
|
||||
return imported, nil
|
||||
}
|
||||
|
||||
// fetchBatch fetches and stores a batch of messages by UID.
|
||||
func (imp *Importer) fetchBatch(ctx context.Context, c *imapclient.Client, uids []imapv2.UID, log *slog.Logger) (int, error) {
|
||||
if len(uids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
fetchOptions := &imapv2.FetchOptions{
|
||||
UID: true,
|
||||
BodySection: []*imapv2.FetchItemBodySection{{}},
|
||||
}
|
||||
|
||||
seqSet := imapv2.UIDSetNum(uids...)
|
||||
fetchCmd := c.Fetch(seqSet, fetchOptions)
|
||||
|
||||
imported := 0
|
||||
for {
|
||||
msg := fetchCmd.Next()
|
||||
if msg == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Collect body sections from this message
|
||||
for {
|
||||
item := msg.Next()
|
||||
if item == nil {
|
||||
break
|
||||
}
|
||||
|
||||
switch body := item.(type) {
|
||||
case imapclient.FetchItemDataBodySection:
|
||||
raw, err := io.ReadAll(body.Literal)
|
||||
if err != nil {
|
||||
log.Warn("failed to read message body", "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := imp.storeAndIndex(raw, log); err != nil {
|
||||
log.Warn("failed to store/index message", "err", err)
|
||||
continue
|
||||
}
|
||||
imported++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := fetchCmd.Close(); err != nil {
|
||||
return imported, fmt.Errorf("fetch close: %w", err)
|
||||
}
|
||||
|
||||
return imported, nil
|
||||
}
|
||||
|
||||
// storeAndIndex saves a raw email to storage and indexes it.
|
||||
func (imp *Importer) storeAndIndex(raw []byte, log *slog.Logger) error {
|
||||
// Save to file storage (deduplicates by SHA256 automatically)
|
||||
id, err := imp.mailStore.Save(raw, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("save: %w", err)
|
||||
}
|
||||
|
||||
// Parse for indexing
|
||||
pm, err := mailparser.Parse(raw)
|
||||
if err != nil {
|
||||
log.Warn("failed to parse mail for indexing", "id", id, "err", err)
|
||||
// Store succeeded, just skip indexing for unparseable mails
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build attachment names string
|
||||
var attachNames []string
|
||||
for _, a := range pm.Attachments {
|
||||
if a.Filename != "" {
|
||||
attachNames = append(attachNames, a.Filename)
|
||||
}
|
||||
}
|
||||
|
||||
doc := index.MailDocument{
|
||||
ID: id,
|
||||
From: pm.From,
|
||||
To: strings.Join(pm.To, ", "),
|
||||
Subject: pm.Subject,
|
||||
Body: pm.TextBody,
|
||||
AttachNames: strings.Join(attachNames, " "),
|
||||
HasAttachment: len(pm.Attachments) > 0,
|
||||
Date: pm.Date,
|
||||
Size: int64(len(raw)),
|
||||
}
|
||||
|
||||
if err := imp.idx.IndexSync(doc); err != nil {
|
||||
log.Warn("failed to index mail", "id", id, "err", err)
|
||||
// Non-fatal: mail is stored, just not searchable yet
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
package imap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// Account represents an IMAP account configuration stored in the database.
|
||||
type Account struct {
|
||||
ID int64 `json:"id"`
|
||||
Owner string `json:"owner"`
|
||||
Name string `json:"name"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
TLS string `json:"tls"`
|
||||
Username string `json:"username"`
|
||||
ExcludedFolders []string `json:"excluded_folders"`
|
||||
Status string `json:"status"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
LastImportAt *time.Time `json:"last_import_at,omitempty"`
|
||||
LastImportCount int `json:"last_import_count"`
|
||||
ProgressCurrent int `json:"progress_current"`
|
||||
ProgressTotal int `json:"progress_total"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// Store manages IMAP account persistence in PostgreSQL.
|
||||
type Store struct {
|
||||
pool *pgxpool.Pool
|
||||
encKey [32]byte
|
||||
}
|
||||
|
||||
const createTableSQL = `
|
||||
CREATE TABLE IF NOT EXISTS imap_accounts (
|
||||
id SERIAL PRIMARY KEY,
|
||||
owner TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
host TEXT NOT NULL,
|
||||
port INTEGER NOT NULL DEFAULT 993,
|
||||
tls TEXT NOT NULL DEFAULT 'ssl',
|
||||
username TEXT NOT NULL,
|
||||
password_enc BYTEA NOT NULL,
|
||||
excluded_folders TEXT[] NOT NULL DEFAULT '{}',
|
||||
status TEXT NOT NULL DEFAULT 'idle',
|
||||
error_msg TEXT NOT NULL DEFAULT '',
|
||||
last_import_at TIMESTAMPTZ,
|
||||
last_import_count INTEGER NOT NULL DEFAULT 0,
|
||||
progress_current INTEGER NOT NULL DEFAULT 0,
|
||||
progress_total INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_imap_accounts_owner ON imap_accounts (owner);
|
||||
`
|
||||
|
||||
// New creates a new Store, connects to PostgreSQL, and runs the migration.
|
||||
func New(dsn, secret string) (*Store, error) {
|
||||
pool, err := pgxpool.New(context.Background(), dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("imap store: connect: %w", err)
|
||||
}
|
||||
|
||||
if _, err := pool.Exec(context.Background(), createTableSQL); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("imap store: migrate: %w", err)
|
||||
}
|
||||
|
||||
key := sha256.Sum256([]byte(secret))
|
||||
return &Store{pool: pool, encKey: key}, nil
|
||||
}
|
||||
|
||||
// Close releases the database connection pool.
|
||||
func (s *Store) Close() {
|
||||
s.pool.Close()
|
||||
}
|
||||
|
||||
// Create inserts a new IMAP account with an encrypted password.
|
||||
func (s *Store) Create(ctx context.Context, acc Account, password string) (*Account, error) {
|
||||
enc, err := encryptPassword(password, s.encKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("imap store: encrypt password: %w", err)
|
||||
}
|
||||
|
||||
row := s.pool.QueryRow(ctx, `
|
||||
INSERT INTO imap_accounts (owner, name, host, port, tls, username, password_enc, excluded_folders)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, created_at`,
|
||||
acc.Owner, acc.Name, acc.Host, acc.Port, acc.TLS, acc.Username, enc, acc.ExcludedFolders,
|
||||
)
|
||||
|
||||
if err := row.Scan(&acc.ID, &acc.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("imap store: create: %w", err)
|
||||
}
|
||||
|
||||
acc.Status = "idle"
|
||||
acc.ErrorMsg = ""
|
||||
return &acc, nil
|
||||
}
|
||||
|
||||
// List returns IMAP accounts. Admins see all accounts; regular users see only their own.
|
||||
func (s *Store) List(ctx context.Context, owner string, isAdmin bool) ([]Account, error) {
|
||||
var rows pgx.Rows
|
||||
var err error
|
||||
|
||||
if isAdmin {
|
||||
rows, err = s.pool.Query(ctx, `
|
||||
SELECT id, owner, name, host, port, tls, username, excluded_folders,
|
||||
status, error_msg, last_import_at, last_import_count,
|
||||
progress_current, progress_total, created_at
|
||||
FROM imap_accounts ORDER BY id`)
|
||||
} else {
|
||||
rows, err = s.pool.Query(ctx, `
|
||||
SELECT id, owner, name, host, port, tls, username, excluded_folders,
|
||||
status, error_msg, last_import_at, last_import_count,
|
||||
progress_current, progress_total, created_at
|
||||
FROM imap_accounts WHERE owner = $1 ORDER BY id`, owner)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("imap store: list: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var accounts []Account
|
||||
for rows.Next() {
|
||||
var a Account
|
||||
if err := rows.Scan(
|
||||
&a.ID, &a.Owner, &a.Name, &a.Host, &a.Port, &a.TLS, &a.Username,
|
||||
&a.ExcludedFolders, &a.Status, &a.ErrorMsg, &a.LastImportAt,
|
||||
&a.LastImportCount, &a.ProgressCurrent, &a.ProgressTotal, &a.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("imap store: scan: %w", err)
|
||||
}
|
||||
accounts = append(accounts, a)
|
||||
}
|
||||
return accounts, rows.Err()
|
||||
}
|
||||
|
||||
// Get returns a single IMAP account by ID.
|
||||
func (s *Store) Get(ctx context.Context, id int64) (*Account, error) {
|
||||
var a Account
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT id, owner, name, host, port, tls, username, excluded_folders,
|
||||
status, error_msg, last_import_at, last_import_count,
|
||||
progress_current, progress_total, created_at
|
||||
FROM imap_accounts WHERE id = $1`, id,
|
||||
).Scan(
|
||||
&a.ID, &a.Owner, &a.Name, &a.Host, &a.Port, &a.TLS, &a.Username,
|
||||
&a.ExcludedFolders, &a.Status, &a.ErrorMsg, &a.LastImportAt,
|
||||
&a.LastImportCount, &a.ProgressCurrent, &a.ProgressTotal, &a.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("imap store: get %d: %w", id, err)
|
||||
}
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
// GetPassword retrieves and decrypts the stored password for an IMAP account.
|
||||
func (s *Store) GetPassword(ctx context.Context, id int64) (string, error) {
|
||||
var enc []byte
|
||||
err := s.pool.QueryRow(ctx, `SELECT password_enc FROM imap_accounts WHERE id = $1`, id).Scan(&enc)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("imap store: get password: %w", err)
|
||||
}
|
||||
return decryptPassword(enc, s.encKey)
|
||||
}
|
||||
|
||||
// Delete removes an IMAP account by ID.
|
||||
func (s *Store) Delete(ctx context.Context, id int64) error {
|
||||
tag, err := s.pool.Exec(ctx, `DELETE FROM imap_accounts WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("imap store: delete: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("imap store: account %d not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateExcluded sets the list of excluded folders for an account.
|
||||
func (s *Store) UpdateExcluded(ctx context.Context, id int64, excluded []string) error {
|
||||
_, err := s.pool.Exec(ctx, `UPDATE imap_accounts SET excluded_folders = $1 WHERE id = $2`, excluded, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("imap store: update excluded: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus updates the import progress and status of an account.
|
||||
func (s *Store) UpdateStatus(ctx context.Context, id int64, status, errMsg string, current, total int) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE imap_accounts
|
||||
SET status = $1, error_msg = $2, progress_current = $3, progress_total = $4
|
||||
WHERE id = $5`, status, errMsg, current, total, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("imap store: update status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateDone marks an import as completed, setting status back to idle.
|
||||
func (s *Store) UpdateDone(ctx context.Context, id int64, count int) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
UPDATE imap_accounts
|
||||
SET status = 'idle', error_msg = '', last_import_at = now(),
|
||||
last_import_count = $1, progress_current = 0, progress_total = 0
|
||||
WHERE id = $2`, count, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("imap store: update done: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptPassword encrypts a plaintext password using AES-256-GCM.
|
||||
func encryptPassword(plaintext string, key [32]byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return gcm.Seal(nonce, nonce, []byte(plaintext), nil), nil
|
||||
}
|
||||
|
||||
// decryptPassword decrypts a password previously encrypted with encryptPassword.
|
||||
func decryptPassword(ciphertext []byte, key [32]byte) (string, error) {
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(ciphertext) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
nonce, ct := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ct, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt failed: %w", err)
|
||||
}
|
||||
return string(plaintext), nil
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package index
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MailDocument is the indexed representation of a stored email.
|
||||
type MailDocument struct {
|
||||
ID string
|
||||
From string
|
||||
To string
|
||||
Subject string
|
||||
Body string
|
||||
AttachNames string
|
||||
HasAttachment bool
|
||||
Date time.Time
|
||||
Size int64
|
||||
}
|
||||
|
||||
// SearchRequest specifies search parameters.
|
||||
type SearchRequest struct {
|
||||
Query string
|
||||
From string
|
||||
To string
|
||||
OwnEmail string
|
||||
DateFrom *time.Time
|
||||
DateTo *time.Time
|
||||
PageSize int
|
||||
Page int
|
||||
}
|
||||
|
||||
// Hit is a single search result.
|
||||
type Hit struct {
|
||||
ID string `json:"id"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
// SearchResult holds paginated search results.
|
||||
type SearchResult struct {
|
||||
Total int
|
||||
Hits []Hit
|
||||
}
|
||||
|
||||
// Indexer is the interface for full-text email indexing.
|
||||
type Indexer interface {
|
||||
IndexSync(doc MailDocument) error
|
||||
Search(req SearchRequest) (*SearchResult, error)
|
||||
Delete(id string) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
// New creates an Indexer for the specified backend.
|
||||
func New(dir string, batchSize int, backend string) (Indexer, error) {
|
||||
switch backend {
|
||||
case "xapian":
|
||||
return newXapian(dir)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown index backend: %q (supported: xapian)", backend)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
package index_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/archivmail/internal/index"
|
||||
)
|
||||
|
||||
// newXapianIndex creates a temporary Xapian index for testing.
|
||||
func newXapianIndex(t *testing.T) index.Indexer {
|
||||
t.Helper()
|
||||
idx, err := index.New(t.TempDir(), 100, "xapian")
|
||||
if err != nil {
|
||||
t.Skip("xapian not available:", err)
|
||||
}
|
||||
t.Cleanup(func() { idx.Close() })
|
||||
return idx
|
||||
}
|
||||
|
||||
func seedDocs(t *testing.T, idx index.Indexer) {
|
||||
t.Helper()
|
||||
docs := []index.MailDocument{
|
||||
{
|
||||
ID: "aaa111",
|
||||
From: "alice@example.com",
|
||||
To: "bob@example.com",
|
||||
Subject: "Invoice Q1-2026",
|
||||
Body: "Please find attached the invoice for January.",
|
||||
Date: time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC),
|
||||
Size: 1024,
|
||||
},
|
||||
{
|
||||
ID: "bbb222",
|
||||
From: "bob@example.com",
|
||||
To: "alice@example.com charlie@example.com",
|
||||
Subject: "Meeting Agenda",
|
||||
Body: "Agenda for the quarterly review meeting.",
|
||||
Date: time.Date(2026, 2, 1, 9, 0, 0, 0, time.UTC),
|
||||
Size: 512,
|
||||
},
|
||||
{
|
||||
ID: "ccc333",
|
||||
From: "charlie@example.com",
|
||||
To: "alice@example.com",
|
||||
Subject: "Offer with attachment",
|
||||
Body: "Please review the attached offer document.",
|
||||
AttachNames: "offer.pdf",
|
||||
HasAttachment: true,
|
||||
Date: time.Date(2026, 3, 1, 14, 0, 0, 0, time.UTC),
|
||||
Size: 8192,
|
||||
},
|
||||
}
|
||||
for _, d := range docs {
|
||||
if err := idx.IndexSync(d); err != nil {
|
||||
t.Fatalf("IndexSync %s: %v", d.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexAndSearchFulltext(t *testing.T) {
|
||||
idx := newXapianIndex(t)
|
||||
seedDocs(t, idx)
|
||||
|
||||
result, err := idx.Search(index.SearchRequest{Query: "invoice", PageSize: 10})
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
if result.Total == 0 {
|
||||
t.Error("expected at least 1 hit for 'invoice'")
|
||||
}
|
||||
if result.Hits[0].ID != "aaa111" {
|
||||
t.Errorf("top hit = %q, want aaa111", result.Hits[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchMatchAll(t *testing.T) {
|
||||
idx := newXapianIndex(t)
|
||||
seedDocs(t, idx)
|
||||
|
||||
result, err := idx.Search(index.SearchRequest{PageSize: 25})
|
||||
if err != nil {
|
||||
t.Fatalf("Search all: %v", err)
|
||||
}
|
||||
if result.Total != 3 {
|
||||
t.Errorf("expected 3 total hits, got %d", result.Total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchFromFilter(t *testing.T) {
|
||||
idx := newXapianIndex(t)
|
||||
seedDocs(t, idx)
|
||||
|
||||
result, err := idx.Search(index.SearchRequest{
|
||||
From: "alice@example.com",
|
||||
PageSize: 25,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Search from: %v", err)
|
||||
}
|
||||
if result.Total != 1 {
|
||||
t.Errorf("expected 1 hit from alice, got %d", result.Total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchDateRange(t *testing.T) {
|
||||
idx := newXapianIndex(t)
|
||||
seedDocs(t, idx)
|
||||
|
||||
from := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
to := time.Date(2026, 2, 1, 23, 59, 59, 0, time.UTC)
|
||||
result, err := idx.Search(index.SearchRequest{
|
||||
DateFrom: &from,
|
||||
DateTo: &to,
|
||||
PageSize: 25,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Search date range: %v", err)
|
||||
}
|
||||
if result.Total != 2 {
|
||||
t.Errorf("expected 2 hits in Jan-Feb 2026, got %d", result.Total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchOwnEmail(t *testing.T) {
|
||||
idx := newXapianIndex(t)
|
||||
seedDocs(t, idx)
|
||||
|
||||
// charlie@example.com sent 1 mail and received 1 mail = should see 2
|
||||
result, err := idx.Search(index.SearchRequest{
|
||||
OwnEmail: "charlie@example.com",
|
||||
PageSize: 25,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Search OwnEmail: %v", err)
|
||||
}
|
||||
if result.Total < 1 {
|
||||
t.Errorf("charlie should see at least 1 mail, got %d", result.Total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchPagination(t *testing.T) {
|
||||
idx := newXapianIndex(t)
|
||||
seedDocs(t, idx)
|
||||
|
||||
page0, _ := idx.Search(index.SearchRequest{PageSize: 2, Page: 0})
|
||||
page1, _ := idx.Search(index.SearchRequest{PageSize: 2, Page: 1})
|
||||
|
||||
if len(page0.Hits) != 2 {
|
||||
t.Errorf("page 0: expected 2 hits, got %d", len(page0.Hits))
|
||||
}
|
||||
if len(page1.Hits) != 1 {
|
||||
t.Errorf("page 1: expected 1 hit, got %d", len(page1.Hits))
|
||||
}
|
||||
// No overlap
|
||||
if page0.Hits[0].ID == page1.Hits[0].ID {
|
||||
t.Error("pagination returned duplicate results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
idx := newXapianIndex(t)
|
||||
seedDocs(t, idx)
|
||||
|
||||
if err := idx.Delete("aaa111"); err != nil {
|
||||
t.Fatalf("Delete: %v", err)
|
||||
}
|
||||
|
||||
result, _ := idx.Search(index.SearchRequest{Query: "invoice", PageSize: 10})
|
||||
for _, h := range result.Hits {
|
||||
if h.ID == "aaa111" {
|
||||
t.Error("deleted document still in results")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownBackend(t *testing.T) {
|
||||
_, err := index.New(t.TempDir(), 10, "elasticsearch")
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown backend")
|
||||
}
|
||||
}
|
||||
|
||||
func TestXapianNotCompiledError(t *testing.T) {
|
||||
_, err := index.New(t.TempDir(), 10, "xapian")
|
||||
// Without -tags xapian this must return a helpful error
|
||||
if err == nil {
|
||||
t.Log("xapian compiled in — skipping stub error test")
|
||||
} else {
|
||||
t.Logf("xapian stub error (expected): %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
//go:build xapian
|
||||
|
||||
package index
|
||||
|
||||
/*
|
||||
#cgo pkg-config: xapian-core
|
||||
#cgo LDFLAGS: -lstdc++
|
||||
#include "xapian_wrapper.h"
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type xapianIndex struct {
|
||||
db *C.XapianDB
|
||||
}
|
||||
|
||||
func newXapian(dir string) (Indexer, error) {
|
||||
cdir := C.CString(dir)
|
||||
defer C.free(unsafe.Pointer(cdir))
|
||||
var cerr *C.char
|
||||
db := C.xapian_open(cdir, 1, &cerr)
|
||||
if db == nil {
|
||||
msg := C.GoString(cerr)
|
||||
C.xapian_free_string(cerr)
|
||||
return nil, fmt.Errorf("xapian open: %s", msg)
|
||||
}
|
||||
return &xapianIndex{db: db}, nil
|
||||
}
|
||||
|
||||
func (x *xapianIndex) IndexSync(doc MailDocument) error {
|
||||
cid := C.CString(doc.ID)
|
||||
defer C.free(unsafe.Pointer(cid))
|
||||
cfrom := C.CString(doc.From)
|
||||
defer C.free(unsafe.Pointer(cfrom))
|
||||
cto := C.CString(doc.To)
|
||||
defer C.free(unsafe.Pointer(cto))
|
||||
csubj := C.CString(doc.Subject)
|
||||
defer C.free(unsafe.Pointer(csubj))
|
||||
cbody := C.CString(doc.Body)
|
||||
defer C.free(unsafe.Pointer(cbody))
|
||||
var cerr *C.char
|
||||
rc := C.xapian_index(x.db, cid, cfrom, cto, csubj, cbody, C.longlong(doc.Date.Unix()), &cerr)
|
||||
if rc != 0 {
|
||||
msg := C.GoString(cerr)
|
||||
C.xapian_free_string(cerr)
|
||||
return fmt.Errorf("xapian index: %s", msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *xapianIndex) Delete(id string) error {
|
||||
cid := C.CString(id)
|
||||
defer C.free(unsafe.Pointer(cid))
|
||||
var cerr *C.char
|
||||
rc := C.xapian_delete(x.db, cid, &cerr)
|
||||
if rc != 0 {
|
||||
msg := C.GoString(cerr)
|
||||
C.xapian_free_string(cerr)
|
||||
return fmt.Errorf("xapian delete: %s", msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *xapianIndex) Search(req SearchRequest) (*SearchResult, error) {
|
||||
cquery := C.CString(req.Query)
|
||||
defer C.free(unsafe.Pointer(cquery))
|
||||
cfrom := C.CString(req.From)
|
||||
defer C.free(unsafe.Pointer(cfrom))
|
||||
cown := C.CString(req.OwnEmail)
|
||||
defer C.free(unsafe.Pointer(cown))
|
||||
cto := C.CString(req.To)
|
||||
defer C.free(unsafe.Pointer(cto))
|
||||
|
||||
var dateFrom, dateTo C.longlong
|
||||
if req.DateFrom != nil {
|
||||
dateFrom = C.longlong(req.DateFrom.Unix())
|
||||
}
|
||||
if req.DateTo != nil {
|
||||
dateTo = C.longlong(req.DateTo.Unix())
|
||||
}
|
||||
page := req.Page
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
offset := C.int((page - 1) * req.PageSize)
|
||||
limit := C.int(req.PageSize)
|
||||
if limit <= 0 {
|
||||
limit = 25
|
||||
}
|
||||
|
||||
var cerr *C.char
|
||||
cresult := C.xapian_search(x.db, cquery, cfrom, cown, cto, dateFrom, dateTo, offset, limit, &cerr)
|
||||
if cresult == nil {
|
||||
msg := C.GoString(cerr)
|
||||
C.xapian_free_string(cerr)
|
||||
return nil, fmt.Errorf("xapian search: %s", msg)
|
||||
}
|
||||
defer C.xapian_free_string(cresult)
|
||||
jsonStr := C.GoString(cresult)
|
||||
|
||||
var raw struct {
|
||||
Total int `json:"total"`
|
||||
Hits []struct {
|
||||
ID string `json:"id"`
|
||||
Score float64 `json:"score"`
|
||||
} `json:"hits"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
|
||||
return nil, fmt.Errorf("xapian parse result: %w", err)
|
||||
}
|
||||
hits := make([]Hit, len(raw.Hits))
|
||||
for i, h := range raw.Hits {
|
||||
hits[i] = Hit{ID: h.ID, Score: h.Score}
|
||||
}
|
||||
return &SearchResult{Total: raw.Total, Hits: hits}, nil
|
||||
}
|
||||
|
||||
func (x *xapianIndex) Close() error {
|
||||
C.xapian_close(x.db)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
//go:build !xapian
|
||||
|
||||
package index
|
||||
|
||||
import "errors"
|
||||
|
||||
func newXapian(dir string) (Indexer, error) {
|
||||
return nil, errors.New("xapian: not compiled in — rebuild with: go build -tags xapian")
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
#include "xapian_wrapper.h"
|
||||
#include <xapian.h>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
|
||||
struct XapianDB {
|
||||
Xapian::WritableDatabase* wdb;
|
||||
Xapian::Database* rdb;
|
||||
bool writable;
|
||||
};
|
||||
|
||||
static char* dup_error(const std::string& msg) {
|
||||
char* s = (char*)malloc(msg.size() + 1);
|
||||
if (s) memcpy(s, msg.c_str(), msg.size() + 1);
|
||||
return s;
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
XapianDB* xapian_open(const char* path, int writable, char** err) {
|
||||
try {
|
||||
XapianDB* db = new XapianDB{nullptr, nullptr, (bool)writable};
|
||||
if (writable) {
|
||||
db->wdb = new Xapian::WritableDatabase(path, Xapian::DB_CREATE_OR_OPEN);
|
||||
} else {
|
||||
db->rdb = new Xapian::Database(path);
|
||||
}
|
||||
return db;
|
||||
} catch (const std::exception& e) {
|
||||
if (err) *err = dup_error(e.what());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void xapian_close(XapianDB* db) {
|
||||
if (!db) return;
|
||||
if (db->wdb) { db->wdb->close(); delete db->wdb; }
|
||||
if (db->rdb) { db->rdb->close(); delete db->rdb; }
|
||||
delete db;
|
||||
}
|
||||
|
||||
int xapian_index(XapianDB* db, const char* id, const char* from,
|
||||
const char* to, const char* subject, const char* body,
|
||||
long long timestamp, char** err) {
|
||||
try {
|
||||
Xapian::Document doc;
|
||||
Xapian::TermGenerator gen;
|
||||
gen.set_document(doc);
|
||||
gen.set_stemmer(Xapian::Stem("en"));
|
||||
|
||||
// Prefix-indexed fields for filtering
|
||||
gen.index_text(from, 1, "XF");
|
||||
gen.index_text(to, 1, "XT");
|
||||
gen.index_text(subject, 1, "XS");
|
||||
|
||||
// Free-text indexed fields
|
||||
gen.index_text(subject);
|
||||
gen.increase_termpos();
|
||||
gen.index_text(body);
|
||||
gen.increase_termpos();
|
||||
gen.index_text(from);
|
||||
gen.increase_termpos();
|
||||
gen.index_text(to);
|
||||
|
||||
// Store timestamp for date range queries (value slot 0)
|
||||
doc.add_value(0, Xapian::sortable_serialise((double)timestamp));
|
||||
|
||||
// Store ID as document data
|
||||
doc.set_data(id);
|
||||
doc.add_boolean_term(std::string("Q") + id);
|
||||
|
||||
db->wdb->replace_document(std::string("Q") + id, doc);
|
||||
db->wdb->commit();
|
||||
return 0;
|
||||
} catch (const std::exception& e) {
|
||||
if (err) *err = dup_error(e.what());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int xapian_delete(XapianDB* db, const char* id, char** err) {
|
||||
try {
|
||||
db->wdb->delete_document(std::string("Q") + id);
|
||||
db->wdb->commit();
|
||||
return 0;
|
||||
} catch (const std::exception& e) {
|
||||
if (err) *err = dup_error(e.what());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
char* xapian_search(XapianDB* db, const char* query_str,
|
||||
const char* from_filter, const char* own_email,
|
||||
const char* to_filter,
|
||||
long long date_from, long long date_to,
|
||||
int offset, int limit, char** err) {
|
||||
try {
|
||||
Xapian::Database& xdb = db->wdb ? (Xapian::Database&)*db->wdb : *db->rdb;
|
||||
Xapian::Enquire enquire(xdb);
|
||||
|
||||
Xapian::Query main_query;
|
||||
|
||||
// Full-text query
|
||||
if (query_str && query_str[0] != '\0') {
|
||||
Xapian::QueryParser qp;
|
||||
qp.set_database(xdb);
|
||||
qp.set_stemmer(Xapian::Stem("en"));
|
||||
qp.set_stemming_strategy(Xapian::QueryParser::STEM_SOME);
|
||||
qp.add_prefix("from", "XF");
|
||||
qp.add_prefix("to", "XT");
|
||||
qp.add_prefix("subject", "XS");
|
||||
main_query = qp.parse_query(query_str,
|
||||
Xapian::QueryParser::FLAG_DEFAULT |
|
||||
Xapian::QueryParser::FLAG_PARTIAL);
|
||||
} else {
|
||||
main_query = Xapian::Query::MatchAll;
|
||||
}
|
||||
|
||||
// From filter
|
||||
if (from_filter && from_filter[0] != '\0') {
|
||||
Xapian::QueryParser qp;
|
||||
qp.set_database(xdb);
|
||||
Xapian::Query fq = qp.parse_query(from_filter,
|
||||
Xapian::QueryParser::FLAG_DEFAULT, "XF");
|
||||
main_query = Xapian::Query(Xapian::Query::OP_AND, main_query, fq);
|
||||
}
|
||||
|
||||
// OwnEmail filter: (from=own OR to=own)
|
||||
if (own_email && own_email[0] != '\0') {
|
||||
Xapian::QueryParser qp;
|
||||
qp.set_database(xdb);
|
||||
Xapian::Query fq = qp.parse_query(own_email,
|
||||
Xapian::QueryParser::FLAG_DEFAULT, "XF");
|
||||
Xapian::Query tq = qp.parse_query(own_email,
|
||||
Xapian::QueryParser::FLAG_DEFAULT, "XT");
|
||||
Xapian::Query owq(Xapian::Query::OP_OR, fq, tq);
|
||||
main_query = Xapian::Query(Xapian::Query::OP_AND, main_query, owq);
|
||||
}
|
||||
|
||||
// To filter
|
||||
if (to_filter && to_filter[0] != '\0') {
|
||||
Xapian::QueryParser qp;
|
||||
qp.set_database(xdb);
|
||||
Xapian::Query tq = qp.parse_query(to_filter,
|
||||
Xapian::QueryParser::FLAG_DEFAULT, "XT");
|
||||
main_query = Xapian::Query(Xapian::Query::OP_AND, main_query, tq);
|
||||
}
|
||||
|
||||
// Date range
|
||||
if (date_from > 0 || date_to > 0) {
|
||||
double lo = date_from > 0 ? (double)date_from : 0.0;
|
||||
double hi = date_to > 0 ? (double)date_to : 1e18;
|
||||
Xapian::Query drq(Xapian::Query::OP_VALUE_RANGE, 0,
|
||||
Xapian::sortable_serialise(lo),
|
||||
Xapian::sortable_serialise(hi));
|
||||
main_query = Xapian::Query(Xapian::Query::OP_AND, main_query, drq);
|
||||
}
|
||||
|
||||
enquire.set_query(main_query);
|
||||
enquire.set_sort_by_value(0, true); // sort by date desc
|
||||
|
||||
// Get total count
|
||||
Xapian::MSet all = enquire.get_mset(0, xdb.get_doccount());
|
||||
int total = (int)all.get_matches_estimated();
|
||||
|
||||
// Get page
|
||||
Xapian::MSet mset = enquire.get_mset(offset, limit);
|
||||
|
||||
std::ostringstream json;
|
||||
json << "{\"total\":" << total << ",\"hits\":[";
|
||||
bool first = true;
|
||||
for (auto it = mset.begin(); it != mset.end(); ++it) {
|
||||
if (!first) json << ",";
|
||||
first = false;
|
||||
std::string id = it.get_document().get_data();
|
||||
double score = it.get_weight();
|
||||
json << "{\"id\":\"" << id << "\",\"score\":" << score << "}";
|
||||
}
|
||||
json << "]}";
|
||||
|
||||
std::string result = json.str();
|
||||
char* out = (char*)malloc(result.size() + 1);
|
||||
memcpy(out, result.c_str(), result.size() + 1);
|
||||
return out;
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
if (err) *err = dup_error(e.what());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void xapian_free_string(char* s) {
|
||||
free(s);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -0,0 +1,32 @@
|
||||
#ifndef XAPIAN_WRAPPER_H
|
||||
#define XAPIAN_WRAPPER_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct XapianDB XapianDB;
|
||||
|
||||
XapianDB* xapian_open(const char* path, int writable, char** err);
|
||||
void xapian_close(XapianDB* db);
|
||||
|
||||
int xapian_index(XapianDB* db, const char* id, const char* from,
|
||||
const char* to, const char* subject, const char* body,
|
||||
long long timestamp, char** err);
|
||||
|
||||
int xapian_delete(XapianDB* db, const char* id, char** err);
|
||||
|
||||
/* Returns JSON string: {"total":N,"hits":[{"id":"...","score":0.9},...]}
|
||||
Returns NULL on error, sets *err. Caller must free with xapian_free_string. */
|
||||
char* xapian_search(XapianDB* db, const char* query,
|
||||
const char* from_filter, const char* own_email,
|
||||
const char* to_filter,
|
||||
long long date_from, long long date_to,
|
||||
int offset, int limit, char** err);
|
||||
|
||||
void xapian_free_string(char* s);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -0,0 +1,283 @@
|
||||
// Package smtpd implements an embedded receive-only SMTP daemon for archivmail.
|
||||
// It accepts incoming emails (e.g. from Postfix via always_bcc) and hands them
|
||||
// off to the storage coordinator. No AUTH, no relay, no outbound mail.
|
||||
package smtpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/emersion/go-smtp"
|
||||
|
||||
"github.com/archivmail/config"
|
||||
"github.com/archivmail/internal/storage"
|
||||
)
|
||||
|
||||
// Stats holds runtime statistics for the SMTP daemon.
|
||||
type Stats struct {
|
||||
Received atomic.Int64 // total emails successfully stored
|
||||
Rejected atomic.Int64 // rejected (IP, size, etc.)
|
||||
LastMailAt atomic.Value // time.Time of last accepted mail
|
||||
}
|
||||
|
||||
// Daemon is the embedded receive-only SMTP server.
|
||||
type Daemon struct {
|
||||
cfg config.SMTPConfig
|
||||
store *storage.Store
|
||||
logger *slog.Logger
|
||||
stats Stats
|
||||
server *smtp.Server
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// New creates a new SMTP Daemon. Call Start() to begin accepting connections.
|
||||
func New(cfg config.SMTPConfig, store *storage.Store, logger *slog.Logger) *Daemon {
|
||||
d := &Daemon{
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
logger: logger,
|
||||
}
|
||||
d.stats.LastMailAt.Store(time.Time{})
|
||||
return d
|
||||
}
|
||||
|
||||
// Start launches the SMTP daemon in a background goroutine.
|
||||
// It returns immediately; use Stop() for graceful shutdown.
|
||||
func (d *Daemon) Start() error {
|
||||
if !d.cfg.Enabled {
|
||||
d.logger.Info("SMTP daemon disabled via config")
|
||||
return nil
|
||||
}
|
||||
|
||||
bind := d.cfg.Bind
|
||||
if bind == "" {
|
||||
bind = ":2525"
|
||||
}
|
||||
domain := d.cfg.Domain
|
||||
if domain == "" {
|
||||
domain = "archivmail"
|
||||
}
|
||||
maxBytes := int64(d.cfg.MaxSizeMB) * 1024 * 1024
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 50 * 1024 * 1024 // 50 MB default
|
||||
}
|
||||
|
||||
backend := &backend{daemon: d}
|
||||
srv := smtp.NewServer(backend)
|
||||
srv.Addr = bind
|
||||
srv.Domain = domain
|
||||
srv.MaxMessageBytes = maxBytes
|
||||
srv.ReadTimeout = 5 * time.Minute
|
||||
srv.WriteTimeout = 30 * time.Second
|
||||
srv.AllowInsecureAuth = false // no AUTH offered at all
|
||||
|
||||
// TLS / STARTTLS
|
||||
if d.cfg.TLSCert != "" && d.cfg.TLSKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(d.cfg.TLSCert, d.cfg.TLSKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("smtpd: load TLS cert: %w", err)
|
||||
}
|
||||
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
}
|
||||
|
||||
d.mu.Lock()
|
||||
d.server = srv
|
||||
d.running = true
|
||||
d.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
d.logger.Info("SMTP daemon starting", "addr", bind, "domain", domain,
|
||||
"max_size_mb", d.cfg.MaxSizeMB, "tls", d.cfg.TLSCert != "")
|
||||
if err := srv.ListenAndServe(); err != nil {
|
||||
if !errors.Is(err, smtp.ErrServerClosed) {
|
||||
d.logger.Error("SMTP daemon error", "err", err)
|
||||
}
|
||||
}
|
||||
d.mu.Lock()
|
||||
d.running = false
|
||||
d.mu.Unlock()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop shuts down the SMTP daemon gracefully.
|
||||
func (d *Daemon) Stop() {
|
||||
d.mu.Lock()
|
||||
srv := d.server
|
||||
d.mu.Unlock()
|
||||
if srv != nil {
|
||||
srv.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns a snapshot of the daemon's current state.
|
||||
func (d *Daemon) Status() StatusResponse {
|
||||
d.mu.Lock()
|
||||
running := d.running
|
||||
d.mu.Unlock()
|
||||
|
||||
lastMail, _ := d.stats.LastMailAt.Load().(time.Time)
|
||||
var lastMailStr string
|
||||
if !lastMail.IsZero() {
|
||||
lastMailStr = lastMail.UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
bind := d.cfg.Bind
|
||||
if bind == "" {
|
||||
bind = ":2525"
|
||||
}
|
||||
|
||||
return StatusResponse{
|
||||
Running: running,
|
||||
Enabled: d.cfg.Enabled,
|
||||
Bind: bind,
|
||||
Domain: d.cfg.Domain,
|
||||
TLS: d.cfg.TLSCert != "",
|
||||
MaxSizeMB: d.cfg.MaxSizeMB,
|
||||
AllowedIPs: d.cfg.AllowedIPs,
|
||||
Received: d.stats.Received.Load(),
|
||||
Rejected: d.stats.Rejected.Load(),
|
||||
LastMailAt: lastMailStr,
|
||||
}
|
||||
}
|
||||
|
||||
// StatusResponse is returned by GET /api/admin/smtp/status.
|
||||
type StatusResponse struct {
|
||||
Running bool `json:"running"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Bind string `json:"bind"`
|
||||
Domain string `json:"domain"`
|
||||
TLS bool `json:"tls"`
|
||||
MaxSizeMB int `json:"max_size_mb"`
|
||||
AllowedIPs []string `json:"allowed_ips"`
|
||||
Received int64 `json:"received"`
|
||||
Rejected int64 `json:"rejected"`
|
||||
LastMailAt string `json:"last_mail_at,omitempty"`
|
||||
}
|
||||
|
||||
// ── go-smtp Backend / Session ─────────────────────────────────────────────
|
||||
|
||||
type backend struct {
|
||||
daemon *Daemon
|
||||
}
|
||||
|
||||
func (b *backend) NewSession(c *smtp.Conn) (smtp.Session, error) {
|
||||
remoteIP := extractIP(c.Conn().RemoteAddr().String())
|
||||
|
||||
if !b.daemon.isAllowed(remoteIP) {
|
||||
b.daemon.stats.Rejected.Add(1)
|
||||
b.daemon.logger.Warn("SMTP: rejected connection from unlisted IP", "ip", remoteIP)
|
||||
return nil, &smtp.SMTPError{
|
||||
Code: 554,
|
||||
EnhancedCode: smtp.EnhancedCode{5, 7, 1},
|
||||
Message: "IP not in allowlist",
|
||||
}
|
||||
}
|
||||
|
||||
b.daemon.logger.Debug("SMTP: new session", "ip", remoteIP)
|
||||
return &session{
|
||||
daemon: b.daemon,
|
||||
remoteIP: remoteIP,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type session struct {
|
||||
daemon *Daemon
|
||||
remoteIP string
|
||||
from string
|
||||
rcpts []string
|
||||
}
|
||||
|
||||
// AuthPlain – never called because server doesn't advertise AUTH.
|
||||
func (s *session) AuthPlain(_, _ string) error {
|
||||
return smtp.ErrAuthUnsupported
|
||||
}
|
||||
|
||||
func (s *session) Mail(from string, _ *smtp.MailOptions) error {
|
||||
s.from = from
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) Rcpt(to string, _ *smtp.RcptOptions) error {
|
||||
s.rcpts = append(s.rcpts, to)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) Data(r io.Reader) error {
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.Copy(&buf, r); err != nil {
|
||||
s.daemon.stats.Rejected.Add(1)
|
||||
return fmt.Errorf("smtpd: read data: %w", err)
|
||||
}
|
||||
raw := buf.Bytes()
|
||||
|
||||
id, err := s.daemon.store.Save(raw, time.Now())
|
||||
if err != nil {
|
||||
s.daemon.stats.Rejected.Add(1)
|
||||
s.daemon.logger.Error("SMTP: storage failed", "from", s.from, "err", err)
|
||||
return &smtp.SMTPError{
|
||||
Code: 554,
|
||||
EnhancedCode: smtp.EnhancedCode{4, 6, 0},
|
||||
Message: "Storage failure, please retry",
|
||||
}
|
||||
}
|
||||
|
||||
s.daemon.stats.Received.Add(1)
|
||||
s.daemon.stats.LastMailAt.Store(time.Now())
|
||||
s.daemon.logger.Info("SMTP: mail stored", "id", id, "from", s.from,
|
||||
"rcpts", strings.Join(s.rcpts, ","), "bytes", len(raw), "ip", s.remoteIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) Reset() {
|
||||
s.from = ""
|
||||
s.rcpts = nil
|
||||
}
|
||||
|
||||
func (s *session) Logout() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
// isAllowed returns true if the IP is in the allowlist, or if the allowlist
|
||||
// is empty (allow-all mode for development).
|
||||
func (d *Daemon) isAllowed(ip string) bool {
|
||||
if len(d.cfg.AllowedIPs) == 0 {
|
||||
return true // no restriction configured
|
||||
}
|
||||
for _, allowed := range d.cfg.AllowedIPs {
|
||||
// Support CIDR notation (e.g. 192.168.1.0/24)
|
||||
if strings.Contains(allowed, "/") {
|
||||
_, network, err := net.ParseCIDR(allowed)
|
||||
if err == nil && network.Contains(net.ParseIP(ip)) {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
if allowed == ip {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractIP strips port from "ip:port" or "[::1]:port" strings.
|
||||
func extractIP(addr string) string {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
return host
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Store is a file-based email storage using SHA256 for deduplication.
|
||||
type Store struct {
|
||||
dir string
|
||||
}
|
||||
|
||||
// StoreStats reports total mail count and size in bytes.
|
||||
type StoreStats struct {
|
||||
TotalMails int64
|
||||
TotalBytes int64
|
||||
}
|
||||
|
||||
// New initialises the storage directory, creating required subdirectories.
|
||||
func New(dir string) (*Store, error) {
|
||||
for _, sub := range []string{"store", "attachments", "meta"} {
|
||||
if err := os.MkdirAll(filepath.Join(dir, sub), 0o755); err != nil {
|
||||
return nil, fmt.Errorf("storage: mkdir %s: %w", sub, err)
|
||||
}
|
||||
}
|
||||
return &Store{dir: dir}, nil
|
||||
}
|
||||
|
||||
// Save writes raw email bytes to storage. The ID is the hex-encoded SHA256 of
|
||||
// the content. If the file already exists, Save is a no-op (deduplication).
|
||||
func (s *Store) Save(raw []byte, _ time.Time) (string, error) {
|
||||
sum := sha256.Sum256(raw)
|
||||
id := fmt.Sprintf("%x", sum[:]) // 64 hex chars
|
||||
|
||||
path := s.filePath(id)
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return "", fmt.Errorf("storage: mkdir shard: %w", err)
|
||||
}
|
||||
|
||||
// If file already exists, dedup: return same id without error.
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, raw, 0o644); err != nil {
|
||||
return "", fmt.Errorf("storage: write: %w", err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// Load reads a stored email by its ID.
|
||||
func (s *Store) Load(id string) ([]byte, error) {
|
||||
path := s.filePath(id)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("storage: not found: %s", id)
|
||||
}
|
||||
return nil, fmt.Errorf("storage: read: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Delete removes a stored email by its ID.
|
||||
func (s *Store) Delete(id string) error {
|
||||
path := s.filePath(id)
|
||||
if err := os.Remove(path); err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("storage: not found: %s", id)
|
||||
}
|
||||
return fmt.Errorf("storage: delete: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats walks the store directory and returns aggregate statistics.
|
||||
func (s *Store) Stats() (*StoreStats, error) {
|
||||
var stats StoreStats
|
||||
err := filepath.WalkDir(filepath.Join(s.dir, "store"), func(_ string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stats.TotalMails++
|
||||
stats.TotalBytes += info.Size()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storage: stats: %w", err)
|
||||
}
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// MailRef holds the ID and modification time of a stored mail.
|
||||
type MailRef struct {
|
||||
ID string
|
||||
ModTime time.Time
|
||||
}
|
||||
|
||||
// FirstAndLastMail walks the store and returns the oldest and newest mail by
|
||||
// file modification time. Returns nil for either if the store is empty.
|
||||
func (s *Store) FirstAndLastMail() (first, last *MailRef, err error) {
|
||||
err = filepath.WalkDir(filepath.Join(s.dir, "store"), func(path string, d fs.DirEntry, werr error) error {
|
||||
if werr != nil {
|
||||
return werr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ref := &MailRef{ID: d.Name(), ModTime: info.ModTime()}
|
||||
if first == nil || ref.ModTime.Before(first.ModTime) {
|
||||
first = ref
|
||||
}
|
||||
if last == nil || ref.ModTime.After(last.ModTime) {
|
||||
last = ref
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("storage: first/last: %w", err)
|
||||
}
|
||||
return first, last, nil
|
||||
}
|
||||
|
||||
// filePath returns the on-disk path for a given mail ID.
|
||||
// Uses 2-char prefix sharding: {dir}/store/{id[:2]}/{id}
|
||||
func (s *Store) filePath(id string) string {
|
||||
return filepath.Join(s.dir, "store", id[:2], id)
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package storage_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/archivmail/internal/storage"
|
||||
)
|
||||
|
||||
func TestSaveAndLoad(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
store, err := storage.New(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
raw := []byte("From: alice@example.com\r\nSubject: Test\r\n\r\nHello World")
|
||||
id, err := store.Save(raw, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("Save: %v", err)
|
||||
}
|
||||
if len(id) != 64 {
|
||||
t.Errorf("expected 64-char SHA256 hex, got %d chars", len(id))
|
||||
}
|
||||
|
||||
got, err := store.Load(id)
|
||||
if err != nil {
|
||||
t.Fatalf("Load: %v", err)
|
||||
}
|
||||
if !bytes.Equal(raw, got) {
|
||||
t.Errorf("loaded content mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeduplication(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
store, err := storage.New(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw := []byte("From: alice@example.com\r\n\r\nDuplicate test")
|
||||
id1, err := store.Save(raw, time.Now())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
id2, err := store.Save(raw, time.Now())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if id1 != id2 {
|
||||
t.Errorf("duplicate mail produced different IDs: %s vs %s", id1, id2)
|
||||
}
|
||||
|
||||
// Only one file should exist
|
||||
count := 0
|
||||
filepath.Walk(filepath.Join(dir, "store"), func(p string, info os.FileInfo, _ error) error {
|
||||
if !info.IsDir() { count++ }
|
||||
return nil
|
||||
})
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 stored file after dedup, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
store, err := storage.New(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
raw := []byte("From: alice@example.com\r\n\r\nDelete me")
|
||||
id, _ := store.Save(raw, time.Now())
|
||||
|
||||
if err := store.Delete(id); err != nil {
|
||||
t.Fatalf("Delete: %v", err)
|
||||
}
|
||||
if _, err := store.Load(id); err == nil {
|
||||
t.Error("Load after Delete should return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStats(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
store, err := storage.New(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mails := [][]byte{
|
||||
[]byte("From: a@x.com\r\n\r\nMail 1"),
|
||||
[]byte("From: b@x.com\r\n\r\nMail 2"),
|
||||
[]byte("From: c@x.com\r\n\r\nMail 3"),
|
||||
}
|
||||
for _, m := range mails {
|
||||
store.Save(m, time.Now())
|
||||
}
|
||||
|
||||
stats, err := store.Stats()
|
||||
if err != nil {
|
||||
t.Fatalf("Stats: %v", err)
|
||||
}
|
||||
if stats.TotalMails != 3 {
|
||||
t.Errorf("expected 3 mails, got %d", stats.TotalMails)
|
||||
}
|
||||
if stats.TotalBytes <= 0 {
|
||||
t.Error("expected positive TotalBytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorageDirectoryCreation(t *testing.T) {
|
||||
dir := filepath.Join(t.TempDir(), "nested", "path")
|
||||
_, err := storage.New(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("New with nested path: %v", err)
|
||||
}
|
||||
for _, sub := range []string{"store", "attachments", "meta"} {
|
||||
if _, err := os.Stat(filepath.Join(dir, sub)); os.IsNotExist(err) {
|
||||
t.Errorf("expected subdirectory %s to be created", sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
package userstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
RoleUser = "user"
|
||||
RoleAdmin = "admin"
|
||||
RoleAuditor = "auditor"
|
||||
)
|
||||
|
||||
// User represents a user account in the system.
|
||||
type User struct {
|
||||
ID int64
|
||||
Username string
|
||||
Email string
|
||||
Role string
|
||||
Source string // "local" or "ldap"
|
||||
Active bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// CreateUserRequest holds parameters for creating a new user.
|
||||
type CreateUserRequest struct {
|
||||
Username string
|
||||
Email string
|
||||
Password string
|
||||
Role string
|
||||
}
|
||||
|
||||
// UpdateUserRequest holds optional fields for updating a user.
|
||||
type UpdateUserRequest struct {
|
||||
Email *string
|
||||
Role *string
|
||||
Active *bool
|
||||
Password *string
|
||||
}
|
||||
|
||||
// Store is a PostgreSQL-backed user store.
|
||||
type Store struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// New connects to PostgreSQL using the given DSN and initialises the schema.
|
||||
func New(dsn string) (*Store, error) {
|
||||
ctx := context.Background()
|
||||
pool, err := pgxpool.New(ctx, dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: connect: %w", err)
|
||||
}
|
||||
|
||||
s := &Store{pool: pool}
|
||||
if err := s.initSchema(ctx); err != nil {
|
||||
pool.Close()
|
||||
return nil, fmt.Errorf("userstore: init schema: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Store) initSchema(ctx context.Context) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
username VARCHAR(100) UNIQUE NOT NULL,
|
||||
email VARCHAR(255) UNIQUE NOT NULL,
|
||||
password_hash VARCHAR(255) NOT NULL DEFAULT '',
|
||||
role VARCHAR(20) NOT NULL CHECK (role IN ('user','auditor','admin')),
|
||||
source VARCHAR(20) NOT NULL DEFAULT 'local',
|
||||
active BOOLEAN NOT NULL DEFAULT true,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS token_blacklist (
|
||||
jti VARCHAR(255) PRIMARY KEY,
|
||||
expires_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the underlying connection pool.
|
||||
func (s *Store) Close() error {
|
||||
s.pool.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create inserts a new local user with a bcrypt-hashed password.
|
||||
func (s *Store) Create(req CreateUserRequest) (*User, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: bcrypt: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
var id int64
|
||||
err = s.pool.QueryRow(ctx,
|
||||
`INSERT INTO users (username, email, password_hash, role, source, active, created_at)
|
||||
VALUES ($1, $2, $3, $4, 'local', true, NOW())
|
||||
RETURNING id`,
|
||||
req.Username, req.Email, string(hash), req.Role,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: create: %w", err)
|
||||
}
|
||||
|
||||
return s.GetByID(id)
|
||||
}
|
||||
|
||||
// GetByID retrieves a user by their numeric ID.
|
||||
func (s *Store) GetByID(id int64) (*User, error) {
|
||||
ctx := context.Background()
|
||||
row := s.pool.QueryRow(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at FROM users WHERE id = $1`, id,
|
||||
)
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
// GetByUsername retrieves a user by their username.
|
||||
func (s *Store) GetByUsername(username string) (*User, error) {
|
||||
ctx := context.Background()
|
||||
row := s.pool.QueryRow(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at FROM users WHERE username = $1`, username,
|
||||
)
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
// VerifyPassword checks credentials and returns the user, or an error if the
|
||||
// password is wrong or the account is disabled.
|
||||
func (s *Store) VerifyPassword(username, password string) (*User, error) {
|
||||
ctx := context.Background()
|
||||
row := s.pool.QueryRow(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at, password_hash FROM users WHERE username = $1`,
|
||||
username,
|
||||
)
|
||||
|
||||
var u User
|
||||
var hash string
|
||||
err := row.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt, &hash)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, errors.New("userstore: user not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: scan: %w", err)
|
||||
}
|
||||
|
||||
if !u.Active {
|
||||
return nil, errors.New("userstore: account disabled")
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
|
||||
return nil, errors.New("userstore: wrong password")
|
||||
}
|
||||
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// Update applies a partial update to a user record.
|
||||
func (s *Store) Update(id int64, req UpdateUserRequest) (*User, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
if req.Email != nil {
|
||||
if _, err := s.pool.Exec(ctx, `UPDATE users SET email = $1 WHERE id = $2`, *req.Email, id); err != nil {
|
||||
return nil, fmt.Errorf("userstore: update email: %w", err)
|
||||
}
|
||||
}
|
||||
if req.Role != nil {
|
||||
if _, err := s.pool.Exec(ctx, `UPDATE users SET role = $1 WHERE id = $2`, *req.Role, id); err != nil {
|
||||
return nil, fmt.Errorf("userstore: update role: %w", err)
|
||||
}
|
||||
}
|
||||
if req.Active != nil {
|
||||
if _, err := s.pool.Exec(ctx, `UPDATE users SET active = $1 WHERE id = $2`, *req.Active, id); err != nil {
|
||||
return nil, fmt.Errorf("userstore: update active: %w", err)
|
||||
}
|
||||
}
|
||||
if req.Password != nil {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: bcrypt: %w", err)
|
||||
}
|
||||
if _, err := s.pool.Exec(ctx, `UPDATE users SET password_hash = $1 WHERE id = $2`, string(hash), id); err != nil {
|
||||
return nil, fmt.Errorf("userstore: update password: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s.GetByID(id)
|
||||
}
|
||||
|
||||
// Delete removes a user by ID. Returns an error if the user does not exist.
|
||||
func (s *Store) Delete(id int64) error {
|
||||
ctx := context.Background()
|
||||
tag, err := s.pool.Exec(ctx, `DELETE FROM users WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("userstore: delete: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("userstore: user %d not found", id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns all users, optionally filtered by role. Pass role="" to list all.
|
||||
func (s *Store) List(role string) ([]*User, error) {
|
||||
ctx := context.Background()
|
||||
var rows pgx.Rows
|
||||
var err error
|
||||
|
||||
if role == "" {
|
||||
rows, err = s.pool.Query(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at FROM users ORDER BY id`)
|
||||
} else {
|
||||
rows, err = s.pool.Query(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at FROM users WHERE role = $1 ORDER BY id`, role)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: list: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*User
|
||||
for rows.Next() {
|
||||
u, err := scanUserRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, u)
|
||||
}
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
// BlacklistToken adds a JWT ID to the token blacklist.
|
||||
func (s *Store) BlacklistToken(jti string, expires time.Time) error {
|
||||
ctx := context.Background()
|
||||
_, err := s.pool.Exec(ctx,
|
||||
`INSERT INTO token_blacklist (jti, expires_at) VALUES ($1, $2)
|
||||
ON CONFLICT (jti) DO UPDATE SET expires_at = EXCLUDED.expires_at`,
|
||||
jti, expires.UTC(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// IsBlacklisted returns true if the given JTI is in the blacklist.
|
||||
func (s *Store) IsBlacklisted(jti string) (bool, error) {
|
||||
ctx := context.Background()
|
||||
var count int
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM token_blacklist WHERE jti = $1`, jti,
|
||||
).Scan(&count)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CleanExpiredTokens removes blacklist entries whose expiry has passed.
|
||||
func (s *Store) CleanExpiredTokens() error {
|
||||
ctx := context.Background()
|
||||
_, err := s.pool.Exec(ctx, `DELETE FROM token_blacklist WHERE expires_at < NOW()`)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpsertLDAPUser creates or updates an LDAP-sourced user.
|
||||
func (s *Store) UpsertLDAPUser(username, email, role string) (*User, error) {
|
||||
ctx := context.Background()
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
INSERT INTO users (username, email, password_hash, role, source, active, created_at)
|
||||
VALUES ($1, $2, '', $3, 'ldap', true, NOW())
|
||||
ON CONFLICT (username) DO UPDATE SET
|
||||
email = EXCLUDED.email,
|
||||
role = EXCLUDED.role,
|
||||
source = 'ldap'
|
||||
`, username, email, role)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: upsert ldap: %w", err)
|
||||
}
|
||||
return s.GetByUsername(username)
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func scanUser(row pgx.Row) (*User, error) {
|
||||
var u User
|
||||
err := row.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, fmt.Errorf("userstore: not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: scan: %w", err)
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func scanUserRow(rows pgx.Rows) (*User, error) {
|
||||
var u User
|
||||
if err := rows.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("userstore: scan row: %w", err)
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
package userstore_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/archivmail/internal/userstore"
|
||||
)
|
||||
|
||||
func newTestStore(t *testing.T) *userstore.Store {
|
||||
t.Helper()
|
||||
dsn := os.Getenv("TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
t.Skip("TEST_DATABASE_URL not set — skipping (needs PostgreSQL)")
|
||||
}
|
||||
// Use a unique schema per test to isolate
|
||||
schema := "test_" + strings.ReplaceAll(t.Name(), "/", "_")
|
||||
schema = strings.ToLower(schema)
|
||||
// Append schema to DSN
|
||||
sep := "?"
|
||||
if strings.Contains(dsn, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
schemaDSN := dsn + sep + "search_path=" + schema
|
||||
|
||||
// Create schema
|
||||
ctx := context.Background()
|
||||
conn, err := pgx.Connect(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("connect: %v", err)
|
||||
}
|
||||
conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schema)
|
||||
conn.Close(ctx)
|
||||
|
||||
s, err := userstore.New(schemaDSN)
|
||||
if err != nil {
|
||||
t.Fatalf("userstore.New: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
s.Close()
|
||||
conn2, _ := pgx.Connect(context.Background(), dsn)
|
||||
if conn2 != nil {
|
||||
conn2.Exec(context.Background(), "DROP SCHEMA "+schema+" CASCADE")
|
||||
conn2.Close(context.Background())
|
||||
}
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
func TestCreateAndGetUser(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
u, err := s.Create(userstore.CreateUserRequest{
|
||||
Username: "alice",
|
||||
Email: "alice@example.com",
|
||||
Password: "secret123",
|
||||
Role: userstore.RoleAdmin,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Create: %v", err)
|
||||
}
|
||||
if u.ID == 0 {
|
||||
t.Error("expected non-zero ID")
|
||||
}
|
||||
if u.Username != "alice" {
|
||||
t.Errorf("Username = %q", u.Username)
|
||||
}
|
||||
if u.Role != userstore.RoleAdmin {
|
||||
t.Errorf("Role = %q", u.Role)
|
||||
}
|
||||
if u.Source != "local" {
|
||||
t.Errorf("Source = %q, want local", u.Source)
|
||||
}
|
||||
|
||||
got, err := s.GetByID(u.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID: %v", err)
|
||||
}
|
||||
if got.Email != "alice@example.com" {
|
||||
t.Errorf("Email = %q", got.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPassword(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
_, err := s.Create(userstore.CreateUserRequest{
|
||||
Username: "bob", Email: "bob@example.com",
|
||||
Password: "correcthorse", Role: userstore.RoleUser,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Correct password
|
||||
u, err := s.VerifyPassword("bob", "correcthorse")
|
||||
if err != nil {
|
||||
t.Errorf("VerifyPassword correct: %v", err)
|
||||
}
|
||||
if u.Username != "bob" {
|
||||
t.Errorf("Username = %q", u.Username)
|
||||
}
|
||||
|
||||
// Wrong password
|
||||
if _, err := s.VerifyPassword("bob", "wrongpassword"); err == nil {
|
||||
t.Error("expected error for wrong password")
|
||||
}
|
||||
|
||||
// Non-existent user
|
||||
if _, err := s.VerifyPassword("nobody", "x"); err == nil {
|
||||
t.Error("expected error for unknown user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
u, _ := s.Create(userstore.CreateUserRequest{
|
||||
Username: "carol", Email: "carol@old.com",
|
||||
Password: "pw", Role: userstore.RoleUser,
|
||||
})
|
||||
|
||||
newEmail := "carol@new.com"
|
||||
newRole := userstore.RoleAuditor
|
||||
updated, err := s.Update(u.ID, userstore.UpdateUserRequest{
|
||||
Email: &newEmail,
|
||||
Role: &newRole,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Update: %v", err)
|
||||
}
|
||||
if updated.Email != "carol@new.com" {
|
||||
t.Errorf("Email after update = %q", updated.Email)
|
||||
}
|
||||
if updated.Role != userstore.RoleAuditor {
|
||||
t.Errorf("Role after update = %q", updated.Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisableUser(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
u, _ := s.Create(userstore.CreateUserRequest{
|
||||
Username: "dave", Email: "dave@x.com",
|
||||
Password: "pw", Role: userstore.RoleUser,
|
||||
})
|
||||
|
||||
active := false
|
||||
s.Update(u.ID, userstore.UpdateUserRequest{Active: &active})
|
||||
|
||||
if _, err := s.VerifyPassword("dave", "pw"); err == nil {
|
||||
t.Error("disabled user should not be able to login")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
u, _ := s.Create(userstore.CreateUserRequest{
|
||||
Username: "eve", Email: "eve@x.com",
|
||||
Password: "pw", Role: userstore.RoleUser,
|
||||
})
|
||||
|
||||
if err := s.Delete(u.ID); err != nil {
|
||||
t.Fatalf("Delete: %v", err)
|
||||
}
|
||||
if _, err := s.GetByID(u.ID); err == nil {
|
||||
t.Error("GetByID should error after delete")
|
||||
}
|
||||
// Delete non-existent should error
|
||||
if err := s.Delete(u.ID); err == nil {
|
||||
t.Error("second delete should return error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUsers(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
users := []userstore.CreateUserRequest{
|
||||
{Username: "u1", Email: "u1@x.com", Password: "pw", Role: userstore.RoleUser},
|
||||
{Username: "u2", Email: "u2@x.com", Password: "pw", Role: userstore.RoleAdmin},
|
||||
{Username: "u3", Email: "u3@x.com", Password: "pw", Role: userstore.RoleAuditor},
|
||||
{Username: "u4", Email: "u4@x.com", Password: "pw", Role: userstore.RoleUser},
|
||||
}
|
||||
for _, req := range users {
|
||||
s.Create(req)
|
||||
}
|
||||
|
||||
all, err := s.List("")
|
||||
if err != nil {
|
||||
t.Fatalf("List all: %v", err)
|
||||
}
|
||||
if len(all) != 4 {
|
||||
t.Errorf("List all: got %d, want 4", len(all))
|
||||
}
|
||||
|
||||
admins, _ := s.List(userstore.RoleAdmin)
|
||||
if len(admins) != 1 {
|
||||
t.Errorf("List admin: got %d, want 1", len(admins))
|
||||
}
|
||||
|
||||
regular, _ := s.List(userstore.RoleUser)
|
||||
if len(regular) != 2 {
|
||||
t.Errorf("List user: got %d, want 2", len(regular))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenBlacklist(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
jti := "test-jti-12345"
|
||||
expires := time.Now().Add(1 * time.Hour)
|
||||
|
||||
if err := s.BlacklistToken(jti, expires); err != nil {
|
||||
t.Fatalf("BlacklistToken: %v", err)
|
||||
}
|
||||
|
||||
blacklisted, err := s.IsBlacklisted(jti)
|
||||
if err != nil {
|
||||
t.Fatalf("IsBlacklisted: %v", err)
|
||||
}
|
||||
if !blacklisted {
|
||||
t.Error("token should be blacklisted")
|
||||
}
|
||||
|
||||
// Non-blacklisted token
|
||||
bl2, _ := s.IsBlacklisted("other-jti")
|
||||
if bl2 {
|
||||
t.Error("unknown token should not be blacklisted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanExpiredTokens(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
// Add an already-expired token
|
||||
s.BlacklistToken("expired-jti", time.Now().Add(-1*time.Hour))
|
||||
// Add a valid token
|
||||
s.BlacklistToken("valid-jti", time.Now().Add(1*time.Hour))
|
||||
|
||||
if err := s.CleanExpiredTokens(); err != nil {
|
||||
t.Fatalf("CleanExpiredTokens: %v", err)
|
||||
}
|
||||
|
||||
bl, _ := s.IsBlacklisted("expired-jti")
|
||||
if bl {
|
||||
t.Error("expired token should be cleaned up")
|
||||
}
|
||||
bl2, _ := s.IsBlacklisted("valid-jti")
|
||||
if !bl2 {
|
||||
t.Error("valid token should still be blacklisted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsertLDAPUser(t *testing.T) {
|
||||
s := newTestStore(t)
|
||||
|
||||
u, err := s.UpsertLDAPUser("ldapuser", "ldap@corp.com", userstore.RoleAuditor)
|
||||
if err != nil {
|
||||
t.Fatalf("UpsertLDAPUser: %v", err)
|
||||
}
|
||||
if u.Source != "ldap" {
|
||||
t.Errorf("Source = %q, want ldap", u.Source)
|
||||
}
|
||||
|
||||
// Second upsert should update, not duplicate
|
||||
u2, err := s.UpsertLDAPUser("ldapuser", "ldap@corp.com", userstore.RoleAuditor)
|
||||
if err != nil {
|
||||
t.Fatalf("UpsertLDAPUser second: %v", err)
|
||||
}
|
||||
if u2.ID != u.ID {
|
||||
t.Error("second upsert should not create a new record")
|
||||
}
|
||||
|
||||
all, _ := s.List("")
|
||||
if len(all) != 1 {
|
||||
t.Errorf("expected 1 user after double upsert, got %d", len(all))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user