Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions apps/server/api/internal/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/MonkyMars/PWS/lib"
"github.com/MonkyMars/PWS/services"
"github.com/MonkyMars/PWS/types"
"github.com/MonkyMars/PWS/workers"
"github.com/gofiber/fiber/v3"
)

Expand Down Expand Up @@ -50,6 +51,30 @@ func GetSystemHealth(c fiber.Ctx) error {
})
}

// GetAuditHealth returns the health status of the audit logging system
func GetAuditHealth(c fiber.Ctx) error {
healthStatus := workers.HealthStatus()

status := "ok"
message := "Audit system operational"

if !healthStatus["is_healthy"].(bool) {
status = "degraded"
message = "Audit system experiencing issues"
}

if !healthStatus["worker_running"].(bool) {
status = "error"
message = "Audit worker not running"
}

return response.Success(c, map[string]any{
"status": status,
"message": message,
"details": healthStatus,
})
}

// TODO: Add authentication middleware to protect this endpoint in production
// Thus making sure only authorized admins can access it
// Currently it's only available in development mode because of this issue
Expand Down
16 changes: 16 additions & 0 deletions apps/server/api/internal/audit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package internal

import (
"github.com/MonkyMars/PWS/api/response"
"github.com/MonkyMars/PWS/services"
"github.com/gofiber/fiber/v3"
)

