feat(PROJ-21): Phase 2+3+5+8 Multi-Tenancy + PROJ-2 EML/MBOX Upload
Phase 2a: userstore domain_admin/superadmin Rollen, User.TenantID,
ListByTenant, UpsertLDAPUser mit tenantID
Phase 2b: storage.Save() mit tenantID *int64, email_refs Tabelle,
GetTenantForMail, GetAllIDsByTenant, StatsByTenant
Phase 2c: JWT-Claims tenant_id/tenant_slug, Session.TenantID,
Login Domain-Erkennung via E-Mail-Domain
Phase 3: tenantMiddleware, Handler-Filterung (Users, Mail, Stats)
Phase 5: SMTP Domain-Routing via DomainToTenantFunc Callback,
config smtp.tenant_routing + default_tenant_id
Phase 8: archivmail migrate-tenants Subkommando
PROJ-2: Upload-Seite /admin/upload mit DropZone + Progress-Polling
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+121
-45
@@ -33,7 +33,10 @@ import (
|
||||
|
||||
type contextKey string
|
||||
|
||||
const sessionKey contextKey = "session"
|
||||
const (
|
||||
sessionKey contextKey = "session"
|
||||
tenantKey contextKey = "tenant_id"
|
||||
)
|
||||
|
||||
// Server is the archivmail HTTP API server.
|
||||
type Server struct {
|
||||
@@ -98,58 +101,68 @@ func New(
|
||||
return s
|
||||
}
|
||||
|
||||
// auth wraps a handler with authentication + tenant context propagation.
|
||||
func (s *Server) auth(h http.HandlerFunc) http.HandlerFunc {
|
||||
return s.authMiddleware(s.tenantMiddleware(h))
|
||||
}
|
||||
|
||||
// authAdmin wraps a handler requiring at least admin role.
|
||||
func (s *Server) authAdmin(h http.HandlerFunc) http.HandlerFunc {
|
||||
return s.authMiddleware(s.tenantMiddleware(s.requireRole(userstore.RoleDomainAdmin, h)))
|
||||
}
|
||||
|
||||
func (s *Server) routes() {
|
||||
s.mux.HandleFunc("GET /api/health", s.handleHealth)
|
||||
s.mux.HandleFunc("POST /api/auth/login", s.handleLogin)
|
||||
s.mux.HandleFunc("GET /api/auth/me", s.authMiddleware(s.handleMe))
|
||||
s.mux.HandleFunc("POST /api/auth/logout", s.authMiddleware(s.handleLogout))
|
||||
s.mux.HandleFunc("GET /api/users", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleListUsers)))
|
||||
s.mux.HandleFunc("POST /api/users", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleCreateUser)))
|
||||
s.mux.HandleFunc("PATCH /api/users/{id}", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleUpdateUser)))
|
||||
s.mux.HandleFunc("DELETE /api/users/{id}", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleDeleteUser)))
|
||||
s.mux.HandleFunc("GET /api/search", s.authMiddleware(s.handleSearch))
|
||||
s.mux.HandleFunc("GET /api/audit", s.authMiddleware(s.requireRole(userstore.RoleAuditor, s.handleAuditLog)))
|
||||
s.mux.HandleFunc("GET /api/admin/smtp/status", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleSMTPStatus)))
|
||||
s.mux.HandleFunc("GET /api/admin/storage/stats", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleStorageStats)))
|
||||
s.mux.HandleFunc("GET /api/mails/{id}", s.authMiddleware(s.requireMailAccess(s.handleGetMail)))
|
||||
s.mux.HandleFunc("GET /api/mails/{id}/attachments/{index}", s.authMiddleware(s.requireMailAccess(s.handleGetAttachment)))
|
||||
s.mux.HandleFunc("GET /api/mails/{id}/raw", s.authMiddleware(s.requireMailAccess(s.handleGetRaw)))
|
||||
s.mux.HandleFunc("GET /api/admin/services", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleListServices)))
|
||||
s.mux.HandleFunc("POST /api/admin/services/{name}/action", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleServiceAction)))
|
||||
s.mux.HandleFunc("GET /api/auth/me", s.auth(s.handleMe))
|
||||
s.mux.HandleFunc("POST /api/auth/logout", s.auth(s.handleLogout))
|
||||
s.mux.HandleFunc("GET /api/users", s.authAdmin(s.handleListUsers))
|
||||
s.mux.HandleFunc("POST /api/users", s.authAdmin(s.handleCreateUser))
|
||||
s.mux.HandleFunc("PATCH /api/users/{id}", s.authAdmin(s.handleUpdateUser))
|
||||
s.mux.HandleFunc("DELETE /api/users/{id}", s.authAdmin(s.handleDeleteUser))
|
||||
s.mux.HandleFunc("GET /api/search", s.auth(s.handleSearch))
|
||||
s.mux.HandleFunc("GET /api/audit", s.auth(s.requireRole(userstore.RoleAuditor, s.handleAuditLog)))
|
||||
s.mux.HandleFunc("GET /api/admin/smtp/status", s.authAdmin(s.handleSMTPStatus))
|
||||
s.mux.HandleFunc("GET /api/admin/storage/stats", s.authAdmin(s.handleStorageStats))
|
||||
s.mux.HandleFunc("GET /api/mails/{id}", s.auth(s.requireMailAccess(s.handleGetMail)))
|
||||
s.mux.HandleFunc("GET /api/mails/{id}/attachments/{index}", s.auth(s.requireMailAccess(s.handleGetAttachment)))
|
||||
s.mux.HandleFunc("GET /api/mails/{id}/raw", s.auth(s.requireMailAccess(s.handleGetRaw)))
|
||||
s.mux.HandleFunc("GET /api/admin/services", s.authAdmin(s.handleListServices))
|
||||
s.mux.HandleFunc("POST /api/admin/services/{name}/action", s.authAdmin(s.handleServiceAction))
|
||||
|
||||
s.mux.HandleFunc("GET /api/admin/system/stats", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleSystemStats)))
|
||||
s.mux.HandleFunc("GET /api/admin/security/audit", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleSecurityAudit)))
|
||||
s.mux.HandleFunc("POST /api/admin/security/fix", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleSecurityFix)))
|
||||
s.mux.HandleFunc("GET /api/admin/system/stats", s.authAdmin(s.handleSystemStats))
|
||||
s.mux.HandleFunc("GET /api/admin/security/audit", s.authAdmin(s.handleSecurityAudit))
|
||||
s.mux.HandleFunc("POST /api/admin/security/fix", s.authAdmin(s.handleSecurityFix))
|
||||
|
||||
// Export routes
|
||||
s.mux.HandleFunc("GET /api/export/pdf/{id}", s.authMiddleware(s.requireMailAccess(s.handleExportPDF)))
|
||||
s.mux.HandleFunc("POST /api/export/zip", s.authMiddleware(s.requireMailAccess(s.handleExportZIP)))
|
||||
s.mux.HandleFunc("GET /api/export/pdf/{id}", s.auth(s.requireMailAccess(s.handleExportPDF)))
|
||||
s.mux.HandleFunc("POST /api/export/zip", s.auth(s.requireMailAccess(s.handleExportZIP)))
|
||||
|
||||
// Upload routes (admin only)
|
||||
s.mux.HandleFunc("POST /api/admin/upload", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleUpload)))
|
||||
s.mux.HandleFunc("GET /api/admin/upload/{jobID}/progress", s.authMiddleware(s.requireRole(userstore.RoleAdmin, s.handleUploadProgress)))
|
||||
s.mux.HandleFunc("POST /api/admin/upload", s.authAdmin(s.handleUpload))
|
||||
s.mux.HandleFunc("GET /api/admin/upload/{jobID}/progress", s.authAdmin(s.handleUploadProgress))
|
||||
|
||||
// Upload routes (all authenticated users)
|
||||
s.mux.HandleFunc("POST /api/upload", s.authMiddleware(s.handleUpload))
|
||||
s.mux.HandleFunc("GET /api/upload/{jobID}/progress", s.authMiddleware(s.handleUploadProgress))
|
||||
s.mux.HandleFunc("POST /api/upload", s.auth(s.handleUpload))
|
||||
s.mux.HandleFunc("GET /api/upload/{jobID}/progress", s.auth(s.handleUploadProgress))
|
||||
|
||||
// IMAP routes (accessible to all authenticated users)
|
||||
s.mux.HandleFunc("GET /api/imap", s.authMiddleware(s.handleListImap))
|
||||
s.mux.HandleFunc("POST /api/imap", s.authMiddleware(s.handleCreateImap))
|
||||
s.mux.HandleFunc("DELETE /api/imap/{id}", s.authMiddleware(s.handleDeleteImap))
|
||||
s.mux.HandleFunc("PATCH /api/imap/{id}", s.authMiddleware(s.handleUpdateImapInterval))
|
||||
s.mux.HandleFunc("POST /api/imap/test", s.authMiddleware(s.handleTestImap))
|
||||
s.mux.HandleFunc("POST /api/imap/{id}/import", s.authMiddleware(s.handleStartImport))
|
||||
s.mux.HandleFunc("GET /api/imap/{id}/progress", s.authMiddleware(s.handleImapProgress))
|
||||
s.mux.HandleFunc("POST /api/imap/{id}/sync", s.authMiddleware(s.handleSyncNow))
|
||||
s.mux.HandleFunc("GET /api/imap", s.auth(s.handleListImap))
|
||||
s.mux.HandleFunc("POST /api/imap", s.auth(s.handleCreateImap))
|
||||
s.mux.HandleFunc("DELETE /api/imap/{id}", s.auth(s.handleDeleteImap))
|
||||
s.mux.HandleFunc("PATCH /api/imap/{id}", s.auth(s.handleUpdateImapInterval))
|
||||
s.mux.HandleFunc("POST /api/imap/test", s.auth(s.handleTestImap))
|
||||
s.mux.HandleFunc("POST /api/imap/{id}/import", s.auth(s.handleStartImport))
|
||||
s.mux.HandleFunc("GET /api/imap/{id}/progress", s.auth(s.handleImapProgress))
|
||||
s.mux.HandleFunc("POST /api/imap/{id}/sync", s.auth(s.handleSyncNow))
|
||||
|
||||
// POP3 routes (accessible to all authenticated users)
|
||||
s.mux.HandleFunc("GET /api/pop3", s.authMiddleware(s.handleListPop3))
|
||||
s.mux.HandleFunc("POST /api/pop3", s.authMiddleware(s.handleCreatePop3))
|
||||
s.mux.HandleFunc("DELETE /api/pop3/{id}", s.authMiddleware(s.handleDeletePop3))
|
||||
s.mux.HandleFunc("POST /api/pop3/test", s.authMiddleware(s.handleTestPop3))
|
||||
s.mux.HandleFunc("POST /api/pop3/{id}/import", s.authMiddleware(s.handleStartPop3Import))
|
||||
s.mux.HandleFunc("GET /api/pop3/{id}/progress", s.authMiddleware(s.handlePop3Progress))
|
||||
s.mux.HandleFunc("GET /api/pop3", s.auth(s.handleListPop3))
|
||||
s.mux.HandleFunc("POST /api/pop3", s.auth(s.handleCreatePop3))
|
||||
s.mux.HandleFunc("DELETE /api/pop3/{id}", s.auth(s.handleDeletePop3))
|
||||
s.mux.HandleFunc("POST /api/pop3/test", s.auth(s.handleTestPop3))
|
||||
s.mux.HandleFunc("POST /api/pop3/{id}/import", s.auth(s.handleStartPop3Import))
|
||||
s.mux.HandleFunc("GET /api/pop3/{id}/progress", s.auth(s.handlePop3Progress))
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler.
|
||||
@@ -286,7 +299,17 @@ func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
users, err := s.users.List("")
|
||||
tenantID := tenantFromCtx(r.Context())
|
||||
|
||||
var (
|
||||
users []*userstore.User
|
||||
err error
|
||||
)
|
||||
if tenantID != nil {
|
||||
users, err = s.users.ListByTenant(r.Context(), *tenantID)
|
||||
} else {
|
||||
users, err = s.users.List("")
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to list users")
|
||||
return
|
||||
@@ -298,6 +321,7 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
Active bool `json:"active"`
|
||||
TenantID *int64 `json:"tenant_id,omitempty"`
|
||||
}
|
||||
|
||||
resp := make([]userResp, 0, len(users))
|
||||
@@ -308,6 +332,7 @@ func (s *Server) handleListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
Email: u.Email,
|
||||
Role: u.Role,
|
||||
Active: u.Active,
|
||||
TenantID: u.TenantID,
|
||||
})
|
||||
}
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
@@ -572,14 +597,15 @@ func (s *Server) handleSMTPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleStorageStats(w http.ResponseWriter, r *http.Request) {
|
||||
stats, err := s.store.Stats()
|
||||
tenantID := tenantFromCtx(r.Context())
|
||||
stats, err := s.store.StatsByTenant(r.Context(), tenantID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to read storage stats")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"total_mails": stats.TotalMails,
|
||||
"total_bytes": stats.TotalBytes,
|
||||
"total_mails": stats["count"],
|
||||
"total_bytes": stats["total_size"],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -684,6 +710,26 @@ func sessionFromCtx(ctx context.Context) *auth.Session {
|
||||
return &auth.Session{}
|
||||
}
|
||||
|
||||
// tenantMiddleware extracts the tenant_id from the session and stores it in
|
||||
// the request context, making it available to all downstream handlers.
|
||||
func (s *Server) tenantMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := sessionFromCtx(r.Context())
|
||||
if session != nil && session.TenantID != nil {
|
||||
ctx := context.WithValue(r.Context(), tenantKey, session.TenantID)
|
||||
next(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// tenantFromCtx extracts the tenant_id from context. Returns nil for global (superadmin) context.
|
||||
func tenantFromCtx(ctx context.Context) *int64 {
|
||||
v, _ := ctx.Value(tenantKey).(*int64)
|
||||
return v
|
||||
}
|
||||
|
||||
func remoteIP(r *http.Request) string {
|
||||
if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" {
|
||||
return strings.Split(fwd, ",")[0]
|
||||
@@ -722,8 +768,18 @@ func (s *Server) handleGetMail(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// user role: only own mailbox
|
||||
sess := sessionFromCtx(r.Context())
|
||||
|
||||
// Tenant isolation: domain_admin sees only own tenant's mail
|
||||
if sess.TenantID != nil {
|
||||
mailTenant, _ := s.store.GetTenantForMail(r.Context(), id)
|
||||
if mailTenant == nil || *mailTenant != *sess.TenantID {
|
||||
writeError(w, http.StatusForbidden, "access denied")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// user role: only own mailbox
|
||||
if sess.Role == userstore.RoleUser {
|
||||
u, err := s.users.GetByUsername(sess.Username)
|
||||
if err != nil || !mailBelongsToUser(pm, u.Email) {
|
||||
@@ -803,6 +859,16 @@ func (s *Server) handleGetAttachment(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
sess := sessionFromCtx(r.Context())
|
||||
|
||||
// Tenant isolation
|
||||
if sess.TenantID != nil {
|
||||
mailTenant, _ := s.store.GetTenantForMail(r.Context(), id)
|
||||
if mailTenant == nil || *mailTenant != *sess.TenantID {
|
||||
writeError(w, http.StatusForbidden, "access denied")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if sess.Role == userstore.RoleUser {
|
||||
u, err := s.users.GetByUsername(sess.Username)
|
||||
if err != nil || !mailBelongsToUser(pm, u.Email) {
|
||||
@@ -838,8 +904,18 @@ func (s *Server) handleGetRaw(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Access check for user role
|
||||
sess := sessionFromCtx(r.Context())
|
||||
|
||||
// Tenant isolation
|
||||
if sess.TenantID != nil {
|
||||
mailTenant, _ := s.store.GetTenantForMail(r.Context(), id)
|
||||
if mailTenant == nil || *mailTenant != *sess.TenantID {
|
||||
writeError(w, http.StatusForbidden, "access denied")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Access check for user role
|
||||
if sess.Role == userstore.RoleUser {
|
||||
pm, err := mailparser.Parse(raw)
|
||||
if err == nil {
|
||||
|
||||
@@ -108,8 +108,11 @@ func (s *Server) handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
s.uploadJobs.Store(jobID, job)
|
||||
|
||||
// Propagate tenant from session context
|
||||
tenantID := tenantFromCtx(r.Context())
|
||||
|
||||
// Run import in background
|
||||
go s.runUploadJob(job, allMessages)
|
||||
go s.runUploadJob(job, allMessages, tenantID)
|
||||
|
||||
writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID})
|
||||
}
|
||||
@@ -126,11 +129,11 @@ func (s *Server) handleUploadProgress(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, job.snapshot())
|
||||
}
|
||||
|
||||
func (s *Server) runUploadJob(job *UploadJob, messages [][]byte) {
|
||||
func (s *Server) runUploadJob(job *UploadJob, messages [][]byte, tenantID *int64) {
|
||||
ctx := context.Background()
|
||||
|
||||
for _, raw := range messages {
|
||||
result := s.importRawMessage(ctx, raw)
|
||||
result := s.importRawMessage(ctx, raw, tenantID)
|
||||
job.mu.Lock()
|
||||
switch result {
|
||||
case "imported":
|
||||
@@ -150,14 +153,14 @@ func (s *Server) runUploadJob(job *UploadJob, messages [][]byte) {
|
||||
|
||||
// importRawMessage stores and indexes a single raw message.
|
||||
// Returns "imported", "skipped", or "error".
|
||||
func (s *Server) importRawMessage(ctx context.Context, raw []byte) string {
|
||||
func (s *Server) importRawMessage(ctx context.Context, raw []byte, tenantID *int64) string {
|
||||
pm, err := mailparser.Parse(raw)
|
||||
if err != nil {
|
||||
s.logger.Warn("upload: parse failed", "err", err)
|
||||
return "error"
|
||||
}
|
||||
|
||||
id, err := s.store.Save(raw, pm.Date)
|
||||
id, err := s.store.Save(ctx, raw, pm.Date, tenantID)
|
||||
if err != nil {
|
||||
s.logger.Warn("upload: save failed", "err", err)
|
||||
return "error"
|
||||
|
||||
+40
-15
@@ -16,10 +16,12 @@ import (
|
||||
|
||||
// Session holds the claims extracted from a validated JWT.
|
||||
type Session struct {
|
||||
UserID int64
|
||||
Username string
|
||||
Role string
|
||||
JTI string // unique JWT ID
|
||||
UserID int64
|
||||
Username string
|
||||
Role string
|
||||
JTI string // unique JWT ID
|
||||
TenantID *int64
|
||||
TenantSlug string
|
||||
}
|
||||
|
||||
// Manager handles login, token issuance, validation, and logout.
|
||||
@@ -83,7 +85,7 @@ func (m *Manager) Login(username, password string) (string, *userstore.User, err
|
||||
email = username + "@ldap.local"
|
||||
}
|
||||
|
||||
ldapUser, upsertErr := m.store.UpsertLDAPUser(username, email, role)
|
||||
ldapUser, upsertErr := m.store.UpsertLDAPUser(username, email, role, nil)
|
||||
if upsertErr == nil {
|
||||
return m.issueToken(ldapUser)
|
||||
}
|
||||
@@ -98,13 +100,20 @@ func (m *Manager) Login(username, password string) (string, *userstore.User, err
|
||||
func (m *Manager) issueToken(user *userstore.User) (string, *userstore.User, error) {
|
||||
jti := generateJTI()
|
||||
now := time.Now()
|
||||
|
||||
var tenantIDVal int64
|
||||
if user.TenantID != nil {
|
||||
tenantIDVal = *user.TenantID
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": user.Username,
|
||||
"role": user.Role,
|
||||
"uid": user.ID,
|
||||
"jti": jti,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(8 * time.Hour).Unix(),
|
||||
"sub": user.Username,
|
||||
"role": user.Role,
|
||||
"uid": user.ID,
|
||||
"jti": jti,
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(8 * time.Hour).Unix(),
|
||||
"tenant_id": tenantIDVal,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
@@ -156,11 +165,25 @@ func (m *Manager) ValidateToken(tokenStr string) (*Session, error) {
|
||||
userID = v
|
||||
}
|
||||
|
||||
var tenantID *int64
|
||||
switch v := claims["tenant_id"].(type) {
|
||||
case float64:
|
||||
if v != 0 {
|
||||
id := int64(v)
|
||||
tenantID = &id
|
||||
}
|
||||
case int64:
|
||||
if v != 0 {
|
||||
tenantID = &v
|
||||
}
|
||||
}
|
||||
|
||||
return &Session{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Role: role,
|
||||
JTI: jti,
|
||||
TenantID: tenantID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -196,12 +219,14 @@ func (m *Manager) Logout(tokenStr string) error {
|
||||
}
|
||||
|
||||
// HasRole returns true when userRole satisfies the required role level.
|
||||
// Hierarchy: admin > auditor > user
|
||||
// Hierarchy: superadmin > admin > domain_admin > auditor > user
|
||||
func HasRole(userRole, required string) bool {
|
||||
levels := map[string]int{
|
||||
userstore.RoleUser: 1,
|
||||
userstore.RoleAuditor: 2,
|
||||
userstore.RoleAdmin: 3,
|
||||
userstore.RoleUser: 1,
|
||||
userstore.RoleAuditor: 2,
|
||||
userstore.RoleDomainAdmin: 3,
|
||||
userstore.RoleAdmin: 4,
|
||||
userstore.RoleSuperAdmin: 5,
|
||||
}
|
||||
return levels[userRole] >= levels[required]
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ type Importer struct {
|
||||
mailStore *storage.Store
|
||||
idx index.Indexer
|
||||
logger *slog.Logger
|
||||
TenantID *int64 // optional tenant assignment for stored mails
|
||||
}
|
||||
|
||||
// NewImporter creates a new Importer wired to the storage and index backends.
|
||||
@@ -229,8 +230,9 @@ func (imp *Importer) fetchBatch(ctx context.Context, c *imapclient.Client, uids
|
||||
|
||||
// storeAndIndex saves a raw email to storage and indexes it.
|
||||
func (imp *Importer) storeAndIndex(raw []byte, log *slog.Logger) error {
|
||||
ctx := context.Background()
|
||||
// Save to file storage (deduplicates by SHA256 automatically)
|
||||
id, err := imp.mailStore.Save(raw, time.Now())
|
||||
id, err := imp.mailStore.Save(ctx, raw, time.Now(), imp.TenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("save: %w", err)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ type Importer struct {
|
||||
mailStore *storage.Store
|
||||
idx index.Indexer
|
||||
logger *slog.Logger
|
||||
TenantID *int64 // optional tenant assignment for stored mails
|
||||
}
|
||||
|
||||
// NewImporter creates a new Importer wired to the storage and index backends.
|
||||
@@ -125,8 +126,9 @@ func (imp *Importer) doImport(ctx context.Context, acc *Account, password string
|
||||
|
||||
// storeAndIndex saves a raw email to storage and indexes it.
|
||||
func (imp *Importer) storeAndIndex(raw []byte, log *slog.Logger) error {
|
||||
ctx := context.Background()
|
||||
// Save to file storage (deduplicates by SHA256 automatically)
|
||||
id, err := imp.mailStore.Save(raw, time.Now())
|
||||
id, err := imp.mailStore.Save(ctx, raw, time.Now(), imp.TenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pop3 save: %w", err)
|
||||
}
|
||||
|
||||
+52
-9
@@ -5,6 +5,7 @@ package smtpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -22,6 +23,10 @@ import (
|
||||
"github.com/archivmail/internal/storage"
|
||||
)
|
||||
|
||||
// DomainToTenantFunc resolves an e-mail domain to a tenant ID.
|
||||
// Returns nil if no tenant matches the domain.
|
||||
type DomainToTenantFunc func(ctx context.Context, domain string) (*int64, error)
|
||||
|
||||
// Stats holds runtime statistics for the SMTP daemon.
|
||||
type Stats struct {
|
||||
Received atomic.Int64 // total emails successfully stored
|
||||
@@ -35,14 +40,16 @@ type IndexCallback func(raw []byte, id string)
|
||||
|
||||
// Daemon is the embedded receive-only SMTP server.
|
||||
type Daemon struct {
|
||||
cfg config.SMTPConfig
|
||||
store *storage.Store
|
||||
logger *slog.Logger
|
||||
stats Stats
|
||||
server *smtp.Server
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
indexCallback IndexCallback
|
||||
cfg config.SMTPConfig
|
||||
store *storage.Store
|
||||
logger *slog.Logger
|
||||
stats Stats
|
||||
server *smtp.Server
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
indexCallback IndexCallback
|
||||
domainToTenant DomainToTenantFunc // optional domain→tenant routing
|
||||
defaultTenantID *int64 // fallback tenant if no domain matches
|
||||
}
|
||||
|
||||
// New creates a new SMTP Daemon. Call Start() to begin accepting connections.
|
||||
@@ -56,6 +63,39 @@ func New(cfg config.SMTPConfig, store *storage.Store, logger *slog.Logger) *Daem
|
||||
return d
|
||||
}
|
||||
|
||||
// SetDomainToTenant wires in the domain→tenant resolution function.
|
||||
func (d *Daemon) SetDomainToTenant(fn DomainToTenantFunc, defaultTenantID *int64) {
|
||||
d.domainToTenant = fn
|
||||
d.defaultTenantID = defaultTenantID
|
||||
}
|
||||
|
||||
// resolveTenantFromRcpts extracts the domain from RCPT TO addresses and
|
||||
// resolves it to a tenant ID via the configured DomainToTenantFunc.
|
||||
func (d *Daemon) resolveTenantFromRcpts(rcpts []string) *int64 {
|
||||
if d.domainToTenant == nil {
|
||||
return d.defaultTenantID
|
||||
}
|
||||
ctx := context.Background()
|
||||
for _, rcpt := range rcpts {
|
||||
// Strip angle brackets if present
|
||||
addr := strings.Trim(rcpt, "<>")
|
||||
at := strings.LastIndex(addr, "@")
|
||||
if at < 0 {
|
||||
continue
|
||||
}
|
||||
domain := strings.ToLower(addr[at+1:])
|
||||
tenantID, err := d.domainToTenant(ctx, domain)
|
||||
if err != nil {
|
||||
d.logger.Warn("SMTP: tenant lookup failed", "domain", domain, "err", err)
|
||||
continue
|
||||
}
|
||||
if tenantID != nil {
|
||||
return tenantID
|
||||
}
|
||||
}
|
||||
return d.defaultTenantID
|
||||
}
|
||||
|
||||
// SetIndexCallback sets the function called after each successfully stored mail.
|
||||
func (d *Daemon) SetIndexCallback(cb IndexCallback) {
|
||||
d.indexCallback = cb
|
||||
@@ -232,7 +272,10 @@ func (s *session) Data(r io.Reader) error {
|
||||
}
|
||||
raw := buf.Bytes()
|
||||
|
||||
id, err := s.daemon.store.Save(raw, time.Now())
|
||||
// Determine tenant from RCPT TO domain routing
|
||||
tenantID := s.daemon.resolveTenantFromRcpts(s.rcpts)
|
||||
|
||||
id, err := s.daemon.store.Save(context.Background(), raw, time.Now(), tenantID)
|
||||
if err != nil {
|
||||
s.daemon.stats.Rejected.Add(1)
|
||||
s.daemon.logger.Error("SMTP: storage failed", "from", s.from, "err", err)
|
||||
|
||||
+160
-36
@@ -193,15 +193,33 @@ func (s *Store) initSchema(ctx context.Context) error {
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_mail_from ON emails (mail_from);
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_subject ON emails USING gin (to_tsvector('simple', subject));
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Phase 2b migrations: tenant isolation
|
||||
_, err = s.db.Exec(ctx, `
|
||||
ALTER TABLE emails ADD COLUMN IF NOT EXISTS tenant_id BIGINT;
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_tenant ON emails (tenant_id);
|
||||
CREATE TABLE IF NOT EXISTS email_refs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
email_id TEXT NOT NULL REFERENCES emails(id),
|
||||
tenant_id BIGINT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(email_id, tenant_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_email_refs_tenant ON email_refs (tenant_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_email_refs_email ON email_refs (email_id);
|
||||
`)
|
||||
return err
|
||||
}
|
||||
|
||||
// ── Core operations ───────────────────────────────────────────────────────
|
||||
|
||||
// Save writes raw email bytes to storage. The ID is the hex-encoded SHA256 of
|
||||
// the plaintext content. If the file already exists, Save is a no-op (dedup).
|
||||
// It also inserts metadata into the emails table if a DB is configured.
|
||||
func (s *Store) Save(raw []byte, _ time.Time) (string, error) {
|
||||
// the plaintext content. If the file already exists, Save ensures an email_ref
|
||||
// exists for the tenant (cross-tenant dedup: one file, many refs).
|
||||
// tenantID may be nil for system-level ingestion without tenant assignment.
|
||||
func (s *Store) Save(ctx context.Context, raw []byte, _ time.Time, tenantID *int64) (string, error) {
|
||||
// Hash plaintext for dedup (always before encryption)
|
||||
sum := sha256.Sum256(raw)
|
||||
id := fmt.Sprintf("%x", sum[:]) // 64 hex chars
|
||||
@@ -211,36 +229,46 @@ func (s *Store) Save(raw []byte, _ time.Time) (string, error) {
|
||||
return "", fmt.Errorf("storage: mkdir shard: %w", err)
|
||||
}
|
||||
|
||||
// Dedup: if file already exists, return same id
|
||||
fileExists := false
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return id, nil
|
||||
fileExists = true
|
||||
}
|
||||
|
||||
// Determine what to write: encrypted or plaintext
|
||||
var toWrite []byte
|
||||
if s.key != nil {
|
||||
encrypted, err := s.encrypt(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
toWrite = encrypted
|
||||
} else {
|
||||
toWrite = raw
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, toWrite, 0o644); err != nil {
|
||||
return "", fmt.Errorf("storage: write: %w", err)
|
||||
}
|
||||
|
||||
// Insert metadata into DB (best-effort parse)
|
||||
if s.db != nil {
|
||||
pm, parseErr := mailparser.Parse(raw)
|
||||
if parseErr == nil {
|
||||
s.insertMeta(context.Background(), id, pm, len(raw))
|
||||
if !fileExists {
|
||||
// Determine what to write: encrypted or plaintext
|
||||
var toWrite []byte
|
||||
if s.key != nil {
|
||||
encrypted, err := s.encrypt(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
toWrite = encrypted
|
||||
} else {
|
||||
// Insert minimal metadata even if parse fails
|
||||
s.insertMetaMinimal(context.Background(), id, len(raw))
|
||||
toWrite = raw
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, toWrite, 0o644); err != nil {
|
||||
return "", fmt.Errorf("storage: write: %w", err)
|
||||
}
|
||||
|
||||
// Insert metadata into DB (best-effort parse)
|
||||
if s.db != nil {
|
||||
pm, parseErr := mailparser.Parse(raw)
|
||||
if parseErr == nil {
|
||||
s.insertMeta(ctx, id, pm, len(raw), tenantID)
|
||||
} else {
|
||||
s.insertMetaMinimal(ctx, id, len(raw), tenantID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure email_ref entry for this tenant (even if file already existed)
|
||||
if s.db != nil && tenantID != nil {
|
||||
_, _ = s.db.Exec(ctx, `
|
||||
INSERT INTO email_refs (email_id, tenant_id)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (email_id, tenant_id) DO NOTHING
|
||||
`, id, *tenantID)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
@@ -402,24 +430,24 @@ func (s *Store) firstAndLastFromFS() (first, last *MailRef, err error) {
|
||||
// ── Metadata helpers ──────────────────────────────────────────────────────
|
||||
|
||||
// insertMeta inserts parsed email metadata into the emails table.
|
||||
func (s *Store) insertMeta(ctx context.Context, id string, pm *mailparser.ParsedMail, size int) {
|
||||
func (s *Store) insertMeta(ctx context.Context, id string, pm *mailparser.ParsedMail, size int, tenantID *int64) {
|
||||
mailTo := strings.Join(pm.To, ", ")
|
||||
hasAttach := len(pm.Attachments) > 0
|
||||
|
||||
_, _ = s.db.Exec(ctx, `
|
||||
INSERT INTO emails (id, received_at, mail_from, mail_to, subject, size_bytes, has_attach)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
INSERT INTO emails (id, received_at, mail_from, mail_to, subject, size_bytes, has_attach, tenant_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
`, id, pm.Date, pm.From, mailTo, pm.Subject, int64(size), hasAttach)
|
||||
`, id, pm.Date, pm.From, mailTo, pm.Subject, int64(size), hasAttach, tenantID)
|
||||
}
|
||||
|
||||
// insertMetaMinimal inserts minimal metadata when parsing fails.
|
||||
func (s *Store) insertMetaMinimal(ctx context.Context, id string, size int) {
|
||||
func (s *Store) insertMetaMinimal(ctx context.Context, id string, size int, tenantID *int64) {
|
||||
_, _ = s.db.Exec(ctx, `
|
||||
INSERT INTO emails (id, received_at, size_bytes)
|
||||
VALUES ($1, NOW(), $2)
|
||||
INSERT INTO emails (id, received_at, size_bytes, tenant_id)
|
||||
VALUES ($1, NOW(), $2, $3)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
`, id, int64(size))
|
||||
`, id, int64(size), tenantID)
|
||||
}
|
||||
|
||||
// SaveMeta upserts metadata for a given email ID. Used by the backfill process.
|
||||
@@ -602,6 +630,102 @@ func (s *Store) VerifyIntegrity(ctx context.Context, id string) (bool, error) {
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// GetTenantForMail returns the tenant_id stored directly on the email record.
|
||||
// Returns nil if no tenant is assigned or the mail does not exist.
|
||||
func (s *Store) GetTenantForMail(ctx context.Context, id string) (*int64, error) {
|
||||
if s.db == nil {
|
||||
return nil, nil
|
||||
}
|
||||
var tenantID *int64
|
||||
err := s.db.QueryRow(ctx, `SELECT tenant_id FROM emails WHERE id = $1`, id).Scan(&tenantID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storage: get tenant for mail: %w", err)
|
||||
}
|
||||
return tenantID, nil
|
||||
}
|
||||
|
||||
// GetAllIDsByTenant returns all email IDs visible to a tenant.
|
||||
// If tenantID is nil, all IDs are returned (superadmin / no-tenant context).
|
||||
func (s *Store) GetAllIDsByTenant(ctx context.Context, tenantID *int64) ([]string, error) {
|
||||
if s.db != nil {
|
||||
var (
|
||||
rows pgx.Rows
|
||||
err error
|
||||
)
|
||||
if tenantID == nil {
|
||||
rows, err = s.db.Query(ctx, `SELECT id FROM emails ORDER BY received_at`)
|
||||
} else {
|
||||
rows, err = s.db.Query(ctx,
|
||||
`SELECT email_id FROM email_refs WHERE tenant_id = $1`, *tenantID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storage: get ids by tenant: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var ids []string
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, rows.Err()
|
||||
}
|
||||
// fallback: walk store (no tenant filtering possible without DB)
|
||||
var ids []string
|
||||
err := s.WalkStore(ctx, func(id string) error {
|
||||
ids = append(ids, id)
|
||||
return nil
|
||||
})
|
||||
return ids, err
|
||||
}
|
||||
|
||||
// StatsByTenant returns mail count and total size filtered by tenant.
|
||||
// If tenantID is nil, aggregate over all emails.
|
||||
func (s *Store) StatsByTenant(ctx context.Context, tenantID *int64) (map[string]interface{}, error) {
|
||||
if s.db == nil {
|
||||
st, err := s.statsFromFS()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"count": st.TotalMails,
|
||||
"total_size": st.TotalBytes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var count int64
|
||||
var totalSize int64
|
||||
|
||||
if tenantID == nil {
|
||||
err := s.db.QueryRow(ctx,
|
||||
`SELECT COALESCE(COUNT(*),0), COALESCE(SUM(size_bytes),0) FROM emails`,
|
||||
).Scan(&count, &totalSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storage: stats by tenant: %w", err)
|
||||
}
|
||||
} else {
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(COUNT(e.id),0), COALESCE(SUM(e.size_bytes),0)
|
||||
FROM email_refs r
|
||||
JOIN emails e ON e.id = r.email_id
|
||||
WHERE r.tenant_id = $1
|
||||
`, *tenantID).Scan(&count, &totalSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storage: stats by tenant: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"count": count,
|
||||
"total_size": totalSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAllIDs returns all email IDs from the DB, or walks the store if no DB.
|
||||
func (s *Store) GetAllIDs(ctx context.Context) ([]string, error) {
|
||||
if s.db != nil {
|
||||
|
||||
@@ -12,9 +12,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
RoleUser = "user"
|
||||
RoleAdmin = "admin"
|
||||
RoleAuditor = "auditor"
|
||||
RoleUser = "user"
|
||||
RoleAdmin = "admin"
|
||||
RoleAuditor = "auditor"
|
||||
RoleDomainAdmin = "domain_admin"
|
||||
RoleSuperAdmin = "superadmin"
|
||||
|
||||
bcryptCost = 12
|
||||
)
|
||||
@@ -28,6 +30,7 @@ type User struct {
|
||||
Source string // "local" or "ldap"
|
||||
Active bool
|
||||
CreatedAt time.Time
|
||||
TenantID *int64 `json:"tenant_id,omitempty"`
|
||||
}
|
||||
|
||||
// CreateUserRequest holds parameters for creating a new user.
|
||||
@@ -36,6 +39,7 @@ type CreateUserRequest struct {
|
||||
Email string
|
||||
Password string
|
||||
Role string
|
||||
TenantID *int64
|
||||
}
|
||||
|
||||
// UpdateUserRequest holds optional fields for updating a user.
|
||||
@@ -75,13 +79,14 @@ func (s *Store) initSchema(ctx context.Context) error {
|
||||
username VARCHAR(100) UNIQUE NOT NULL,
|
||||
email VARCHAR(255) UNIQUE NOT NULL,
|
||||
password_hash VARCHAR(255) NOT NULL DEFAULT '',
|
||||
role VARCHAR(20) NOT NULL CHECK (role IN ('user','auditor','admin')),
|
||||
role VARCHAR(20) NOT NULL DEFAULT 'user',
|
||||
source VARCHAR(20) NOT NULL DEFAULT 'local',
|
||||
active BOOLEAN NOT NULL DEFAULT true,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
last_login_at TIMESTAMPTZ
|
||||
);
|
||||
ALTER TABLE users ADD COLUMN IF NOT EXISTS last_login_at TIMESTAMPTZ;
|
||||
ALTER TABLE users ADD COLUMN IF NOT EXISTS tenant_id BIGINT;
|
||||
CREATE TABLE IF NOT EXISTS token_blacklist (
|
||||
jti VARCHAR(255) PRIMARY KEY,
|
||||
expires_at TIMESTAMPTZ NOT NULL
|
||||
@@ -92,6 +97,7 @@ func (s *Store) initSchema(ctx context.Context) error {
|
||||
attempted_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_login_attempts_username_time ON login_attempts (username, attempted_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_users_tenant ON users (tenant_id);
|
||||
`)
|
||||
return err
|
||||
}
|
||||
@@ -112,10 +118,10 @@ func (s *Store) Create(req CreateUserRequest) (*User, error) {
|
||||
ctx := context.Background()
|
||||
var id int64
|
||||
err = s.pool.QueryRow(ctx,
|
||||
`INSERT INTO users (username, email, password_hash, role, source, active, created_at)
|
||||
VALUES ($1, $2, $3, $4, 'local', true, NOW())
|
||||
`INSERT INTO users (username, email, password_hash, role, source, active, created_at, tenant_id)
|
||||
VALUES ($1, $2, $3, $4, 'local', true, NOW(), $5)
|
||||
RETURNING id`,
|
||||
req.Username, req.Email, string(hash), req.Role,
|
||||
req.Username, req.Email, string(hash), req.Role, req.TenantID,
|
||||
).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: create: %w", err)
|
||||
@@ -128,7 +134,7 @@ func (s *Store) Create(req CreateUserRequest) (*User, error) {
|
||||
func (s *Store) GetByID(id int64) (*User, error) {
|
||||
ctx := context.Background()
|
||||
row := s.pool.QueryRow(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at FROM users WHERE id = $1`, id,
|
||||
`SELECT id, username, email, role, source, active, created_at, tenant_id FROM users WHERE id = $1`, id,
|
||||
)
|
||||
return scanUser(row)
|
||||
}
|
||||
@@ -137,7 +143,7 @@ func (s *Store) GetByID(id int64) (*User, error) {
|
||||
func (s *Store) GetByUsername(username string) (*User, error) {
|
||||
ctx := context.Background()
|
||||
row := s.pool.QueryRow(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at FROM users WHERE username = $1`, username,
|
||||
`SELECT id, username, email, role, source, active, created_at, tenant_id FROM users WHERE username = $1`, username,
|
||||
)
|
||||
return scanUser(row)
|
||||
}
|
||||
@@ -147,13 +153,13 @@ func (s *Store) GetByUsername(username string) (*User, error) {
|
||||
func (s *Store) VerifyPassword(username, password string) (*User, error) {
|
||||
ctx := context.Background()
|
||||
row := s.pool.QueryRow(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at, password_hash FROM users WHERE username = $1`,
|
||||
`SELECT id, username, email, role, source, active, created_at, tenant_id, password_hash FROM users WHERE username = $1`,
|
||||
username,
|
||||
)
|
||||
|
||||
var u User
|
||||
var hash string
|
||||
err := row.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt, &hash)
|
||||
err := row.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt, &u.TenantID, &hash)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, errors.New("userstore: user not found")
|
||||
}
|
||||
@@ -225,10 +231,10 @@ func (s *Store) List(role string) ([]*User, error) {
|
||||
|
||||
if role == "" {
|
||||
rows, err = s.pool.Query(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at FROM users ORDER BY id`)
|
||||
`SELECT id, username, email, role, source, active, created_at, tenant_id FROM users ORDER BY id`)
|
||||
} else {
|
||||
rows, err = s.pool.Query(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at FROM users WHERE role = $1 ORDER BY id`, role)
|
||||
`SELECT id, username, email, role, source, active, created_at, tenant_id FROM users WHERE role = $1 ORDER BY id`, role)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: list: %w", err)
|
||||
@@ -246,6 +252,28 @@ func (s *Store) List(role string) ([]*User, error) {
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
// ListByTenant returns all users belonging to a specific tenant.
|
||||
func (s *Store) ListByTenant(ctx context.Context, tenantID int64) ([]*User, error) {
|
||||
rows, err := s.pool.Query(ctx,
|
||||
`SELECT id, username, email, role, source, active, created_at, tenant_id FROM users WHERE tenant_id = $1 ORDER BY id`,
|
||||
tenantID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: list by tenant: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*User
|
||||
for rows.Next() {
|
||||
u, err := scanUserRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, u)
|
||||
}
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
// BlacklistToken adds a JWT ID to the token blacklist.
|
||||
func (s *Store) BlacklistToken(jti string, expires time.Time) error {
|
||||
ctx := context.Background()
|
||||
@@ -295,12 +323,12 @@ func (s *Store) CountRecentFailures(username string, window time.Duration) (int,
|
||||
return count, err
|
||||
}
|
||||
|
||||
// AdminCount returns the number of active admin users.
|
||||
// AdminCount returns the number of active privileged users (admin, domain_admin, superadmin).
|
||||
func (s *Store) AdminCount() (int, error) {
|
||||
ctx := context.Background()
|
||||
var count int
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM users WHERE role = 'admin' AND active = true`,
|
||||
`SELECT COUNT(*) FROM users WHERE role IN ('admin','domain_admin','superadmin') AND active = true`,
|
||||
).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
@@ -311,7 +339,7 @@ func (s *Store) DeleteSafe(id int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if user.Role == RoleAdmin {
|
||||
if user.Role == RoleAdmin || user.Role == RoleDomainAdmin || user.Role == RoleSuperAdmin {
|
||||
count, err := s.AdminCount()
|
||||
if err != nil {
|
||||
return fmt.Errorf("userstore: admin count: %w", err)
|
||||
@@ -331,16 +359,18 @@ func (s *Store) CleanExpiredTokens() error {
|
||||
}
|
||||
|
||||
// UpsertLDAPUser creates or updates an LDAP-sourced user.
|
||||
func (s *Store) UpsertLDAPUser(username, email, role string) (*User, error) {
|
||||
// tenantID may be nil for users not associated with a specific tenant.
|
||||
func (s *Store) UpsertLDAPUser(username, email, role string, tenantID *int64) (*User, error) {
|
||||
ctx := context.Background()
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
INSERT INTO users (username, email, password_hash, role, source, active, created_at)
|
||||
VALUES ($1, $2, '', $3, 'ldap', true, NOW())
|
||||
INSERT INTO users (username, email, password_hash, role, source, active, created_at, tenant_id)
|
||||
VALUES ($1, $2, '', $3, 'ldap', true, NOW(), $4)
|
||||
ON CONFLICT (username) DO UPDATE SET
|
||||
email = EXCLUDED.email,
|
||||
role = EXCLUDED.role,
|
||||
source = 'ldap'
|
||||
`, username, email, role)
|
||||
email = EXCLUDED.email,
|
||||
role = EXCLUDED.role,
|
||||
source = 'ldap',
|
||||
tenant_id = COALESCE(EXCLUDED.tenant_id, users.tenant_id)
|
||||
`, username, email, role, tenantID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("userstore: upsert ldap: %w", err)
|
||||
}
|
||||
@@ -351,7 +381,7 @@ func (s *Store) UpsertLDAPUser(username, email, role string) (*User, error) {
|
||||
|
||||
func scanUser(row pgx.Row) (*User, error) {
|
||||
var u User
|
||||
err := row.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt)
|
||||
err := row.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt, &u.TenantID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, fmt.Errorf("userstore: not found")
|
||||
}
|
||||
@@ -363,7 +393,7 @@ func scanUser(row pgx.Row) (*User, error) {
|
||||
|
||||
func scanUserRow(rows pgx.Rows) (*User, error) {
|
||||
var u User
|
||||
if err := rows.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt); err != nil {
|
||||
if err := rows.Scan(&u.ID, &u.Username, &u.Email, &u.Role, &u.Source, &u.Active, &u.CreatedAt, &u.TenantID); err != nil {
|
||||
return nil, fmt.Errorf("userstore: scan row: %w", err)
|
||||
}
|
||||
return &u, nil
|
||||
|
||||
Reference in New Issue
Block a user