|  | 
|  | 1 | +package middleware | 
|  | 2 | + | 
|  | 3 | +import ( | 
|  | 4 | +	"encoding/json" | 
|  | 5 | +	"strings" | 
|  | 6 | +	"time" | 
|  | 7 | + | 
|  | 8 | +	"github.com/gofiber/fiber/v2" | 
|  | 9 | +	"github.com/mudler/LocalAI/core/services" | 
|  | 10 | +	"github.com/rs/zerolog/log" | 
|  | 11 | +) | 
|  | 12 | + | 
|  | 13 | +// MetricsMiddleware creates a middleware that tracks API usage metrics | 
|  | 14 | +// Note: Uses CONTEXT_LOCALS_KEY_MODEL_NAME constant defined in request.go | 
|  | 15 | +func MetricsMiddleware(metricsStore services.MetricsStore) fiber.Handler { | 
|  | 16 | +	return func(c *fiber.Ctx) error { | 
|  | 17 | +		path := c.Path() | 
|  | 18 | + | 
|  | 19 | +		// Skip tracking for UI routes, static files, and non-API endpoints | 
|  | 20 | +		if shouldSkipMetrics(path) { | 
|  | 21 | +			return c.Next() | 
|  | 22 | +		} | 
|  | 23 | + | 
|  | 24 | +		// Record start time | 
|  | 25 | +		start := time.Now() | 
|  | 26 | + | 
|  | 27 | +		// Get endpoint category | 
|  | 28 | +		endpoint := categorizeEndpoint(path) | 
|  | 29 | + | 
|  | 30 | +		// Continue with the request | 
|  | 31 | +		err := c.Next() | 
|  | 32 | + | 
|  | 33 | +		// Record metrics after request completes | 
|  | 34 | +		duration := time.Since(start) | 
|  | 35 | +		success := err == nil && c.Response().StatusCode() < 400 | 
|  | 36 | + | 
|  | 37 | +		// Extract model name from context (set by RequestExtractor middleware) | 
|  | 38 | +		// Use the same constant as RequestExtractor | 
|  | 39 | +		model := "unknown" | 
|  | 40 | +		if modelVal, ok := c.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string); ok && modelVal != "" { | 
|  | 41 | +			model = modelVal | 
|  | 42 | +			log.Debug().Str("model", model).Str("endpoint", endpoint).Msg("Recording metrics for request") | 
|  | 43 | +		} else { | 
|  | 44 | +			// Fallback: try to extract from path params or query | 
|  | 45 | +			model = extractModelFromRequest(c) | 
|  | 46 | +			log.Debug().Str("model", model).Str("endpoint", endpoint).Msg("Recording metrics for request (fallback)") | 
|  | 47 | +		} | 
|  | 48 | + | 
|  | 49 | +		// Extract backend from response headers if available | 
|  | 50 | +		backend := string(c.Response().Header.Peek("X-LocalAI-Backend")) | 
|  | 51 | + | 
|  | 52 | +		// Record the request | 
|  | 53 | +		metricsStore.RecordRequest(endpoint, model, backend, success, duration) | 
|  | 54 | + | 
|  | 55 | +		return err | 
|  | 56 | +	} | 
|  | 57 | +} | 
|  | 58 | + | 
|  | 59 | +// shouldSkipMetrics determines if a request should be excluded from metrics | 
|  | 60 | +func shouldSkipMetrics(path string) bool { | 
|  | 61 | +	// Skip UI routes | 
|  | 62 | +	skipPrefixes := []string{ | 
|  | 63 | +		"/views/", | 
|  | 64 | +		"/static/", | 
|  | 65 | +		"/browse/", | 
|  | 66 | +		"/chat/", | 
|  | 67 | +		"/text2image/", | 
|  | 68 | +		"/tts/", | 
|  | 69 | +		"/talk/", | 
|  | 70 | +		"/models/edit/", | 
|  | 71 | +		"/import-model", | 
|  | 72 | +		"/settings", | 
|  | 73 | +		"/api/models",     // UI API endpoints | 
|  | 74 | +		"/api/backends",   // UI API endpoints | 
|  | 75 | +		"/api/operations", // UI API endpoints | 
|  | 76 | +		"/api/p2p",        // UI API endpoints | 
|  | 77 | +		"/api/metrics",    // Metrics API itself | 
|  | 78 | +	} | 
|  | 79 | + | 
|  | 80 | +	for _, prefix := range skipPrefixes { | 
|  | 81 | +		if strings.HasPrefix(path, prefix) { | 
|  | 82 | +			return true | 
|  | 83 | +		} | 
|  | 84 | +	} | 
|  | 85 | + | 
|  | 86 | +	// Also skip root path and other UI pages | 
|  | 87 | +	if path == "/" || path == "/index" { | 
|  | 88 | +		return true | 
|  | 89 | +	} | 
|  | 90 | + | 
|  | 91 | +	return false | 
|  | 92 | +} | 
|  | 93 | + | 
|  | 94 | +// categorizeEndpoint maps request paths to friendly endpoint categories | 
|  | 95 | +func categorizeEndpoint(path string) string { | 
|  | 96 | +	// OpenAI-compatible endpoints | 
|  | 97 | +	if strings.HasPrefix(path, "/v1/chat/completions") || strings.HasPrefix(path, "/chat/completions") { | 
|  | 98 | +		return "chat" | 
|  | 99 | +	} | 
|  | 100 | +	if strings.HasPrefix(path, "/v1/completions") || strings.HasPrefix(path, "/completions") { | 
|  | 101 | +		return "completions" | 
|  | 102 | +	} | 
|  | 103 | +	if strings.HasPrefix(path, "/v1/embeddings") || strings.HasPrefix(path, "/embeddings") { | 
|  | 104 | +		return "embeddings" | 
|  | 105 | +	} | 
|  | 106 | +	if strings.HasPrefix(path, "/v1/images/generations") || strings.HasPrefix(path, "/images/generations") { | 
|  | 107 | +		return "image-generation" | 
|  | 108 | +	} | 
|  | 109 | +	if strings.HasPrefix(path, "/v1/audio/transcriptions") || strings.HasPrefix(path, "/audio/transcriptions") { | 
|  | 110 | +		return "transcriptions" | 
|  | 111 | +	} | 
|  | 112 | +	if strings.HasPrefix(path, "/v1/audio/speech") || strings.HasPrefix(path, "/audio/speech") { | 
|  | 113 | +		return "text-to-speech" | 
|  | 114 | +	} | 
|  | 115 | +	if strings.HasPrefix(path, "/v1/models") || strings.HasPrefix(path, "/models") { | 
|  | 116 | +		return "models" | 
|  | 117 | +	} | 
|  | 118 | + | 
|  | 119 | +	// LocalAI-specific endpoints | 
|  | 120 | +	if strings.HasPrefix(path, "/v1/internal") { | 
|  | 121 | +		return "internal" | 
|  | 122 | +	} | 
|  | 123 | +	if strings.Contains(path, "/tts") { | 
|  | 124 | +		return "text-to-speech" | 
|  | 125 | +	} | 
|  | 126 | +	if strings.Contains(path, "/stt") || strings.Contains(path, "/whisper") { | 
|  | 127 | +		return "speech-to-text" | 
|  | 128 | +	} | 
|  | 129 | +	if strings.Contains(path, "/sound-generation") { | 
|  | 130 | +		return "sound-generation" | 
|  | 131 | +	} | 
|  | 132 | + | 
|  | 133 | +	// Default to the first path segment | 
|  | 134 | +	parts := strings.Split(strings.Trim(path, "/"), "/") | 
|  | 135 | +	if len(parts) > 0 { | 
|  | 136 | +		return parts[0] | 
|  | 137 | +	} | 
|  | 138 | + | 
|  | 139 | +	return "unknown" | 
|  | 140 | +} | 
|  | 141 | + | 
|  | 142 | +// extractModelFromRequest attempts to extract the model name from the request | 
|  | 143 | +func extractModelFromRequest(c *fiber.Ctx) string { | 
|  | 144 | +	// Try query parameter first | 
|  | 145 | +	model := c.Query("model") | 
|  | 146 | +	if model != "" { | 
|  | 147 | +		return model | 
|  | 148 | +	} | 
|  | 149 | + | 
|  | 150 | +	// Try to extract from JSON body for POST requests | 
|  | 151 | +	if c.Method() == fiber.MethodPost { | 
|  | 152 | +		// Read body | 
|  | 153 | +		bodyBytes := c.Body() | 
|  | 154 | +		if len(bodyBytes) > 0 { | 
|  | 155 | +			// Parse JSON | 
|  | 156 | +			var reqBody map[string]interface{} | 
|  | 157 | +			if err := json.Unmarshal(bodyBytes, &reqBody); err == nil { | 
|  | 158 | +				if modelVal, ok := reqBody["model"]; ok { | 
|  | 159 | +					if modelStr, ok := modelVal.(string); ok { | 
|  | 160 | +						return modelStr | 
|  | 161 | +					} | 
|  | 162 | +				} | 
|  | 163 | +			} | 
|  | 164 | +		} | 
|  | 165 | +	} | 
|  | 166 | + | 
|  | 167 | +	// Try path parameter for endpoints like /models/:model | 
|  | 168 | +	model = c.Params("model") | 
|  | 169 | +	if model != "" { | 
|  | 170 | +		return model | 
|  | 171 | +	} | 
|  | 172 | + | 
|  | 173 | +	return "unknown" | 
|  | 174 | +} | 
0 commit comments