feat(PROJ-17): Admin Dashboard Systemauslastung immer anzeigen

- Systemauslastungs-Sektion wird immer gerendert (nicht nur bei Erfolg)
- Fehlermeldung wenn /api/admin/system/stats nicht erreichbar ist
- Feature-Status auf In Review gesetzt

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
sysops
2026-03-14 11:43:19 +01:00
parent a893084a88
commit d360c9a5ba
68 changed files with 11938 additions and 435 deletions
+269
View File
@@ -0,0 +1,269 @@
package api_test
import (
"bytes"
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/jackc/pgx/v5"
"github.com/archivmail/config"
"github.com/archivmail/internal/api"
"github.com/archivmail/internal/audit"
"github.com/archivmail/internal/auth"
"github.com/archivmail/internal/index"
"github.com/archivmail/internal/storage"
"github.com/archivmail/internal/userstore"
)
type testEnv struct {
server *api.Server
users *userstore.Store
store *storage.Store
idx index.Indexer
}
func newTestEnv(t *testing.T) *testEnv {
t.Helper()
dir := t.TempDir()
logger := slog.New(slog.NewTextHandler(os.Discard, nil))
store, err := storage.New(filepath.Join(dir, "store"))
if err != nil {
t.Fatal(err)
}
idx, err := index.New(filepath.Join(dir, "index"), 100, "xapian")
if err != nil {
t.Skip("xapian not available:", err)
}
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
t.Skip("TEST_DATABASE_URL not set — skipping (needs PostgreSQL)")
}
// Create isolated schemas for this test
schemaUsers := "apitest_users_" + strings.ToLower(strings.ReplaceAll(t.Name(), "/", "_"))
schemaAudit := "apitest_audit_" + strings.ToLower(strings.ReplaceAll(t.Name(), "/", "_"))
if len(schemaUsers) > 63 {
schemaUsers = schemaUsers[:63]
}
if len(schemaAudit) > 63 {
schemaAudit = schemaAudit[:63]
}
ctx := context.Background()
conn, err := pgx.Connect(ctx, dsn)
if err != nil {
t.Fatalf("connect: %v", err)
}
conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schemaUsers)
conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schemaAudit)
conn.Close(ctx)
sep := "?"
if strings.Contains(dsn, "?") {
sep = "&"
}
usersDSN := dsn + sep + "search_path=" + schemaUsers
auditDSN := dsn + sep + "search_path=" + schemaAudit
users, err := userstore.New(usersDSN)
if err != nil {
t.Fatalf("userstore.New: %v", err)
}
audlog, err := audit.New(auditDSN, dir, logger)
if err != nil {
t.Fatalf("audit.New: %v", err)
}
// Seed users
users.Create(userstore.CreateUserRequest{Username: "admin", Email: "admin@x.com", Password: "adminpass", Role: userstore.RoleAdmin})
users.Create(userstore.CreateUserRequest{Username: "auditor", Email: "auditor@x.com", Password: "auditorpass", Role: userstore.RoleAuditor})
users.Create(userstore.CreateUserRequest{Username: "user1", Email: "user1@x.com", Password: "userpass", Role: userstore.RoleUser})
authMgr := auth.New(users, nil, "test-secret-must-be-long-enough-32")
cfg := config.APIConfig{Bind: ":18080", Secret: "test-secret-must-be-long-enough-32"}
srv := api.New(cfg, store, idx, authMgr, users, audlog, logger)
t.Cleanup(func() {
idx.Close()
users.Close()
audlog.Close()
conn2, _ := pgx.Connect(context.Background(), dsn)
if conn2 != nil {
conn2.Exec(context.Background(), "DROP SCHEMA "+schemaUsers+" CASCADE")
conn2.Exec(context.Background(), "DROP SCHEMA "+schemaAudit+" CASCADE")
conn2.Close(context.Background())
}
})
return &testEnv{server: srv, users: users, store: store, idx: idx}
}
func (e *testEnv) do(t *testing.T, method, path string, body interface{}, token string) *httptest.ResponseRecorder {
t.Helper()
var buf bytes.Buffer
if body != nil {
json.NewEncoder(&buf).Encode(body)
}
req := httptest.NewRequest(method, path, &buf)
req.Header.Set("Content-Type", "application/json")
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
w := httptest.NewRecorder()
e.server.ServeHTTP(w, req)
return w
}
func (e *testEnv) login(t *testing.T, username, password string) string {
t.Helper()
w := e.do(t, "POST", "/api/auth/login",
map[string]string{"username": username, "password": password}, "")
if w.Code != 200 {
t.Fatalf("login %s: status %d, body: %s", username, w.Code, w.Body.String())
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
return resp["token"].(string)
}
// ---- Tests ----
func TestHealth(t *testing.T) {
env := newTestEnv(t)
w := env.do(t, "GET", "/api/health", nil, "")
if w.Code != 200 {
t.Errorf("health: status %d", w.Code)
}
}
func TestLoginAndMe(t *testing.T) {
env := newTestEnv(t)
token := env.login(t, "admin", "adminpass")
w := env.do(t, "GET", "/api/auth/me", nil, token)
if w.Code != 200 {
t.Fatalf("me: status %d", w.Code)
}
var resp map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &resp)
if resp["username"] != "admin" {
t.Errorf("me username = %q", resp["username"])
}
if resp["role"] != "admin" {
t.Errorf("me role = %q", resp["role"])
}
}
func TestLoginWrongCredentials(t *testing.T) {
env := newTestEnv(t)
w := env.do(t, "POST", "/api/auth/login",
map[string]string{"username": "admin", "password": "wrong"}, "")
if w.Code != 401 {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestUnauthenticatedSearchBlocked(t *testing.T) {
env := newTestEnv(t)
w := env.do(t, "GET", "/api/search?q=test", nil, "")
if w.Code != 401 {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestLogout(t *testing.T) {
env := newTestEnv(t)
token := env.login(t, "admin", "adminpass")
w := env.do(t, "POST", "/api/auth/logout", nil, token)
if w.Code != 200 {
t.Fatalf("logout: status %d", w.Code)
}
// Token should now be invalid
w2 := env.do(t, "GET", "/api/auth/me", nil, token)
if w2.Code != 401 {
t.Errorf("after logout, me should return 401, got %d", w2.Code)
}
}
func TestAdminUserCRUD(t *testing.T) {
env := newTestEnv(t)
token := env.login(t, "admin", "adminpass")
// List users
w := env.do(t, "GET", "/api/users", nil, token)
if w.Code != 200 {
t.Fatalf("list users: status %d", w.Code)
}
var users []map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &users)
if len(users) != 3 { // admin + auditor + user1
t.Errorf("expected 3 users, got %d", len(users))
}
// Create user
w = env.do(t, "POST", "/api/users",
map[string]string{"username": "newuser", "email": "new@x.com", "password": "pw123", "role": "user"},
token)
if w.Code != 201 {
t.Fatalf("create user: status %d, body: %s", w.Code, w.Body.String())
}
}
func TestNonAdminCannotManageUsers(t *testing.T) {
env := newTestEnv(t)
token := env.login(t, "user1", "userpass")
w := env.do(t, "GET", "/api/users", nil, token)
if w.Code != 403 {
t.Errorf("user role should not list users, got %d", w.Code)
}
}
func TestAuditorCanAccessAuditLog(t *testing.T) {
env := newTestEnv(t)
token := env.login(t, "auditor", "auditorpass")
w := env.do(t, "GET", "/api/audit", nil, token)
if w.Code != 200 {
t.Errorf("auditor should access audit log, got %d", w.Code)
}
}
func TestUserCannotAccessAuditLog(t *testing.T) {
env := newTestEnv(t)
token := env.login(t, "user1", "userpass")
w := env.do(t, "GET", "/api/audit", nil, token)
if w.Code != 403 {
t.Errorf("user role should not access audit log, got %d", w.Code)
}
}
func TestSearchReturnsResults(t *testing.T) {
env := newTestEnv(t)
token := env.login(t, "admin", "adminpass")
w := env.do(t, "GET", "/api/search?q=test", nil, token)
if w.Code != 200 {
t.Fatalf("search: status %d, body: %s", w.Code, w.Body.String())
}
var result map[string]interface{}
json.Unmarshal(w.Body.Bytes(), &result)
if _, ok := result["total"]; !ok {
t.Error("search response missing 'total' field")
}
}
File diff suppressed because it is too large Load Diff
+196
View File
@@ -0,0 +1,196 @@
package audit
import (
"context"
"fmt"
"log/slog"
"strings"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
const (
EventLogin = "login"
EventLogout = "logout"
EventSearch = "search"
EventMailView = "mail_view"
EventImport = "import"
EventExport = "export"
EventUserMgmt = "user_mgmt"
)
// Entry is a single audit log record.
type Entry struct {
ID int64 `json:"id"`
Timestamp time.Time `json:"timestamp"`
EventType string `json:"event_type"`
Username string `json:"username"`
IPAddress string `json:"ip_address"`
Query string `json:"query"`
MailID string `json:"mail_id"`
Success bool `json:"success"`
Detail string `json:"detail"`
}
// QueryFilter specifies filtering options for audit log queries.
type QueryFilter struct {
Username string
EventType string
MailID string
From *time.Time
To *time.Time
PageSize int
Page int
}
// Logger is a PostgreSQL-backed, append-only audit log.
type Logger struct {
pool *pgxpool.Pool
logger *slog.Logger
}
// New connects to PostgreSQL using the given DSN and initialises the schema.
// logDir is reserved for future flat-file logging.
func New(dsn, logDir string, logger *slog.Logger) (*Logger, error) {
ctx := context.Background()
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
return nil, fmt.Errorf("audit: connect: %w", err)
}
_, err = pool.Exec(ctx, `
CREATE TABLE IF NOT EXISTS audit_log (
id BIGSERIAL PRIMARY KEY,
timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
event_type VARCHAR(50) NOT NULL,
username VARCHAR(255) NOT NULL DEFAULT '',
ip_address VARCHAR(45) NOT NULL DEFAULT '',
query TEXT NOT NULL DEFAULT '',
mail_id VARCHAR(255) NOT NULL DEFAULT '',
success BOOLEAN NOT NULL DEFAULT true,
detail TEXT NOT NULL DEFAULT ''
);
`)
if err != nil {
pool.Close()
return nil, fmt.Errorf("audit: create schema: %w", err)
}
return &Logger{pool: pool, logger: logger}, nil
}
// Log appends an entry to the audit log. Errors are logged but not returned.
func (l *Logger) Log(entry Entry) {
ts := entry.Timestamp
if ts.IsZero() {
ts = time.Now().UTC()
}
ctx := context.Background()
_, err := l.pool.Exec(ctx,
`INSERT INTO audit_log (timestamp, event_type, username, ip_address, query, mail_id, success, detail)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
ts.UTC(),
entry.EventType,
entry.Username,
entry.IPAddress,
entry.Query,
entry.MailID,
entry.Success,
entry.Detail,
)
if err != nil {
l.logger.Error("audit: insert failed", "err", err)
}
}
// Query retrieves audit entries matching the given filter, returning the
// matched entries, the total count (ignoring pagination), and any error.
func (l *Logger) Query(filter QueryFilter) ([]Entry, int, error) {
pageSize := filter.PageSize
if pageSize <= 0 {
pageSize = 50
}
where, args := buildWhere(filter)
ctx := context.Background()
// Count total
countSQL := "SELECT COUNT(*) FROM audit_log" + where
var total int
if err := l.pool.QueryRow(ctx, countSQL, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("audit: count: %w", err)
}
offset := filter.Page * pageSize
// Append limit and offset as next positional args
limitArg := len(args) + 1
offsetArg := len(args) + 2
querySQL := fmt.Sprintf(
"SELECT id, timestamp, event_type, username, ip_address, query, mail_id, success, detail FROM audit_log%s ORDER BY timestamp DESC LIMIT $%d OFFSET $%d",
where, limitArg, offsetArg,
)
allArgs := append(args, pageSize, offset)
rows, err := l.pool.Query(ctx, querySQL, allArgs...)
if err != nil {
return nil, 0, fmt.Errorf("audit: query: %w", err)
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.ID, &e.Timestamp, &e.EventType, &e.Username, &e.IPAddress, &e.Query, &e.MailID, &e.Success, &e.Detail); err != nil {
return nil, 0, fmt.Errorf("audit: scan: %w", err)
}
entries = append(entries, e)
}
return entries, total, rows.Err()
}
// Close closes the audit connection pool.
func (l *Logger) Close() error {
l.pool.Close()
return nil
}
// buildWhere constructs a SQL WHERE clause from QueryFilter fields using
// positional parameters ($1, $2, ...) for PostgreSQL.
func buildWhere(f QueryFilter) (string, []interface{}) {
var clauses []string
var args []interface{}
n := 1
if f.Username != "" {
clauses = append(clauses, fmt.Sprintf("username = $%d", n))
args = append(args, f.Username)
n++
}
if f.EventType != "" {
clauses = append(clauses, fmt.Sprintf("event_type = $%d", n))
args = append(args, f.EventType)
n++
}
if f.MailID != "" {
clauses = append(clauses, fmt.Sprintf("mail_id = $%d", n))
args = append(args, f.MailID)
n++
}
if f.From != nil {
clauses = append(clauses, fmt.Sprintf("timestamp >= $%d", n))
args = append(args, f.From.UTC())
n++
}
if f.To != nil {
clauses = append(clauses, fmt.Sprintf("timestamp <= $%d", n))
args = append(args, f.To.UTC())
n++
}
if len(clauses) == 0 {
return "", args
}
return " WHERE " + strings.Join(clauses, " AND "), args
}
+179
View File
@@ -0,0 +1,179 @@
package audit_test
import (
"context"
"log/slog"
"os"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/archivmail/internal/audit"
)
func newTestAudit(t *testing.T) *audit.Logger {
t.Helper()
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
t.Skip("TEST_DATABASE_URL not set — skipping (needs PostgreSQL)")
}
schema := "autest_" + strings.ToLower(strings.ReplaceAll(t.Name(), "/", "_"))
// truncate schema name to 63 chars (PostgreSQL limit)
if len(schema) > 63 {
schema = schema[:63]
}
ctx := context.Background()
conn, err := pgx.Connect(ctx, dsn)
if err != nil {
t.Fatalf("connect: %v", err)
}
conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schema)
conn.Close(ctx)
sep := "?"
if strings.Contains(dsn, "?") {
sep = "&"
}
schemaDSN := dsn + sep + "search_path=" + schema
logger := slog.New(slog.NewTextHandler(os.Discard, nil))
l, err := audit.New(schemaDSN, t.TempDir(), logger)
if err != nil {
t.Fatalf("audit.New: %v", err)
}
t.Cleanup(func() {
l.Close()
conn2, _ := pgx.Connect(context.Background(), dsn)
if conn2 != nil {
conn2.Exec(context.Background(), "DROP SCHEMA "+schema+" CASCADE")
conn2.Close(context.Background())
}
})
return l
}
func TestLogAndQuery(t *testing.T) {
l := newTestAudit(t)
l.Log(audit.Entry{
EventType: audit.EventLogin,
Username: "alice",
IPAddress: "192.168.1.1",
Success: true,
})
l.Log(audit.Entry{
EventType: audit.EventSearch,
Username: "alice",
IPAddress: "192.168.1.1",
Query: "invoice",
Success: true,
})
l.Log(audit.Entry{
EventType: audit.EventLogin,
Username: "bob",
IPAddress: "10.0.0.1",
Success: false,
Detail: "wrong password",
})
all, total, err := l.Query(audit.QueryFilter{PageSize: 50})
if err != nil {
t.Fatalf("Query all: %v", err)
}
if total != 3 {
t.Errorf("expected 3 entries, got %d", total)
}
_ = all
}
func TestQueryByUsername(t *testing.T) {
l := newTestAudit(t)
l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true})
l.Log(audit.Entry{EventType: audit.EventSearch, Username: "alice", Success: true})
l.Log(audit.Entry{EventType: audit.EventLogin, Username: "bob", Success: true})
entries, total, _ := l.Query(audit.QueryFilter{Username: "alice", PageSize: 50})
if total != 2 {
t.Errorf("alice: expected 2 entries, got %d", total)
}
for _, e := range entries {
if e.Username != "alice" {
t.Errorf("got entry for user %q in alice filter", e.Username)
}
}
}
func TestQueryByEventType(t *testing.T) {
l := newTestAudit(t)
l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true})
l.Log(audit.Entry{EventType: audit.EventSearch, Username: "alice", Success: true})
l.Log(audit.Entry{EventType: audit.EventMailView, Username: "alice", MailID: "abc123", Success: true})
_, total, _ := l.Query(audit.QueryFilter{EventType: audit.EventSearch, PageSize: 50})
if total != 1 {
t.Errorf("search event filter: expected 1, got %d", total)
}
}
func TestQueryByMailID(t *testing.T) {
l := newTestAudit(t)
l.Log(audit.Entry{EventType: audit.EventMailView, Username: "alice", MailID: "mail-001", Success: true})
l.Log(audit.Entry{EventType: audit.EventMailView, Username: "bob", MailID: "mail-001", Success: true})
l.Log(audit.Entry{EventType: audit.EventMailView, Username: "alice", MailID: "mail-002", Success: true})
_, total, _ := l.Query(audit.QueryFilter{MailID: "mail-001", PageSize: 50})
if total != 2 {
t.Errorf("mailID filter: expected 2, got %d", total)
}
}
func TestQueryDateRange(t *testing.T) {
l := newTestAudit(t)
l.Log(audit.Entry{EventType: audit.EventLogin, Username: "alice", Success: true})
l.Log(audit.Entry{EventType: audit.EventLogin, Username: "bob", Success: true})
// Query with future date range — should return 0
future := time.Now().Add(24 * time.Hour)
futureEnd := time.Now().Add(48 * time.Hour)
_, total, _ := l.Query(audit.QueryFilter{From: &future, To: &futureEnd, PageSize: 50})
if total != 0 {
t.Errorf("future date range should return 0, got %d", total)
}
// Query with past-to-now range — should return all
past := time.Now().Add(-1 * time.Minute)
now := time.Now().Add(1 * time.Minute)
_, total, _ = l.Query(audit.QueryFilter{From: &past, To: &now, PageSize: 50})
if total != 2 {
t.Errorf("current date range should return 2, got %d", total)
}
}
func TestQueryPagination(t *testing.T) {
l := newTestAudit(t)
for i := 0; i < 10; i++ {
l.Log(audit.Entry{EventType: audit.EventSearch, Username: "alice", Success: true})
}
page0, total, _ := l.Query(audit.QueryFilter{PageSize: 4, Page: 0})
_, _, _ = l.Query(audit.QueryFilter{PageSize: 4, Page: 1})
page2, _, _ := l.Query(audit.QueryFilter{PageSize: 4, Page: 2})
if total != 10 {
t.Errorf("total = %d, want 10", total)
}
if len(page0) != 4 {
t.Errorf("page 0 len = %d, want 4", len(page0))
}
if len(page2) != 2 {
t.Errorf("page 2 len = %d, want 2", len(page2))
}
}
+156
View File
@@ -0,0 +1,156 @@
package auth
import (
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/archivmail/internal/userstore"
)
// Session holds the claims extracted from a validated JWT.
type Session struct {
UserID int64
Username string
Role string
JTI string // unique JWT ID
}
// Manager handles login, token issuance, validation, and logout.
type Manager struct {
store *userstore.Store
ldap interface{} // placeholder for LDAP provider
jwtSecret []byte
}
// New creates a new auth Manager.
func New(store *userstore.Store, ldap interface{}, jwtSecret string) *Manager {
return &Manager{
store: store,
ldap: ldap,
jwtSecret: []byte(jwtSecret),
}
}
// Login verifies credentials and returns a signed JWT token.
func (m *Manager) Login(username, password string) (string, *userstore.User, error) {
user, err := m.store.VerifyPassword(username, password)
if err != nil {
return "", nil, fmt.Errorf("auth: login: %w", err)
}
jti := generateJTI()
now := time.Now()
claims := jwt.MapClaims{
"sub": user.Username,
"role": user.Role,
"uid": user.ID,
"jti": jti,
"iat": now.Unix(),
"exp": now.Add(8 * time.Hour).Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString(m.jwtSecret)
if err != nil {
return "", nil, fmt.Errorf("auth: sign token: %w", err)
}
return signed, user, nil
}
// ValidateToken parses and validates the token, checking the blacklist.
func (m *Manager) ValidateToken(tokenStr string) (*Session, error) {
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("auth: unexpected signing method: %v", t.Header["alg"])
}
return m.jwtSecret, nil
})
if err != nil {
return nil, fmt.Errorf("auth: invalid token: %w", err)
}
if !token.Valid {
return nil, errors.New("auth: token not valid")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("auth: bad claims")
}
jti, _ := claims["jti"].(string)
blacklisted, err := m.store.IsBlacklisted(jti)
if err != nil {
return nil, fmt.Errorf("auth: blacklist check: %w", err)
}
if blacklisted {
return nil, errors.New("auth: token revoked")
}
username, _ := claims["sub"].(string)
role, _ := claims["role"].(string)
var userID int64
switch v := claims["uid"].(type) {
case float64:
userID = int64(v)
case int64:
userID = v
}
return &Session{
UserID: userID,
Username: username,
Role: role,
JTI: jti,
}, nil
}
// Logout revokes the token by adding its JTI to the blacklist.
func (m *Manager) Logout(tokenStr string) error {
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("auth: unexpected signing method")
}
return m.jwtSecret, nil
})
if err != nil {
return fmt.Errorf("auth: logout parse: %w", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return errors.New("auth: bad claims on logout")
}
jti, _ := claims["jti"].(string)
var exp time.Time
switch v := claims["exp"].(type) {
case float64:
exp = time.Unix(int64(v), 0)
case int64:
exp = time.Unix(v, 0)
default:
exp = time.Now().Add(8 * time.Hour)
}
return m.store.BlacklistToken(jti, exp)
}
// HasRole returns true when userRole satisfies the required role level.
// Hierarchy: admin > auditor > user
func HasRole(userRole, required string) bool {
levels := map[string]int{
userstore.RoleUser: 1,
userstore.RoleAuditor: 2,
userstore.RoleAdmin: 3,
}
return levels[userRole] >= levels[required]
}
// generateJTI returns a pseudo-unique identifier for a JWT.
func generateJTI() string {
return fmt.Sprintf("%d-%x", time.Now().UnixNano(), time.Now().UnixNano()^0xdeadbeef)
}
+161
View File
@@ -0,0 +1,161 @@
package auth_test
import (
"path/filepath"
"testing"
"github.com/archivmail/internal/auth"
"github.com/archivmail/internal/userstore"
)
func newTestAuth(t *testing.T) (*auth.Manager, *userstore.Store) {
t.Helper()
store, err := userstore.New(filepath.Join(t.TempDir(), "users.db"))
if err != nil {
t.Fatalf("userstore.New: %v", err)
}
t.Cleanup(func() { store.Close() })
// Seed a test user
store.Create(userstore.CreateUserRequest{
Username: "testadmin",
Email: "admin@example.com",
Password: "adminpass",
Role: userstore.RoleAdmin,
})
store.Create(userstore.CreateUserRequest{
Username: "regularuser",
Email: "user@example.com",
Password: "userpass",
Role: userstore.RoleUser,
})
mgr := auth.New(store, nil, "test-jwt-secret-32chars-long-enough")
return mgr, store
}
func TestLoginSuccess(t *testing.T) {
mgr, _ := newTestAuth(t)
token, user, err := mgr.Login("testadmin", "adminpass")
if err != nil {
t.Fatalf("Login: %v", err)
}
if token == "" {
t.Error("expected non-empty token")
}
if user.Username != "testadmin" {
t.Errorf("Username = %q", user.Username)
}
if user.Role != userstore.RoleAdmin {
t.Errorf("Role = %q", user.Role)
}
}
func TestLoginWrongPassword(t *testing.T) {
mgr, _ := newTestAuth(t)
if _, _, err := mgr.Login("testadmin", "wrongpass"); err == nil {
t.Error("expected error for wrong password")
}
}
func TestLoginUnknownUser(t *testing.T) {
mgr, _ := newTestAuth(t)
if _, _, err := mgr.Login("nobody", "pw"); err == nil {
t.Error("expected error for unknown user")
}
}
func TestTokenValidation(t *testing.T) {
mgr, _ := newTestAuth(t)
token, _, _ := mgr.Login("testadmin", "adminpass")
sess, err := mgr.ValidateToken(token)
if err != nil {
t.Fatalf("ValidateToken: %v", err)
}
if sess.Username != "testadmin" {
t.Errorf("Session Username = %q", sess.Username)
}
if sess.Role != userstore.RoleAdmin {
t.Errorf("Session Role = %q", sess.Role)
}
if sess.JTI == "" {
t.Error("Session JTI should not be empty")
}
}
func TestTokenTampering(t *testing.T) {
mgr, _ := newTestAuth(t)
token, _, _ := mgr.Login("testadmin", "adminpass")
tampered := token + "x"
if _, err := mgr.ValidateToken(tampered); err == nil {
t.Error("tampered token should fail validation")
}
}
func TestLogout(t *testing.T) {
mgr, _ := newTestAuth(t)
token, _, _ := mgr.Login("testadmin", "adminpass")
// Token valid before logout
if _, err := mgr.ValidateToken(token); err != nil {
t.Fatalf("token should be valid before logout: %v", err)
}
if err := mgr.Logout(token); err != nil {
t.Fatalf("Logout: %v", err)
}
// Token invalid after logout
if _, err := mgr.ValidateToken(token); err == nil {
t.Error("token should be invalid after logout")
}
}
func TestHasRole(t *testing.T) {
tests := []struct {
userRole string
required string
want bool
}{
{userstore.RoleAdmin, userstore.RoleAdmin, true},
{userstore.RoleAdmin, userstore.RoleAuditor, true},
{userstore.RoleAdmin, userstore.RoleUser, true},
{userstore.RoleAuditor, userstore.RoleAdmin, false},
{userstore.RoleAuditor, userstore.RoleAuditor, true},
{userstore.RoleAuditor, userstore.RoleUser, true},
{userstore.RoleUser, userstore.RoleAdmin, false},
{userstore.RoleUser, userstore.RoleAuditor, false},
{userstore.RoleUser, userstore.RoleUser, true},
}
for _, tt := range tests {
got := auth.HasRole(tt.userRole, tt.required)
if got != tt.want {
t.Errorf("HasRole(%q, %q) = %v, want %v", tt.userRole, tt.required, got, tt.want)
}
}
}
func TestMultipleSessionsIndependent(t *testing.T) {
mgr, _ := newTestAuth(t)
token1, _, _ := mgr.Login("testadmin", "adminpass")
token2, _, _ := mgr.Login("testadmin", "adminpass")
if token1 == token2 {
t.Error("two logins should produce different tokens (different JTIs)")
}
// Logout session 1, session 2 should still work
mgr.Logout(token1)
if _, err := mgr.ValidateToken(token2); err != nil {
t.Errorf("session 2 should still be valid after session 1 logout: %v", err)
}
}
+99
View File
@@ -0,0 +1,99 @@
package imap
import (
"crypto/tls"
"fmt"
"strings"
imapv2 "github.com/emersion/go-imap/v2"
"github.com/emersion/go-imap/v2/imapclient"
)
// FolderInfo describes a single IMAP folder with exclusion metadata.
type FolderInfo struct {
Name string `json:"name"`
Excluded bool `json:"excluded"`
Reason string `json:"reason,omitempty"`
}
// junkTrashNames lists well-known junk/trash folder names for fallback detection.
var junkTrashNames = []string{
"junk", "spam", "trash", "deleted items",
"deleted messages", "papierkorb", "gelöschte elemente",
}
// Connect establishes an IMAP client connection using the specified TLS mode.
func Connect(host string, port int, tlsMode string) (*imapclient.Client, error) {
addr := fmt.Sprintf("%s:%d", host, port)
switch tlsMode {
case "ssl":
c, err := imapclient.DialTLS(addr, &imapclient.Options{
TLSConfig: &tls.Config{ServerName: host},
})
if err != nil {
return nil, fmt.Errorf("imap connect ssl: %w", err)
}
return c, nil
case "starttls":
c, err := imapclient.DialStartTLS(addr, &imapclient.Options{
TLSConfig: &tls.Config{ServerName: host},
})
if err != nil {
return nil, fmt.Errorf("imap connect starttls: %w", err)
}
return c, nil
case "none":
c, err := imapclient.DialInsecure(addr, nil)
if err != nil {
return nil, fmt.Errorf("imap connect plain: %w", err)
}
return c, nil
default:
return nil, fmt.Errorf("imap: unknown tls mode %q", tlsMode)
}
}
// ListFolders retrieves all mailbox folders and detects junk/trash folders.
func ListFolders(c *imapclient.Client) ([]FolderInfo, error) {
listCmd := c.List("", "*", nil)
mailboxes, err := listCmd.Collect()
if err != nil {
return nil, fmt.Errorf("imap list folders: %w", err)
}
var folders []FolderInfo
for _, mb := range mailboxes {
fi := FolderInfo{Name: mb.Mailbox}
// Check special-use attributes (RFC 6154)
for _, attr := range mb.Attrs {
if attr == imapv2.MailboxAttrJunk {
fi.Excluded = true
fi.Reason = "special_use"
break
}
if attr == imapv2.MailboxAttrTrash {
fi.Excluded = true
fi.Reason = "special_use"
break
}
}
// Fallback: case-insensitive name matching
if !fi.Excluded {
lower := strings.ToLower(mb.Mailbox)
for _, jt := range junkTrashNames {
if lower == jt {
fi.Excluded = true
fi.Reason = "name_match"
break
}
}
}
folders = append(folders, fi)
}
return folders, nil
}
+272
View File
@@ -0,0 +1,272 @@
package imap
import (
"context"
"fmt"
"io"
"log/slog"
"strings"
"time"
imapv2 "github.com/emersion/go-imap/v2"
"github.com/emersion/go-imap/v2/imapclient"
"github.com/archivmail/internal/index"
"github.com/archivmail/internal/storage"
"github.com/archivmail/pkg/mailparser"
)
const batchSize = 50
// Importer runs background IMAP import jobs.
type Importer struct {
store *Store
mailStore *storage.Store
idx index.Indexer
logger *slog.Logger
}
// NewImporter creates a new Importer wired to the storage and index backends.
func NewImporter(store *Store, mailStore *storage.Store, idx index.Indexer, logger *slog.Logger) *Importer {
return &Importer{
store: store,
mailStore: mailStore,
idx: idx,
logger: logger,
}
}
// Run performs a full IMAP import for the given account. It is designed to be
// called as a goroutine: go imp.Run(context.Background(), accountID)
func (imp *Importer) Run(ctx context.Context, accountID int64) {
log := imp.logger.With("component", "imap-importer", "account_id", accountID)
acc, err := imp.store.Get(ctx, accountID)
if err != nil {
log.Error("failed to get account", "err", err)
return
}
password, err := imp.store.GetPassword(ctx, accountID)
if err != nil {
log.Error("failed to decrypt password", "err", err)
_ = imp.store.UpdateStatus(ctx, accountID, "error", "failed to decrypt password", 0, 0)
return
}
// Mark as running
if err := imp.store.UpdateStatus(ctx, accountID, "running", "", 0, 0); err != nil {
log.Error("failed to update status", "err", err)
return
}
imported, err := imp.doImport(ctx, acc, password, log)
if err != nil {
log.Error("import failed", "err", err)
_ = imp.store.UpdateStatus(ctx, accountID, "error", err.Error(), 0, 0)
return
}
if err := imp.store.UpdateDone(ctx, accountID, imported); err != nil {
log.Error("failed to update done", "err", err)
}
log.Info("import completed", "imported", imported)
}
// doImport handles the actual IMAP connection, folder iteration, and message fetching.
func (imp *Importer) doImport(ctx context.Context, acc *Account, password string, log *slog.Logger) (int, error) {
c, err := Connect(acc.Host, acc.Port, acc.TLS)
if err != nil {
return 0, fmt.Errorf("connect: %w", err)
}
defer c.Close()
// Login
if err := c.Login(acc.Username, password).Wait(); err != nil {
return 0, fmt.Errorf("login: %w", err)
}
// List all folders
folders, err := ListFolders(c)
if err != nil {
return 0, fmt.Errorf("list folders: %w", err)
}
// Build excluded set from account config
excluded := make(map[string]bool)
for _, f := range acc.ExcludedFolders {
excluded[f] = true
}
// Collect included folders
var includedFolders []string
for _, f := range folders {
if !excluded[f.Name] {
includedFolders = append(includedFolders, f.Name)
}
}
// Count total messages across all folders first
totalMsgs := 0
folderUIDs := make(map[string][]imapv2.UID)
for _, folder := range includedFolders {
selectData, err := c.Select(folder, nil).Wait()
if err != nil {
log.Warn("failed to select folder, skipping", "folder", folder, "err", err)
continue
}
_ = selectData
searchCmd := c.UIDSearch(&imapv2.SearchCriteria{}, nil)
searchData, err := searchCmd.Wait()
if err != nil {
log.Warn("failed to search folder, skipping", "folder", folder, "err", err)
continue
}
uids := searchData.AllUIDs()
folderUIDs[folder] = uids
totalMsgs += len(uids)
}
log.Info("starting import", "folders", len(includedFolders), "total_messages", totalMsgs)
_ = imp.store.UpdateStatus(ctx, acc.ID, "running", "", 0, totalMsgs)
imported := 0
processed := 0
for _, folder := range includedFolders {
uids, ok := folderUIDs[folder]
if !ok || len(uids) == 0 {
continue
}
// Need to re-select the folder before fetching
if _, err := c.Select(folder, nil).Wait(); err != nil {
log.Warn("failed to re-select folder", "folder", folder, "err", err)
continue
}
log.Info("importing folder", "folder", folder, "messages", len(uids))
// Process in batches
for i := 0; i < len(uids); i += batchSize {
end := i + batchSize
if end > len(uids) {
end = len(uids)
}
batch := uids[i:end]
count, err := imp.fetchBatch(ctx, c, batch, log)
if err != nil {
log.Error("batch fetch error", "folder", folder, "offset", i, "err", err)
// Continue with the next batch rather than aborting entirely
continue
}
imported += count
processed += len(batch)
_ = imp.store.UpdateStatus(ctx, acc.ID, "running", "", processed, totalMsgs)
}
}
return imported, nil
}
// fetchBatch fetches and stores a batch of messages by UID.
func (imp *Importer) fetchBatch(ctx context.Context, c *imapclient.Client, uids []imapv2.UID, log *slog.Logger) (int, error) {
if len(uids) == 0 {
return 0, nil
}
fetchOptions := &imapv2.FetchOptions{
UID: true,
BodySection: []*imapv2.FetchItemBodySection{{}},
}
seqSet := imapv2.UIDSetNum(uids...)
fetchCmd := c.Fetch(seqSet, fetchOptions)
imported := 0
for {
msg := fetchCmd.Next()
if msg == nil {
break
}
// Collect body sections from this message
for {
item := msg.Next()
if item == nil {
break
}
switch body := item.(type) {
case imapclient.FetchItemDataBodySection:
raw, err := io.ReadAll(body.Literal)
if err != nil {
log.Warn("failed to read message body", "err", err)
continue
}
if err := imp.storeAndIndex(raw, log); err != nil {
log.Warn("failed to store/index message", "err", err)
continue
}
imported++
}
}
}
if err := fetchCmd.Close(); err != nil {
return imported, fmt.Errorf("fetch close: %w", err)
}
return imported, nil
}
// storeAndIndex saves a raw email to storage and indexes it.
func (imp *Importer) storeAndIndex(raw []byte, log *slog.Logger) error {
// Save to file storage (deduplicates by SHA256 automatically)
id, err := imp.mailStore.Save(raw, time.Now())
if err != nil {
return fmt.Errorf("save: %w", err)
}
// Parse for indexing
pm, err := mailparser.Parse(raw)
if err != nil {
log.Warn("failed to parse mail for indexing", "id", id, "err", err)
// Store succeeded, just skip indexing for unparseable mails
return nil
}
// Build attachment names string
var attachNames []string
for _, a := range pm.Attachments {
if a.Filename != "" {
attachNames = append(attachNames, a.Filename)
}
}
doc := index.MailDocument{
ID: id,
From: pm.From,
To: strings.Join(pm.To, ", "),
Subject: pm.Subject,
Body: pm.TextBody,
AttachNames: strings.Join(attachNames, " "),
HasAttachment: len(pm.Attachments) > 0,
Date: pm.Date,
Size: int64(len(raw)),
}
if err := imp.idx.IndexSync(doc); err != nil {
log.Warn("failed to index mail", "id", id, "err", err)
// Non-fatal: mail is stored, just not searchable yet
}
return nil
}
+259
View File
@@ -0,0 +1,259 @@
package imap
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
// Account represents an IMAP account configuration stored in the database.
type Account struct {
ID int64 `json:"id"`
Owner string `json:"owner"`
Name string `json:"name"`
Host string `json:"host"`
Port int `json:"port"`
TLS string `json:"tls"`
Username string `json:"username"`
ExcludedFolders []string `json:"excluded_folders"`
Status string `json:"status"`
ErrorMsg string `json:"error_msg"`
LastImportAt *time.Time `json:"last_import_at,omitempty"`
LastImportCount int `json:"last_import_count"`
ProgressCurrent int `json:"progress_current"`
ProgressTotal int `json:"progress_total"`
CreatedAt time.Time `json:"created_at"`
}
// Store manages IMAP account persistence in PostgreSQL.
type Store struct {
pool *pgxpool.Pool
encKey [32]byte
}
const createTableSQL = `
CREATE TABLE IF NOT EXISTS imap_accounts (
id SERIAL PRIMARY KEY,
owner TEXT NOT NULL,
name TEXT NOT NULL,
host TEXT NOT NULL,
port INTEGER NOT NULL DEFAULT 993,
tls TEXT NOT NULL DEFAULT 'ssl',
username TEXT NOT NULL,
password_enc BYTEA NOT NULL,
excluded_folders TEXT[] NOT NULL DEFAULT '{}',
status TEXT NOT NULL DEFAULT 'idle',
error_msg TEXT NOT NULL DEFAULT '',
last_import_at TIMESTAMPTZ,
last_import_count INTEGER NOT NULL DEFAULT 0,
progress_current INTEGER NOT NULL DEFAULT 0,
progress_total INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_imap_accounts_owner ON imap_accounts (owner);
`
// New creates a new Store, connects to PostgreSQL, and runs the migration.
func New(dsn, secret string) (*Store, error) {
pool, err := pgxpool.New(context.Background(), dsn)
if err != nil {
return nil, fmt.Errorf("imap store: connect: %w", err)
}
if _, err := pool.Exec(context.Background(), createTableSQL); err != nil {
pool.Close()
return nil, fmt.Errorf("imap store: migrate: %w", err)
}
key := sha256.Sum256([]byte(secret))
return &Store{pool: pool, encKey: key}, nil
}
// Close releases the database connection pool.
func (s *Store) Close() {
s.pool.Close()
}
// Create inserts a new IMAP account with an encrypted password.
func (s *Store) Create(ctx context.Context, acc Account, password string) (*Account, error) {
enc, err := encryptPassword(password, s.encKey)
if err != nil {
return nil, fmt.Errorf("imap store: encrypt password: %w", err)
}
row := s.pool.QueryRow(ctx, `
INSERT INTO imap_accounts (owner, name, host, port, tls, username, password_enc, excluded_folders)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, created_at`,
acc.Owner, acc.Name, acc.Host, acc.Port, acc.TLS, acc.Username, enc, acc.ExcludedFolders,
)
if err := row.Scan(&acc.ID, &acc.CreatedAt); err != nil {
return nil, fmt.Errorf("imap store: create: %w", err)
}
acc.Status = "idle"
acc.ErrorMsg = ""
return &acc, nil
}
// List returns IMAP accounts. Admins see all accounts; regular users see only their own.
func (s *Store) List(ctx context.Context, owner string, isAdmin bool) ([]Account, error) {
var rows pgx.Rows
var err error
if isAdmin {
rows, err = s.pool.Query(ctx, `
SELECT id, owner, name, host, port, tls, username, excluded_folders,
status, error_msg, last_import_at, last_import_count,
progress_current, progress_total, created_at
FROM imap_accounts ORDER BY id`)
} else {
rows, err = s.pool.Query(ctx, `
SELECT id, owner, name, host, port, tls, username, excluded_folders,
status, error_msg, last_import_at, last_import_count,
progress_current, progress_total, created_at
FROM imap_accounts WHERE owner = $1 ORDER BY id`, owner)
}
if err != nil {
return nil, fmt.Errorf("imap store: list: %w", err)
}
defer rows.Close()
var accounts []Account
for rows.Next() {
var a Account
if err := rows.Scan(
&a.ID, &a.Owner, &a.Name, &a.Host, &a.Port, &a.TLS, &a.Username,
&a.ExcludedFolders, &a.Status, &a.ErrorMsg, &a.LastImportAt,
&a.LastImportCount, &a.ProgressCurrent, &a.ProgressTotal, &a.CreatedAt,
); err != nil {
return nil, fmt.Errorf("imap store: scan: %w", err)
}
accounts = append(accounts, a)
}
return accounts, rows.Err()
}
// Get returns a single IMAP account by ID.
func (s *Store) Get(ctx context.Context, id int64) (*Account, error) {
var a Account
err := s.pool.QueryRow(ctx, `
SELECT id, owner, name, host, port, tls, username, excluded_folders,
status, error_msg, last_import_at, last_import_count,
progress_current, progress_total, created_at
FROM imap_accounts WHERE id = $1`, id,
).Scan(
&a.ID, &a.Owner, &a.Name, &a.Host, &a.Port, &a.TLS, &a.Username,
&a.ExcludedFolders, &a.Status, &a.ErrorMsg, &a.LastImportAt,
&a.LastImportCount, &a.ProgressCurrent, &a.ProgressTotal, &a.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("imap store: get %d: %w", id, err)
}
return &a, nil
}
// GetPassword retrieves and decrypts the stored password for an IMAP account.
func (s *Store) GetPassword(ctx context.Context, id int64) (string, error) {
var enc []byte
err := s.pool.QueryRow(ctx, `SELECT password_enc FROM imap_accounts WHERE id = $1`, id).Scan(&enc)
if err != nil {
return "", fmt.Errorf("imap store: get password: %w", err)
}
return decryptPassword(enc, s.encKey)
}
// Delete removes an IMAP account by ID.
func (s *Store) Delete(ctx context.Context, id int64) error {
tag, err := s.pool.Exec(ctx, `DELETE FROM imap_accounts WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("imap store: delete: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("imap store: account %d not found", id)
}
return nil
}
// UpdateExcluded sets the list of excluded folders for an account.
func (s *Store) UpdateExcluded(ctx context.Context, id int64, excluded []string) error {
_, err := s.pool.Exec(ctx, `UPDATE imap_accounts SET excluded_folders = $1 WHERE id = $2`, excluded, id)
if err != nil {
return fmt.Errorf("imap store: update excluded: %w", err)
}
return nil
}
// UpdateStatus updates the import progress and status of an account.
func (s *Store) UpdateStatus(ctx context.Context, id int64, status, errMsg string, current, total int) error {
_, err := s.pool.Exec(ctx, `
UPDATE imap_accounts
SET status = $1, error_msg = $2, progress_current = $3, progress_total = $4
WHERE id = $5`, status, errMsg, current, total, id)
if err != nil {
return fmt.Errorf("imap store: update status: %w", err)
}
return nil
}
// UpdateDone marks an import as completed, setting status back to idle.
func (s *Store) UpdateDone(ctx context.Context, id int64, count int) error {
_, err := s.pool.Exec(ctx, `
UPDATE imap_accounts
SET status = 'idle', error_msg = '', last_import_at = now(),
last_import_count = $1, progress_current = 0, progress_total = 0
WHERE id = $2`, count, id)
if err != nil {
return fmt.Errorf("imap store: update done: %w", err)
}
return nil
}
// encryptPassword encrypts a plaintext password using AES-256-GCM.
func encryptPassword(plaintext string, key [32]byte) ([]byte, error) {
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
return gcm.Seal(nonce, nonce, []byte(plaintext), nil), nil
}
// decryptPassword decrypts a password previously encrypted with encryptPassword.
func decryptPassword(ciphertext []byte, key [32]byte) (string, error) {
block, err := aes.NewCipher(key[:])
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonceSize := gcm.NonceSize()
if len(ciphertext) < nonceSize {
return "", fmt.Errorf("ciphertext too short")
}
nonce, ct := ciphertext[:nonceSize], ciphertext[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ct, nil)
if err != nil {
return "", fmt.Errorf("decrypt failed: %w", err)
}
return string(plaintext), nil
}
+61
View File
@@ -0,0 +1,61 @@
package index
import (
"fmt"
"time"
)
// MailDocument is the indexed representation of a stored email.
type MailDocument struct {
ID string
From string
To string
Subject string
Body string
AttachNames string
HasAttachment bool
Date time.Time
Size int64
}
// SearchRequest specifies search parameters.
type SearchRequest struct {
Query string
From string
To string
OwnEmail string
DateFrom *time.Time
DateTo *time.Time
PageSize int
Page int
}
// Hit is a single search result.
type Hit struct {
ID string `json:"id"`
Score float64 `json:"score"`
}
// SearchResult holds paginated search results.
type SearchResult struct {
Total int
Hits []Hit
}
// Indexer is the interface for full-text email indexing.
type Indexer interface {
IndexSync(doc MailDocument) error
Search(req SearchRequest) (*SearchResult, error)
Delete(id string) error
Close() error
}
// New creates an Indexer for the specified backend.
func New(dir string, batchSize int, backend string) (Indexer, error) {
switch backend {
case "xapian":
return newXapian(dir)
default:
return nil, fmt.Errorf("unknown index backend: %q (supported: xapian)", backend)
}
}
+192
View File
@@ -0,0 +1,192 @@
package index_test
import (
"testing"
"time"
"github.com/archivmail/internal/index"
)
// newXapianIndex creates a temporary Xapian index for testing.
func newXapianIndex(t *testing.T) index.Indexer {
t.Helper()
idx, err := index.New(t.TempDir(), 100, "xapian")
if err != nil {
t.Skip("xapian not available:", err)
}
t.Cleanup(func() { idx.Close() })
return idx
}
func seedDocs(t *testing.T, idx index.Indexer) {
t.Helper()
docs := []index.MailDocument{
{
ID: "aaa111",
From: "alice@example.com",
To: "bob@example.com",
Subject: "Invoice Q1-2026",
Body: "Please find attached the invoice for January.",
Date: time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC),
Size: 1024,
},
{
ID: "bbb222",
From: "bob@example.com",
To: "alice@example.com charlie@example.com",
Subject: "Meeting Agenda",
Body: "Agenda for the quarterly review meeting.",
Date: time.Date(2026, 2, 1, 9, 0, 0, 0, time.UTC),
Size: 512,
},
{
ID: "ccc333",
From: "charlie@example.com",
To: "alice@example.com",
Subject: "Offer with attachment",
Body: "Please review the attached offer document.",
AttachNames: "offer.pdf",
HasAttachment: true,
Date: time.Date(2026, 3, 1, 14, 0, 0, 0, time.UTC),
Size: 8192,
},
}
for _, d := range docs {
if err := idx.IndexSync(d); err != nil {
t.Fatalf("IndexSync %s: %v", d.ID, err)
}
}
}
func TestIndexAndSearchFulltext(t *testing.T) {
idx := newXapianIndex(t)
seedDocs(t, idx)
result, err := idx.Search(index.SearchRequest{Query: "invoice", PageSize: 10})
if err != nil {
t.Fatalf("Search: %v", err)
}
if result.Total == 0 {
t.Error("expected at least 1 hit for 'invoice'")
}
if result.Hits[0].ID != "aaa111" {
t.Errorf("top hit = %q, want aaa111", result.Hits[0].ID)
}
}
func TestSearchMatchAll(t *testing.T) {
idx := newXapianIndex(t)
seedDocs(t, idx)
result, err := idx.Search(index.SearchRequest{PageSize: 25})
if err != nil {
t.Fatalf("Search all: %v", err)
}
if result.Total != 3 {
t.Errorf("expected 3 total hits, got %d", result.Total)
}
}
func TestSearchFromFilter(t *testing.T) {
idx := newXapianIndex(t)
seedDocs(t, idx)
result, err := idx.Search(index.SearchRequest{
From: "alice@example.com",
PageSize: 25,
})
if err != nil {
t.Fatalf("Search from: %v", err)
}
if result.Total != 1 {
t.Errorf("expected 1 hit from alice, got %d", result.Total)
}
}
func TestSearchDateRange(t *testing.T) {
idx := newXapianIndex(t)
seedDocs(t, idx)
from := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
to := time.Date(2026, 2, 1, 23, 59, 59, 0, time.UTC)
result, err := idx.Search(index.SearchRequest{
DateFrom: &from,
DateTo: &to,
PageSize: 25,
})
if err != nil {
t.Fatalf("Search date range: %v", err)
}
if result.Total != 2 {
t.Errorf("expected 2 hits in Jan-Feb 2026, got %d", result.Total)
}
}
func TestSearchOwnEmail(t *testing.T) {
idx := newXapianIndex(t)
seedDocs(t, idx)
// charlie@example.com sent 1 mail and received 1 mail = should see 2
result, err := idx.Search(index.SearchRequest{
OwnEmail: "charlie@example.com",
PageSize: 25,
})
if err != nil {
t.Fatalf("Search OwnEmail: %v", err)
}
if result.Total < 1 {
t.Errorf("charlie should see at least 1 mail, got %d", result.Total)
}
}
func TestSearchPagination(t *testing.T) {
idx := newXapianIndex(t)
seedDocs(t, idx)
page0, _ := idx.Search(index.SearchRequest{PageSize: 2, Page: 0})
page1, _ := idx.Search(index.SearchRequest{PageSize: 2, Page: 1})
if len(page0.Hits) != 2 {
t.Errorf("page 0: expected 2 hits, got %d", len(page0.Hits))
}
if len(page1.Hits) != 1 {
t.Errorf("page 1: expected 1 hit, got %d", len(page1.Hits))
}
// No overlap
if page0.Hits[0].ID == page1.Hits[0].ID {
t.Error("pagination returned duplicate results")
}
}
func TestDelete(t *testing.T) {
idx := newXapianIndex(t)
seedDocs(t, idx)
if err := idx.Delete("aaa111"); err != nil {
t.Fatalf("Delete: %v", err)
}
result, _ := idx.Search(index.SearchRequest{Query: "invoice", PageSize: 10})
for _, h := range result.Hits {
if h.ID == "aaa111" {
t.Error("deleted document still in results")
}
}
}
func TestUnknownBackend(t *testing.T) {
_, err := index.New(t.TempDir(), 10, "elasticsearch")
if err == nil {
t.Error("expected error for unknown backend")
}
}
func TestXapianNotCompiledError(t *testing.T) {
_, err := index.New(t.TempDir(), 10, "xapian")
// Without -tags xapian this must return a helpful error
if err == nil {
t.Log("xapian compiled in — skipping stub error test")
} else {
t.Logf("xapian stub error (expected): %v", err)
}
}
+126
View File
@@ -0,0 +1,126 @@
//go:build xapian
package index
/*
#cgo pkg-config: xapian-core
#cgo LDFLAGS: -lstdc++
#include "xapian_wrapper.h"
#include <stdlib.h>
*/
import "C"
import (
"encoding/json"
"fmt"
"unsafe"
)
type xapianIndex struct {
db *C.XapianDB
}
func newXapian(dir string) (Indexer, error) {
cdir := C.CString(dir)
defer C.free(unsafe.Pointer(cdir))
var cerr *C.char
db := C.xapian_open(cdir, 1, &cerr)
if db == nil {
msg := C.GoString(cerr)
C.xapian_free_string(cerr)
return nil, fmt.Errorf("xapian open: %s", msg)
}
return &xapianIndex{db: db}, nil
}
func (x *xapianIndex) IndexSync(doc MailDocument) error {
cid := C.CString(doc.ID)
defer C.free(unsafe.Pointer(cid))
cfrom := C.CString(doc.From)
defer C.free(unsafe.Pointer(cfrom))
cto := C.CString(doc.To)
defer C.free(unsafe.Pointer(cto))
csubj := C.CString(doc.Subject)
defer C.free(unsafe.Pointer(csubj))
cbody := C.CString(doc.Body)
defer C.free(unsafe.Pointer(cbody))
var cerr *C.char
rc := C.xapian_index(x.db, cid, cfrom, cto, csubj, cbody, C.longlong(doc.Date.Unix()), &cerr)
if rc != 0 {
msg := C.GoString(cerr)
C.xapian_free_string(cerr)
return fmt.Errorf("xapian index: %s", msg)
}
return nil
}
func (x *xapianIndex) Delete(id string) error {
cid := C.CString(id)
defer C.free(unsafe.Pointer(cid))
var cerr *C.char
rc := C.xapian_delete(x.db, cid, &cerr)
if rc != 0 {
msg := C.GoString(cerr)
C.xapian_free_string(cerr)
return fmt.Errorf("xapian delete: %s", msg)
}
return nil
}
func (x *xapianIndex) Search(req SearchRequest) (*SearchResult, error) {
cquery := C.CString(req.Query)
defer C.free(unsafe.Pointer(cquery))
cfrom := C.CString(req.From)
defer C.free(unsafe.Pointer(cfrom))
cown := C.CString(req.OwnEmail)
defer C.free(unsafe.Pointer(cown))
cto := C.CString(req.To)
defer C.free(unsafe.Pointer(cto))
var dateFrom, dateTo C.longlong
if req.DateFrom != nil {
dateFrom = C.longlong(req.DateFrom.Unix())
}
if req.DateTo != nil {
dateTo = C.longlong(req.DateTo.Unix())
}
page := req.Page
if page < 1 {
page = 1
}
offset := C.int((page - 1) * req.PageSize)
limit := C.int(req.PageSize)
if limit <= 0 {
limit = 25
}
var cerr *C.char
cresult := C.xapian_search(x.db, cquery, cfrom, cown, cto, dateFrom, dateTo, offset, limit, &cerr)
if cresult == nil {
msg := C.GoString(cerr)
C.xapian_free_string(cerr)
return nil, fmt.Errorf("xapian search: %s", msg)
}
defer C.xapian_free_string(cresult)
jsonStr := C.GoString(cresult)
var raw struct {
Total int `json:"total"`
Hits []struct {
ID string `json:"id"`
Score float64 `json:"score"`
} `json:"hits"`
}
if err := json.Unmarshal([]byte(jsonStr), &raw); err != nil {
return nil, fmt.Errorf("xapian parse result: %w", err)
}
hits := make([]Hit, len(raw.Hits))
for i, h := range raw.Hits {
hits[i] = Hit{ID: h.ID, Score: h.Score}
}
return &SearchResult{Total: raw.Total, Hits: hits}, nil
}
func (x *xapianIndex) Close() error {
C.xapian_close(x.db)
return nil
}
+9
View File
@@ -0,0 +1,9 @@
//go:build !xapian
package index
import "errors"
func newXapian(dir string) (Indexer, error) {
return nil, errors.New("xapian: not compiled in — rebuild with: go build -tags xapian")
}
+199
View File
@@ -0,0 +1,199 @@
#include "xapian_wrapper.h"
#include <xapian.h>
#include <cstring>
#include <cstdlib>
#include <string>
#include <sstream>
#include <stdexcept>
struct XapianDB {
Xapian::WritableDatabase* wdb;
Xapian::Database* rdb;
bool writable;
};
static char* dup_error(const std::string& msg) {
char* s = (char*)malloc(msg.size() + 1);
if (s) memcpy(s, msg.c_str(), msg.size() + 1);
return s;
}
extern "C" {
XapianDB* xapian_open(const char* path, int writable, char** err) {
try {
XapianDB* db = new XapianDB{nullptr, nullptr, (bool)writable};
if (writable) {
db->wdb = new Xapian::WritableDatabase(path, Xapian::DB_CREATE_OR_OPEN);
} else {
db->rdb = new Xapian::Database(path);
}
return db;
} catch (const std::exception& e) {
if (err) *err = dup_error(e.what());
return nullptr;
}
}
void xapian_close(XapianDB* db) {
if (!db) return;
if (db->wdb) { db->wdb->close(); delete db->wdb; }
if (db->rdb) { db->rdb->close(); delete db->rdb; }
delete db;
}
int xapian_index(XapianDB* db, const char* id, const char* from,
const char* to, const char* subject, const char* body,
long long timestamp, char** err) {
try {
Xapian::Document doc;
Xapian::TermGenerator gen;
gen.set_document(doc);
gen.set_stemmer(Xapian::Stem("en"));
// Prefix-indexed fields for filtering
gen.index_text(from, 1, "XF");
gen.index_text(to, 1, "XT");
gen.index_text(subject, 1, "XS");
// Free-text indexed fields
gen.index_text(subject);
gen.increase_termpos();
gen.index_text(body);
gen.increase_termpos();
gen.index_text(from);
gen.increase_termpos();
gen.index_text(to);
// Store timestamp for date range queries (value slot 0)
doc.add_value(0, Xapian::sortable_serialise((double)timestamp));
// Store ID as document data
doc.set_data(id);
doc.add_boolean_term(std::string("Q") + id);
db->wdb->replace_document(std::string("Q") + id, doc);
db->wdb->commit();
return 0;
} catch (const std::exception& e) {
if (err) *err = dup_error(e.what());
return -1;
}
}
int xapian_delete(XapianDB* db, const char* id, char** err) {
try {
db->wdb->delete_document(std::string("Q") + id);
db->wdb->commit();
return 0;
} catch (const std::exception& e) {
if (err) *err = dup_error(e.what());
return -1;
}
}
char* xapian_search(XapianDB* db, const char* query_str,
const char* from_filter, const char* own_email,
const char* to_filter,
long long date_from, long long date_to,
int offset, int limit, char** err) {
try {
Xapian::Database& xdb = db->wdb ? (Xapian::Database&)*db->wdb : *db->rdb;
Xapian::Enquire enquire(xdb);
Xapian::Query main_query;
// Full-text query
if (query_str && query_str[0] != '\0') {
Xapian::QueryParser qp;
qp.set_database(xdb);
qp.set_stemmer(Xapian::Stem("en"));
qp.set_stemming_strategy(Xapian::QueryParser::STEM_SOME);
qp.add_prefix("from", "XF");
qp.add_prefix("to", "XT");
qp.add_prefix("subject", "XS");
main_query = qp.parse_query(query_str,
Xapian::QueryParser::FLAG_DEFAULT |
Xapian::QueryParser::FLAG_PARTIAL);
} else {
main_query = Xapian::Query::MatchAll;
}
// From filter
if (from_filter && from_filter[0] != '\0') {
Xapian::QueryParser qp;
qp.set_database(xdb);
Xapian::Query fq = qp.parse_query(from_filter,
Xapian::QueryParser::FLAG_DEFAULT, "XF");
main_query = Xapian::Query(Xapian::Query::OP_AND, main_query, fq);
}
// OwnEmail filter: (from=own OR to=own)
if (own_email && own_email[0] != '\0') {
Xapian::QueryParser qp;
qp.set_database(xdb);
Xapian::Query fq = qp.parse_query(own_email,
Xapian::QueryParser::FLAG_DEFAULT, "XF");
Xapian::Query tq = qp.parse_query(own_email,
Xapian::QueryParser::FLAG_DEFAULT, "XT");
Xapian::Query owq(Xapian::Query::OP_OR, fq, tq);
main_query = Xapian::Query(Xapian::Query::OP_AND, main_query, owq);
}
// To filter
if (to_filter && to_filter[0] != '\0') {
Xapian::QueryParser qp;
qp.set_database(xdb);
Xapian::Query tq = qp.parse_query(to_filter,
Xapian::QueryParser::FLAG_DEFAULT, "XT");
main_query = Xapian::Query(Xapian::Query::OP_AND, main_query, tq);
}
// Date range
if (date_from > 0 || date_to > 0) {
double lo = date_from > 0 ? (double)date_from : 0.0;
double hi = date_to > 0 ? (double)date_to : 1e18;
Xapian::Query drq(Xapian::Query::OP_VALUE_RANGE, 0,
Xapian::sortable_serialise(lo),
Xapian::sortable_serialise(hi));
main_query = Xapian::Query(Xapian::Query::OP_AND, main_query, drq);
}
enquire.set_query(main_query);
enquire.set_sort_by_value(0, true); // sort by date desc
// Get total count
Xapian::MSet all = enquire.get_mset(0, xdb.get_doccount());
int total = (int)all.get_matches_estimated();
// Get page
Xapian::MSet mset = enquire.get_mset(offset, limit);
std::ostringstream json;
json << "{\"total\":" << total << ",\"hits\":[";
bool first = true;
for (auto it = mset.begin(); it != mset.end(); ++it) {
if (!first) json << ",";
first = false;
std::string id = it.get_document().get_data();
double score = it.get_weight();
json << "{\"id\":\"" << id << "\",\"score\":" << score << "}";
}
json << "]}";
std::string result = json.str();
char* out = (char*)malloc(result.size() + 1);
memcpy(out, result.c_str(), result.size() + 1);
return out;
} catch (const std::exception& e) {
if (err) *err = dup_error(e.what());
return nullptr;
}
}
void xapian_free_string(char* s) {
free(s);
}
} // extern "C"
+32
View File
@@ -0,0 +1,32 @@
#ifndef XAPIAN_WRAPPER_H
#define XAPIAN_WRAPPER_H
#ifdef __cplusplus
extern "C" {
#endif
typedef struct XapianDB XapianDB;
XapianDB* xapian_open(const char* path, int writable, char** err);
void xapian_close(XapianDB* db);
int xapian_index(XapianDB* db, const char* id, const char* from,
const char* to, const char* subject, const char* body,
long long timestamp, char** err);
int xapian_delete(XapianDB* db, const char* id, char** err);
/* Returns JSON string: {"total":N,"hits":[{"id":"...","score":0.9},...]}
Returns NULL on error, sets *err. Caller must free with xapian_free_string. */
char* xapian_search(XapianDB* db, const char* query,
const char* from_filter, const char* own_email,
const char* to_filter,
long long date_from, long long date_to,
int offset, int limit, char** err);
void xapian_free_string(char* s);
#ifdef __cplusplus
}
#endif
#endif
+283
View File
@@ -0,0 +1,283 @@
// Package smtpd implements an embedded receive-only SMTP daemon for archivmail.
// It accepts incoming emails (e.g. from Postfix via always_bcc) and hands them
// off to the storage coordinator. No AUTH, no relay, no outbound mail.
package smtpd
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"log/slog"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/emersion/go-smtp"
"github.com/archivmail/config"
"github.com/archivmail/internal/storage"
)
// Stats holds runtime statistics for the SMTP daemon.
type Stats struct {
Received atomic.Int64 // total emails successfully stored
Rejected atomic.Int64 // rejected (IP, size, etc.)
LastMailAt atomic.Value // time.Time of last accepted mail
}
// Daemon is the embedded receive-only SMTP server.
type Daemon struct {
cfg config.SMTPConfig
store *storage.Store
logger *slog.Logger
stats Stats
server *smtp.Server
mu sync.Mutex
running bool
}
// New creates a new SMTP Daemon. Call Start() to begin accepting connections.
func New(cfg config.SMTPConfig, store *storage.Store, logger *slog.Logger) *Daemon {
d := &Daemon{
cfg: cfg,
store: store,
logger: logger,
}
d.stats.LastMailAt.Store(time.Time{})
return d
}
// Start launches the SMTP daemon in a background goroutine.
// It returns immediately; use Stop() for graceful shutdown.
func (d *Daemon) Start() error {
if !d.cfg.Enabled {
d.logger.Info("SMTP daemon disabled via config")
return nil
}
bind := d.cfg.Bind
if bind == "" {
bind = ":2525"
}
domain := d.cfg.Domain
if domain == "" {
domain = "archivmail"
}
maxBytes := int64(d.cfg.MaxSizeMB) * 1024 * 1024
if maxBytes <= 0 {
maxBytes = 50 * 1024 * 1024 // 50 MB default
}
backend := &backend{daemon: d}
srv := smtp.NewServer(backend)
srv.Addr = bind
srv.Domain = domain
srv.MaxMessageBytes = maxBytes
srv.ReadTimeout = 5 * time.Minute
srv.WriteTimeout = 30 * time.Second
srv.AllowInsecureAuth = false // no AUTH offered at all
// TLS / STARTTLS
if d.cfg.TLSCert != "" && d.cfg.TLSKey != "" {
cert, err := tls.LoadX509KeyPair(d.cfg.TLSCert, d.cfg.TLSKey)
if err != nil {
return fmt.Errorf("smtpd: load TLS cert: %w", err)
}
srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
}
d.mu.Lock()
d.server = srv
d.running = true
d.mu.Unlock()
go func() {
d.logger.Info("SMTP daemon starting", "addr", bind, "domain", domain,
"max_size_mb", d.cfg.MaxSizeMB, "tls", d.cfg.TLSCert != "")
if err := srv.ListenAndServe(); err != nil {
if !errors.Is(err, smtp.ErrServerClosed) {
d.logger.Error("SMTP daemon error", "err", err)
}
}
d.mu.Lock()
d.running = false
d.mu.Unlock()
}()
return nil
}
// Stop shuts down the SMTP daemon gracefully.
func (d *Daemon) Stop() {
d.mu.Lock()
srv := d.server
d.mu.Unlock()
if srv != nil {
srv.Close()
}
}
// Status returns a snapshot of the daemon's current state.
func (d *Daemon) Status() StatusResponse {
d.mu.Lock()
running := d.running
d.mu.Unlock()
lastMail, _ := d.stats.LastMailAt.Load().(time.Time)
var lastMailStr string
if !lastMail.IsZero() {
lastMailStr = lastMail.UTC().Format(time.RFC3339)
}
bind := d.cfg.Bind
if bind == "" {
bind = ":2525"
}
return StatusResponse{
Running: running,
Enabled: d.cfg.Enabled,
Bind: bind,
Domain: d.cfg.Domain,
TLS: d.cfg.TLSCert != "",
MaxSizeMB: d.cfg.MaxSizeMB,
AllowedIPs: d.cfg.AllowedIPs,
Received: d.stats.Received.Load(),
Rejected: d.stats.Rejected.Load(),
LastMailAt: lastMailStr,
}
}
// StatusResponse is returned by GET /api/admin/smtp/status.
type StatusResponse struct {
Running bool `json:"running"`
Enabled bool `json:"enabled"`
Bind string `json:"bind"`
Domain string `json:"domain"`
TLS bool `json:"tls"`
MaxSizeMB int `json:"max_size_mb"`
AllowedIPs []string `json:"allowed_ips"`
Received int64 `json:"received"`
Rejected int64 `json:"rejected"`
LastMailAt string `json:"last_mail_at,omitempty"`
}
// ── go-smtp Backend / Session ─────────────────────────────────────────────
type backend struct {
daemon *Daemon
}
func (b *backend) NewSession(c *smtp.Conn) (smtp.Session, error) {
remoteIP := extractIP(c.Conn().RemoteAddr().String())
if !b.daemon.isAllowed(remoteIP) {
b.daemon.stats.Rejected.Add(1)
b.daemon.logger.Warn("SMTP: rejected connection from unlisted IP", "ip", remoteIP)
return nil, &smtp.SMTPError{
Code: 554,
EnhancedCode: smtp.EnhancedCode{5, 7, 1},
Message: "IP not in allowlist",
}
}
b.daemon.logger.Debug("SMTP: new session", "ip", remoteIP)
return &session{
daemon: b.daemon,
remoteIP: remoteIP,
}, nil
}
type session struct {
daemon *Daemon
remoteIP string
from string
rcpts []string
}
// AuthPlain never called because server doesn't advertise AUTH.
func (s *session) AuthPlain(_, _ string) error {
return smtp.ErrAuthUnsupported
}
func (s *session) Mail(from string, _ *smtp.MailOptions) error {
s.from = from
return nil
}
func (s *session) Rcpt(to string, _ *smtp.RcptOptions) error {
s.rcpts = append(s.rcpts, to)
return nil
}
func (s *session) Data(r io.Reader) error {
var buf bytes.Buffer
if _, err := io.Copy(&buf, r); err != nil {
s.daemon.stats.Rejected.Add(1)
return fmt.Errorf("smtpd: read data: %w", err)
}
raw := buf.Bytes()
id, err := s.daemon.store.Save(raw, time.Now())
if err != nil {
s.daemon.stats.Rejected.Add(1)
s.daemon.logger.Error("SMTP: storage failed", "from", s.from, "err", err)
return &smtp.SMTPError{
Code: 554,
EnhancedCode: smtp.EnhancedCode{4, 6, 0},
Message: "Storage failure, please retry",
}
}
s.daemon.stats.Received.Add(1)
s.daemon.stats.LastMailAt.Store(time.Now())
s.daemon.logger.Info("SMTP: mail stored", "id", id, "from", s.from,
"rcpts", strings.Join(s.rcpts, ","), "bytes", len(raw), "ip", s.remoteIP)
return nil
}
func (s *session) Reset() {
s.from = ""
s.rcpts = nil
}
func (s *session) Logout() error {
return nil
}
// ── Helpers ───────────────────────────────────────────────────────────────
// isAllowed returns true if the IP is in the allowlist, or if the allowlist
// is empty (allow-all mode for development).
func (d *Daemon) isAllowed(ip string) bool {
if len(d.cfg.AllowedIPs) == 0 {
return true // no restriction configured
}
for _, allowed := range d.cfg.AllowedIPs {
// Support CIDR notation (e.g. 192.168.1.0/24)
if strings.Contains(allowed, "/") {
_, network, err := net.ParseCIDR(allowed)
if err == nil && network.Contains(net.ParseIP(ip)) {
return true
}
continue
}
if allowed == ip {
return true
}
}
return false
}
// extractIP strips port from "ip:port" or "[::1]:port" strings.
func extractIP(addr string) string {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
return host
}
+145
View File
@@ -0,0 +1,145 @@
package storage
import (
"crypto/sha256"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"time"
)
// Store is a file-based email storage using SHA256 for deduplication.
type Store struct {
dir string
}
// StoreStats reports total mail count and size in bytes.
type StoreStats struct {
TotalMails int64
TotalBytes int64
}
// New initialises the storage directory, creating required subdirectories.
func New(dir string) (*Store, error) {
for _, sub := range []string{"store", "attachments", "meta"} {
if err := os.MkdirAll(filepath.Join(dir, sub), 0o755); err != nil {
return nil, fmt.Errorf("storage: mkdir %s: %w", sub, err)
}
}
return &Store{dir: dir}, nil
}
// Save writes raw email bytes to storage. The ID is the hex-encoded SHA256 of
// the content. If the file already exists, Save is a no-op (deduplication).
func (s *Store) Save(raw []byte, _ time.Time) (string, error) {
sum := sha256.Sum256(raw)
id := fmt.Sprintf("%x", sum[:]) // 64 hex chars
path := s.filePath(id)
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return "", fmt.Errorf("storage: mkdir shard: %w", err)
}
// If file already exists, dedup: return same id without error.
if _, err := os.Stat(path); err == nil {
return id, nil
}
if err := os.WriteFile(path, raw, 0o644); err != nil {
return "", fmt.Errorf("storage: write: %w", err)
}
return id, nil
}
// Load reads a stored email by its ID.
func (s *Store) Load(id string) ([]byte, error) {
path := s.filePath(id)
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("storage: not found: %s", id)
}
return nil, fmt.Errorf("storage: read: %w", err)
}
return data, nil
}
// Delete removes a stored email by its ID.
func (s *Store) Delete(id string) error {
path := s.filePath(id)
if err := os.Remove(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("storage: not found: %s", id)
}
return fmt.Errorf("storage: delete: %w", err)
}
return nil
}
// Stats walks the store directory and returns aggregate statistics.
func (s *Store) Stats() (*StoreStats, error) {
var stats StoreStats
err := filepath.WalkDir(filepath.Join(s.dir, "store"), func(_ string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
info, err := d.Info()
if err != nil {
return err
}
stats.TotalMails++
stats.TotalBytes += info.Size()
return nil
})
if err != nil {
return nil, fmt.Errorf("storage: stats: %w", err)
}
return &stats, nil
}
// MailRef holds the ID and modification time of a stored mail.
type MailRef struct {
ID string
ModTime time.Time
}
// FirstAndLastMail walks the store and returns the oldest and newest mail by
// file modification time. Returns nil for either if the store is empty.
func (s *Store) FirstAndLastMail() (first, last *MailRef, err error) {
err = filepath.WalkDir(filepath.Join(s.dir, "store"), func(path string, d fs.DirEntry, werr error) error {
if werr != nil {
return werr
}
if d.IsDir() {
return nil
}
info, err := d.Info()
if err != nil {
return err
}
ref := &MailRef{ID: d.Name(), ModTime: info.ModTime()}
if first == nil || ref.ModTime.Before(first.ModTime) {
first = ref
}
if last == nil || ref.ModTime.After(last.ModTime) {
last = ref
}
return nil
})
if err != nil {
return nil, nil, fmt.Errorf("storage: first/last: %w", err)
}
return first, last, nil
}
// filePath returns the on-disk path for a given mail ID.
// Uses 2-char prefix sharding: {dir}/store/{id[:2]}/{id}
func (s *Store) filePath(id string) string {
return filepath.Join(s.dir, "store", id[:2], id)
}
+126
View File
@@ -0,0 +1,126 @@
package storage_test
import (
"bytes"
"os"
"path/filepath"
"testing"
"time"
"github.com/archivmail/internal/storage"
)
func TestSaveAndLoad(t *testing.T) {
dir := t.TempDir()
store, err := storage.New(dir)
if err != nil {
t.Fatalf("New: %v", err)
}
raw := []byte("From: alice@example.com\r\nSubject: Test\r\n\r\nHello World")
id, err := store.Save(raw, time.Now())
if err != nil {
t.Fatalf("Save: %v", err)
}
if len(id) != 64 {
t.Errorf("expected 64-char SHA256 hex, got %d chars", len(id))
}
got, err := store.Load(id)
if err != nil {
t.Fatalf("Load: %v", err)
}
if !bytes.Equal(raw, got) {
t.Errorf("loaded content mismatch")
}
}
func TestDeduplication(t *testing.T) {
dir := t.TempDir()
store, err := storage.New(dir)
if err != nil {
t.Fatal(err)
}
raw := []byte("From: alice@example.com\r\n\r\nDuplicate test")
id1, err := store.Save(raw, time.Now())
if err != nil {
t.Fatal(err)
}
id2, err := store.Save(raw, time.Now())
if err != nil {
t.Fatal(err)
}
if id1 != id2 {
t.Errorf("duplicate mail produced different IDs: %s vs %s", id1, id2)
}
// Only one file should exist
count := 0
filepath.Walk(filepath.Join(dir, "store"), func(p string, info os.FileInfo, _ error) error {
if !info.IsDir() { count++ }
return nil
})
if count != 1 {
t.Errorf("expected 1 stored file after dedup, got %d", count)
}
}
func TestDelete(t *testing.T) {
dir := t.TempDir()
store, err := storage.New(dir)
if err != nil {
t.Fatal(err)
}
raw := []byte("From: alice@example.com\r\n\r\nDelete me")
id, _ := store.Save(raw, time.Now())
if err := store.Delete(id); err != nil {
t.Fatalf("Delete: %v", err)
}
if _, err := store.Load(id); err == nil {
t.Error("Load after Delete should return error")
}
}
func TestStats(t *testing.T) {
dir := t.TempDir()
store, err := storage.New(dir)
if err != nil {
t.Fatal(err)
}
mails := [][]byte{
[]byte("From: a@x.com\r\n\r\nMail 1"),
[]byte("From: b@x.com\r\n\r\nMail 2"),
[]byte("From: c@x.com\r\n\r\nMail 3"),
}
for _, m := range mails {
store.Save(m, time.Now())
}
stats, err := store.Stats()
if err != nil {
t.Fatalf("Stats: %v", err)
}
if stats.TotalMails != 3 {
t.Errorf("expected 3 mails, got %d", stats.TotalMails)
}
if stats.TotalBytes <= 0 {
t.Error("expected positive TotalBytes")
}
}
func TestStorageDirectoryCreation(t *testing.T) {
dir := filepath.Join(t.TempDir(), "nested", "path")
_, err := storage.New(dir)
if err != nil {
t.Fatalf("New with nested path: %v", err)
}
for _, sub := range []string{"store", "attachments", "meta"} {
if _, err := os.Stat(filepath.Join(dir, sub)); os.IsNotExist(err) {
t.Errorf("expected subdirectory %s to be created", sub)
}
}
}
+304
View File
@@ -0,0 +1,304 @@
package userstore
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/crypto/bcrypt"
)
const (
RoleUser = "user"
RoleAdmin = "admin"
RoleAuditor = "auditor"
)
// User represents a user account in the system.
type User struct {
ID int64
Username string
Email string
Role string
Source string // "local" or "ldap"
Active bool
CreatedAt time.Time
}
// CreateUserRequest holds parameters for creating a new user.
type CreateUserRequest struct {
Username string
Email string
Password string
Role string
}
// UpdateUserRequest holds optional fields for updating a user.
type UpdateUserRequest struct {
Email *string
Role *string
Active *bool
Password *string
}
// Store is a PostgreSQL-backed user store.
type Store struct {
pool *pgxpool.Pool
}
// New connects to PostgreSQL using the given DSN and initialises the schema.
func New(dsn string) (*Store, error) {
ctx := context.Background()
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
return nil, fmt.Errorf("userstore: connect: %w", err)
}
s := &Store{pool: pool}
if err := s.initSchema(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("userstore: init schema: %w", err)
}
return s, nil
}
func (s *Store) initSchema(ctx context.Context) error {
_, err := s.pool.Exec(ctx, `
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
username VARCHAR(100) UNIQUE NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL,
password_hash VARCHAR(255) NOT NULL DEFAULT '',
role VARCHAR(20) NOT NULL CHECK (role IN ('user','auditor','admin')),
source VARCHAR(20) NOT NULL DEFAULT 'local',
active BOOLEAN NOT NULL DEFAULT true,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS token_blacklist (
jti VARCHAR(255) PRIMARY KEY,
expires_at TIMESTAMPTZ NOT NULL
);
`)
return err
}
// Close closes the underlying connection pool.
func (s *Store) Close() error {
s.pool.Close()
return nil
}
// Create inserts a new local user with a bcrypt-hashed password.
func (s *Store) Create(req CreateUserRequest) (*User, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("userstore: bcrypt: %w", err)
}
ctx := context.Background()
var id int64
err = s.pool.QueryRow(ctx,
`INSERT INTO users (username, email, password_hash, role, source, active, created_at)
VALUES ($1, $2, $3, $4, 'local', true, NOW())
RETURNING id`,
req.Username, req.Email, string(hash), req.Role,
).Scan(&id)
if err != nil {
return nil, fmt.Errorf("userstore: create: %w", err)
}
return s.GetByID(id)
}
// GetByID retrieves a user by their numeric ID.
func (s *Store) GetByID(id int64) (*User, error) {
ctx := context.Background()
row := s.pool.QueryRow(ctx,
`SELECT id, username, email, role, source, active, created_at FROM users WHERE id = $1`, id,
)
return scanUser(row)
}
// GetByUsername retrieves a user by their username.
func (s *Store) GetByUsername(username string) (*User, error) {
ctx := context.Background()
row := s.pool.QueryRow(ctx,
`SELECT id, username, email, role, source, active, created_at FROM users WHERE username = $1`, username,
)
return scanUser(row)
}
// VerifyPassword checks credentials and returns the user, or an error if the
// password is wrong or the account is disabled.
func (s *Store) VerifyPassword(username, password string) (*User, error) {
ctx := context.Background()
row := s.pool.QueryRow(ctx,
`SELECT id, username, email, role, source, active, created_at, password_hash FROM users WHERE username = $1`,
username,
)
var u User
var hash string
err := row.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt, &hash)
if errors.Is(err, pgx.ErrNoRows) {
return nil, errors.New("userstore: user not found")
}
if err != nil {
return nil, fmt.Errorf("userstore: scan: %w", err)
}
if !u.Active {
return nil, errors.New("userstore: account disabled")
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
return nil, errors.New("userstore: wrong password")
}
return &u, nil
}
// Update applies a partial update to a user record.
func (s *Store) Update(id int64, req UpdateUserRequest) (*User, error) {
ctx := context.Background()
if req.Email != nil {
if _, err := s.pool.Exec(ctx, `UPDATE users SET email = $1 WHERE id = $2`, *req.Email, id); err != nil {
return nil, fmt.Errorf("userstore: update email: %w", err)
}
}
if req.Role != nil {
if _, err := s.pool.Exec(ctx, `UPDATE users SET role = $1 WHERE id = $2`, *req.Role, id); err != nil {
return nil, fmt.Errorf("userstore: update role: %w", err)
}
}
if req.Active != nil {
if _, err := s.pool.Exec(ctx, `UPDATE users SET active = $1 WHERE id = $2`, *req.Active, id); err != nil {
return nil, fmt.Errorf("userstore: update active: %w", err)
}
}
if req.Password != nil {
hash, err := bcrypt.GenerateFromPassword([]byte(*req.Password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("userstore: bcrypt: %w", err)
}
if _, err := s.pool.Exec(ctx, `UPDATE users SET password_hash = $1 WHERE id = $2`, string(hash), id); err != nil {
return nil, fmt.Errorf("userstore: update password: %w", err)
}
}
return s.GetByID(id)
}
// Delete removes a user by ID. Returns an error if the user does not exist.
func (s *Store) Delete(id int64) error {
ctx := context.Background()
tag, err := s.pool.Exec(ctx, `DELETE FROM users WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("userstore: delete: %w", err)
}
if tag.RowsAffected() == 0 {
return fmt.Errorf("userstore: user %d not found", id)
}
return nil
}
// List returns all users, optionally filtered by role. Pass role="" to list all.
func (s *Store) List(role string) ([]*User, error) {
ctx := context.Background()
var rows pgx.Rows
var err error
if role == "" {
rows, err = s.pool.Query(ctx,
`SELECT id, username, email, role, source, active, created_at FROM users ORDER BY id`)
} else {
rows, err = s.pool.Query(ctx,
`SELECT id, username, email, role, source, active, created_at FROM users WHERE role = $1 ORDER BY id`, role)
}
if err != nil {
return nil, fmt.Errorf("userstore: list: %w", err)
}
defer rows.Close()
var users []*User
for rows.Next() {
u, err := scanUserRow(rows)
if err != nil {
return nil, err
}
users = append(users, u)
}
return users, rows.Err()
}
// BlacklistToken adds a JWT ID to the token blacklist.
func (s *Store) BlacklistToken(jti string, expires time.Time) error {
ctx := context.Background()
_, err := s.pool.Exec(ctx,
`INSERT INTO token_blacklist (jti, expires_at) VALUES ($1, $2)
ON CONFLICT (jti) DO UPDATE SET expires_at = EXCLUDED.expires_at`,
jti, expires.UTC(),
)
return err
}
// IsBlacklisted returns true if the given JTI is in the blacklist.
func (s *Store) IsBlacklisted(jti string) (bool, error) {
ctx := context.Background()
var count int
err := s.pool.QueryRow(ctx,
`SELECT COUNT(*) FROM token_blacklist WHERE jti = $1`, jti,
).Scan(&count)
return count > 0, err
}
// CleanExpiredTokens removes blacklist entries whose expiry has passed.
func (s *Store) CleanExpiredTokens() error {
ctx := context.Background()
_, err := s.pool.Exec(ctx, `DELETE FROM token_blacklist WHERE expires_at < NOW()`)
return err
}
// UpsertLDAPUser creates or updates an LDAP-sourced user.
func (s *Store) UpsertLDAPUser(username, email, role string) (*User, error) {
ctx := context.Background()
_, err := s.pool.Exec(ctx, `
INSERT INTO users (username, email, password_hash, role, source, active, created_at)
VALUES ($1, $2, '', $3, 'ldap', true, NOW())
ON CONFLICT (username) DO UPDATE SET
email = EXCLUDED.email,
role = EXCLUDED.role,
source = 'ldap'
`, username, email, role)
if err != nil {
return nil, fmt.Errorf("userstore: upsert ldap: %w", err)
}
return s.GetByUsername(username)
}
// --- helpers ---
func scanUser(row pgx.Row) (*User, error) {
var u User
err := row.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt)
if errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("userstore: not found")
}
if err != nil {
return nil, fmt.Errorf("userstore: scan: %w", err)
}
return &u, nil
}
func scanUserRow(rows pgx.Rows) (*User, error) {
var u User
if err := rows.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt); err != nil {
return nil, fmt.Errorf("userstore: scan row: %w", err)
}
return &u, nil
}
+279
View File
@@ -0,0 +1,279 @@
package userstore_test
import (
"context"
"os"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/archivmail/internal/userstore"
)
func newTestStore(t *testing.T) *userstore.Store {
t.Helper()
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
t.Skip("TEST_DATABASE_URL not set — skipping (needs PostgreSQL)")
}
// Use a unique schema per test to isolate
schema := "test_" + strings.ReplaceAll(t.Name(), "/", "_")
schema = strings.ToLower(schema)
// Append schema to DSN
sep := "?"
if strings.Contains(dsn, "?") {
sep = "&"
}
schemaDSN := dsn + sep + "search_path=" + schema
// Create schema
ctx := context.Background()
conn, err := pgx.Connect(ctx, dsn)
if err != nil {
t.Fatalf("connect: %v", err)
}
conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schema)
conn.Close(ctx)
s, err := userstore.New(schemaDSN)
if err != nil {
t.Fatalf("userstore.New: %v", err)
}
t.Cleanup(func() {
s.Close()
conn2, _ := pgx.Connect(context.Background(), dsn)
if conn2 != nil {
conn2.Exec(context.Background(), "DROP SCHEMA "+schema+" CASCADE")
conn2.Close(context.Background())
}
})
return s
}
func TestCreateAndGetUser(t *testing.T) {
s := newTestStore(t)
u, err := s.Create(userstore.CreateUserRequest{
Username: "alice",
Email: "alice@example.com",
Password: "secret123",
Role: userstore.RoleAdmin,
})
if err != nil {
t.Fatalf("Create: %v", err)
}
if u.ID == 0 {
t.Error("expected non-zero ID")
}
if u.Username != "alice" {
t.Errorf("Username = %q", u.Username)
}
if u.Role != userstore.RoleAdmin {
t.Errorf("Role = %q", u.Role)
}
if u.Source != "local" {
t.Errorf("Source = %q, want local", u.Source)
}
got, err := s.GetByID(u.ID)
if err != nil {
t.Fatalf("GetByID: %v", err)
}
if got.Email != "alice@example.com" {
t.Errorf("Email = %q", got.Email)
}
}
func TestVerifyPassword(t *testing.T) {
s := newTestStore(t)
_, err := s.Create(userstore.CreateUserRequest{
Username: "bob", Email: "bob@example.com",
Password: "correcthorse", Role: userstore.RoleUser,
})
if err != nil {
t.Fatal(err)
}
// Correct password
u, err := s.VerifyPassword("bob", "correcthorse")
if err != nil {
t.Errorf("VerifyPassword correct: %v", err)
}
if u.Username != "bob" {
t.Errorf("Username = %q", u.Username)
}
// Wrong password
if _, err := s.VerifyPassword("bob", "wrongpassword"); err == nil {
t.Error("expected error for wrong password")
}
// Non-existent user
if _, err := s.VerifyPassword("nobody", "x"); err == nil {
t.Error("expected error for unknown user")
}
}
func TestUpdateUser(t *testing.T) {
s := newTestStore(t)
u, _ := s.Create(userstore.CreateUserRequest{
Username: "carol", Email: "carol@old.com",
Password: "pw", Role: userstore.RoleUser,
})
newEmail := "carol@new.com"
newRole := userstore.RoleAuditor
updated, err := s.Update(u.ID, userstore.UpdateUserRequest{
Email: &newEmail,
Role: &newRole,
})
if err != nil {
t.Fatalf("Update: %v", err)
}
if updated.Email != "carol@new.com" {
t.Errorf("Email after update = %q", updated.Email)
}
if updated.Role != userstore.RoleAuditor {
t.Errorf("Role after update = %q", updated.Role)
}
}
func TestDisableUser(t *testing.T) {
s := newTestStore(t)
u, _ := s.Create(userstore.CreateUserRequest{
Username: "dave", Email: "dave@x.com",
Password: "pw", Role: userstore.RoleUser,
})
active := false
s.Update(u.ID, userstore.UpdateUserRequest{Active: &active})
if _, err := s.VerifyPassword("dave", "pw"); err == nil {
t.Error("disabled user should not be able to login")
}
}
func TestDeleteUser(t *testing.T) {
s := newTestStore(t)
u, _ := s.Create(userstore.CreateUserRequest{
Username: "eve", Email: "eve@x.com",
Password: "pw", Role: userstore.RoleUser,
})
if err := s.Delete(u.ID); err != nil {
t.Fatalf("Delete: %v", err)
}
if _, err := s.GetByID(u.ID); err == nil {
t.Error("GetByID should error after delete")
}
// Delete non-existent should error
if err := s.Delete(u.ID); err == nil {
t.Error("second delete should return error")
}
}
func TestListUsers(t *testing.T) {
s := newTestStore(t)
users := []userstore.CreateUserRequest{
{Username: "u1", Email: "u1@x.com", Password: "pw", Role: userstore.RoleUser},
{Username: "u2", Email: "u2@x.com", Password: "pw", Role: userstore.RoleAdmin},
{Username: "u3", Email: "u3@x.com", Password: "pw", Role: userstore.RoleAuditor},
{Username: "u4", Email: "u4@x.com", Password: "pw", Role: userstore.RoleUser},
}
for _, req := range users {
s.Create(req)
}
all, err := s.List("")
if err != nil {
t.Fatalf("List all: %v", err)
}
if len(all) != 4 {
t.Errorf("List all: got %d, want 4", len(all))
}
admins, _ := s.List(userstore.RoleAdmin)
if len(admins) != 1 {
t.Errorf("List admin: got %d, want 1", len(admins))
}
regular, _ := s.List(userstore.RoleUser)
if len(regular) != 2 {
t.Errorf("List user: got %d, want 2", len(regular))
}
}
func TestTokenBlacklist(t *testing.T) {
s := newTestStore(t)
jti := "test-jti-12345"
expires := time.Now().Add(1 * time.Hour)
if err := s.BlacklistToken(jti, expires); err != nil {
t.Fatalf("BlacklistToken: %v", err)
}
blacklisted, err := s.IsBlacklisted(jti)
if err != nil {
t.Fatalf("IsBlacklisted: %v", err)
}
if !blacklisted {
t.Error("token should be blacklisted")
}
// Non-blacklisted token
bl2, _ := s.IsBlacklisted("other-jti")
if bl2 {
t.Error("unknown token should not be blacklisted")
}
}
func TestCleanExpiredTokens(t *testing.T) {
s := newTestStore(t)
// Add an already-expired token
s.BlacklistToken("expired-jti", time.Now().Add(-1*time.Hour))
// Add a valid token
s.BlacklistToken("valid-jti", time.Now().Add(1*time.Hour))
if err := s.CleanExpiredTokens(); err != nil {
t.Fatalf("CleanExpiredTokens: %v", err)
}
bl, _ := s.IsBlacklisted("expired-jti")
if bl {
t.Error("expired token should be cleaned up")
}
bl2, _ := s.IsBlacklisted("valid-jti")
if !bl2 {
t.Error("valid token should still be blacklisted")
}
}
func TestUpsertLDAPUser(t *testing.T) {
s := newTestStore(t)
u, err := s.UpsertLDAPUser("ldapuser", "ldap@corp.com", userstore.RoleAuditor)
if err != nil {
t.Fatalf("UpsertLDAPUser: %v", err)
}
if u.Source != "ldap" {
t.Errorf("Source = %q, want ldap", u.Source)
}
// Second upsert should update, not duplicate
u2, err := s.UpsertLDAPUser("ldapuser", "ldap@corp.com", userstore.RoleAuditor)
if err != nil {
t.Fatalf("UpsertLDAPUser second: %v", err)
}
if u2.ID != u.ID {
t.Error("second upsert should not create a new record")
}
all, _ := s.List("")
if len(all) != 1 {
t.Errorf("expected 1 user after double upsert, got %d", len(all))
}
}