func GetLogs(c fiber.Ctx) error {
auditService := services.NewAuditService()
logs, err := auditService.GetLogs()
if err != nil {
return response.InternalServerError(c, "Failed to retrieve audit logs")
}
return response.Success(c, logs)
}
22 changes: 10 additions & 12 deletions apps/server/api/internal/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ func Login(c fiber.Ctx) error {
// Attempt login
user, err := authService.Login(&authRequest)
if err != nil {
logger.Error("Login failed", "email", authRequest.Email, "error", err)

logger.AuditError("Login failed", "email", authRequest.Email, "error", err.Error())
if errors.Is(err, lib.ErrInvalidCredentials) {
return response.Unauthorized(c, "Invalid email or password")
}
Expand All @@ -66,13 +65,13 @@ func Login(c fiber.Ctx) error {
// Generate tokens
accessToken, err := authService.GenerateAccessToken(user)
if err != nil {
logger.Error("Failed to generate access token", "user_id", user.Id, "error", err)
logger.AuditError("Failed to generate access token", "user_id", user.Id, "error", err)
return response.InternalServerError(c, "Failed to generate access token")
}

refreshToken, err := authService.GenerateRefreshToken(user)
if err != nil {
logger.Error("Failed to generate refresh token", "user_id", user.Id, "error", err)
logger.AuditError("Failed to generate refresh token", "user_id", user.Id, "error", err)
return response.InternalServerError(c, "Failed to generate refresh token")
}

Expand Down Expand Up @@ -143,25 +142,24 @@ func Register(c fiber.Ctx) error {
// Attempt registration
user, err := authService.Register(&registerRequest)
if err != nil {
logger.Error("Registration failed", "email", registerRequest.Email, "username", registerRequest.Username, "error", err)

if errors.Is(err, lib.ErrUserAlreadyExists) {
if errors.Is(err, lib.ErrUserAlreadyExists) || errors.Is(err, lib.ErrUsernameTaken) {
return response.Conflict(c, "User with this email or username already exists")
}

logger.AuditError("Registration failed", "email", registerRequest.Email, "username", registerRequest.Username, "error", err.Error())
return response.InternalServerError(c, "An error occurred during registration")
}

// Generate tokens for the new user
accessToken, err := authService.GenerateAccessToken(user)
if err != nil {
logger.Error("Failed to generate access token", "user_id", user.Id, "error", err)
logger.AuditError("Failed to generate access token", "user_id", user.Id, "error", err)
return response.InternalServerError(c, "Failed to generate access token")
}

refreshToken, err := authService.GenerateRefreshToken(user)
if err != nil {
logger.Error("Failed to generate refresh token", "user_id", user.Id, "error", err)
logger.AuditError("Failed to generate refresh token", "user_id", user.Id, "error", err)
return response.InternalServerError(c, "Failed to generate refresh token")
}

Expand Down Expand Up @@ -217,7 +215,7 @@ func Me(c fiber.Ctx) error {

claims, ok := claimsInterface.(*types.AuthClaims)
if !ok {
logger.Error("Invalid claims type in context", "type", fmt.Sprintf("%T", claimsInterface))
logger.AuditError("Invalid claims type in context", "type", fmt.Sprintf("%T", claimsInterface))
return response.Unauthorized(c, "Unauthorized")
}

Expand All @@ -227,7 +225,7 @@ func Me(c fiber.Ctx) error {
// Fetch user info
user, err := authService.GetUserByID(claims.Sub)
if err != nil {
logger.Error("Failed to retrieve user info", "user_id", claims.Sub, "error", err)
logger.AuditError("Failed to retrieve user info", "user_id", claims.Sub, "error", err)
return response.InternalServerError(c, "Failed to retrieve user info")
}

Expand Down Expand Up @@ -259,7 +257,7 @@ func Logout(c fiber.Ctx) error {
} else {
// Token is valid, blacklist it
if err := authService.BlacklistToken(accessToken, true); err != nil {
logger.Error("Failed to blacklist access token", "error", err)
logger.AuditError("Failed to blacklist access token", "error", err)
// Don't return error, continue with logout process
}
}
Expand Down
4 changes: 2 additions & 2 deletions apps/server/api/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func AuthMiddleware() fiber.Handler {

claims, err := authService.ParseToken(token, true)
if err != nil {
logger.Error("Failed to parse access token", "error", err)
logger.AuditError("Failed to parse access token", "error", err)
return response.Unauthorized(c, "Invalid or expired access token")
}

Expand All @@ -33,7 +33,7 @@ func AuthMiddleware() fiber.Handler {
// Check if token is blacklisted with graceful Redis failure handling
blacklisted, err := cacheService.IsTokenBlacklisted(claims.Jti.String())
if err != nil {
logger.Error("Redis blacklist check failed, denying request for security", "error", err, "jti", claims.Jti.String())
logger.AuditError("Redis blacklist check failed, denying request for security", "error", err, "jti", claims.Jti.String())
// Do not return faulty Redis errors to the client, let the request through if Redis is down
} else if blacklisted {
// SECURITY: This could indicate a token reuse attack
Expand Down
3 changes: 3 additions & 0 deletions apps/server/api/routes/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ func SetupAppRoutes(app *fiber.App) {
if cfg.Environment == "development" {
app.Get("/health", internal.GetSystemHealth)
app.Get("/health/database", internal.GetDatabaseHealth)
app.Get("/health/audit", internal.GetAuditHealth)
app.Get("/logs", internal.GetLogs)
}

app.Use(internal.NotFoundHandler)
}
14 changes: 14 additions & 0 deletions apps/server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ type Config struct {

// CORS Settings
Cors types.CorsConfig

// Audit Settings
Audit types.AuditConfig
}

var (
Expand Down Expand Up @@ -128,6 +131,17 @@ func Load() *Config {
AllowHeaders: getEnvSlice("CORS_ALLOW_HEADERS", []string{"Origin", "Content-Type", "Accept", "Authorization"}),
AllowCredentials: getEnvBool("CORS_ALLOW_CREDENTIALS", true),
},

// Audit Settings
Audit: types.AuditConfig{
BatchSize: getEnvInt("AUDIT_BATCH_SIZE", 10),
FlushTime: getEnvDuration("AUDIT_FLUSH_TIME", 2*time.Second),
ChannelSize: getEnvInt("AUDIT_CHANNEL_SIZE", 100),
MaxRetries: getEnvInt("AUDIT_MAX_RETRIES", 3),
MaxFailures: getEnvInt("AUDIT_MAX_FAILURES", 5),
RetentionDays: getEnvInt("AUDIT_RETENTION_DAYS", 90),
Enabled: getEnvBool("AUDIT_ENABLED", true),
},
}

// Validate configuration
Expand Down
116 changes: 116 additions & 0 deletions apps/server/config/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ package config

import (
"context"
"crypto/sha256"
"fmt"
"log/slog"
"os"
"sync"
"time"

"github.com/MonkyMars/PWS/types"
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v3"
)

Expand Down Expand Up @@ -203,3 +207,115 @@ func (l *Logger) Performance(operation string, duration time.Duration) {
slog.Duration("duration", duration),
)
}

// AuditError logs error messages to both the standard logger and the audit system.
// This function creates an audit log entry that gets batched and stored in the database
// via the audit worker for persistent error tracking and analysis.
//
// Parameters:
// - message: A descriptive error message
// - attrs: Additional structured attributes to include in both logs
func (l *Logger) AuditError(message string, attrs ...any) {
// Log to standard logger first
l.Error(message, attrs...)

// Create audit log entry with validation
auditAttrs := make(map[string]any)

// Process attrs in pairs (key, value)
for i := 0; i < len(attrs)-1; i += 2 {
if key, ok := attrs[i].(string); ok && key != "" {
auditAttrs[key] = attrs[i+1]
}
}

auditLog := types.AuditLog{
Timestamp: time.Now(),
Level: "ERROR",
Message: message,
Attrs: auditAttrs,
}

entryHash := generateEntryHash(auditLog)
auditLog.EntryHash = entryHash

// Send to audit worker (non-blocking)
addAuditLogFunc := getAddAuditLogFunc()
if addAuditLogFunc != nil {
addAuditLogFunc(auditLog)
}
}

// getAddAuditLogFunc returns the AddAuditLog function to avoid circular imports
// This uses a lazy loading approach to access the workers.AddAuditLog function
func getAddAuditLogFunc() func(types.AuditLog) {
auditMutex.RLock()
defer auditMutex.RUnlock()
return globalAddAuditLogFunc
}

// Global variable to hold the AddAuditLog function reference
var (
globalAddAuditLogFunc func(types.AuditLog)
auditMutex sync.RWMutex
)

// SetAuditLogFunc sets the audit log function to avoid circular dependencies.
// This should be called during application initialization to wire up the audit logging.
func SetAuditLogFunc(fn func(types.AuditLog)) {
if fn == nil {
return // Don't set nil function
}
auditMutex.Lock()
defer auditMutex.Unlock()
globalAddAuditLogFunc = fn
}

// generateEntryHash creates a unique hash for an audit log entry
// This is used for deduplication to prevent the same entry from being inserted multiple times
func generateEntryHash(entry types.AuditLog) string {
// Validate required fields
if entry.Message == "" || entry.Level == "" {
return "" // Return empty hash for invalid entries
}

// Normalize timestamp to second precision to handle minor timing differences
timestamp := entry.Timestamp.Truncate(time.Second).Unix()

// Handle nil attrs map
attrs := entry.Attrs
if attrs == nil {
attrs = make(map[string]any)
}

// Create a deterministic representation of the entry
data := map[string]any{
"timestamp": timestamp,
"level": entry.Level,
"message": entry.Message,
"attrs": attrs,
}

// Convert to JSON for consistent hashing
jsonData, err := json.Marshal(data)
if err != nil {
// Fallback to a simple string representation with sanitized message
sanitizedMessage := entry.Message
if len(sanitizedMessage) > 100 {
sanitizedMessage = sanitizedMessage[:100] + "..."
}
fallbackHash := fmt.Sprintf("%d_%s_%s", timestamp, entry.Level, sanitizedMessage)

// Ensure fallback hash is not too long for database column
if len(fallbackHash) > 64 {
// Use a hash of the fallback string if it's too long
hash := sha256.Sum256([]byte(fallbackHash))
return fmt.Sprintf("%x", hash)
}
return fallbackHash
}

// Generate SHA256 hash
hash := sha256.Sum256(jsonData)
return fmt.Sprintf("%x", hash)
Comment on lines +313 to +320
Copy link

Copilot AI Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider truncating the hash to reduce storage overhead since the full 64-character SHA256 hash may be unnecessarily long for deduplication purposes. A 16-32 character prefix would likely provide sufficient uniqueness.

Suggested change
return fmt.Sprintf("%x", hash)
}
return fallbackHash
}
// Generate SHA256 hash
hash := sha256.Sum256(jsonData)
return fmt.Sprintf("%x", hash)
return fmt.Sprintf("%x", hash)[:32] // Truncate to 32 hex chars (16 bytes)
}
return fallbackHash
}
// Generate SHA256 hash
hash := sha256.Sum256(jsonData)
return fmt.Sprintf("%x", hash)[:32] // Truncate to 32 hex chars (16 bytes)

Copilot uses AI. Check for mistakes.
}
Loading
Loading