@@ -3,6 +3,7 @@ package pricing
3
3
import (
4
4
"context"
5
5
"fmt"
6
+ "slices"
6
7
"sync"
7
8
"time"
8
9
@@ -26,6 +27,8 @@ type PricingManager struct {
26
27
pricingData map [string ]configstore.TableModelPricing
27
28
mu sync.RWMutex
28
29
30
+ modelPool map [schemas.ModelProvider ][]string
31
+
29
32
// Background sync worker
30
33
syncTicker * time.Ticker
31
34
done chan struct {}
@@ -73,9 +76,12 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem
73
76
configStore : configStore ,
74
77
logger : logger ,
75
78
pricingData : make (map [string ]configstore.TableModelPricing ),
79
+ modelPool : make (map [schemas.ModelProvider ][]string ),
76
80
done : make (chan struct {}),
77
81
}
78
82
83
+ logger .Info ("initializing pricing manager..." )
84
+
79
85
if configStore != nil {
80
86
// Load initial pricing data
81
87
if err := pm .loadPricingFromDatabase (ctx ); err != nil {
@@ -86,14 +92,16 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem
86
92
if err := pm .syncPricing (ctx ); err != nil {
87
93
return nil , fmt .Errorf ("failed to sync pricing data: %w" , err )
88
94
}
89
-
90
95
} else {
91
96
// Load pricing data from config memory
92
97
if err := pm .loadPricingIntoMemory (); err != nil {
93
98
return nil , fmt .Errorf ("failed to load pricing data from config memory: %w" , err )
94
99
}
95
100
}
96
101
102
+ // Populate model pool with normalized providers
103
+ pm .populateModelPool ()
104
+
97
105
// Start background sync worker
98
106
pm .startSyncWorker (ctx )
99
107
pm .configStore = configStore
@@ -327,6 +335,78 @@ func (pm *PricingManager) CalculateCostFromUsage(provider string, model string,
327
335
return totalCost
328
336
}
329
337
338
+ // populateModelPool populates the model pool with all available models per provider (thread-safe)
339
+ func (pm * PricingManager ) populateModelPool () {
340
+ // Clear existing model pool
341
+ pm .modelPool = make (map [schemas.ModelProvider ][]string )
342
+
343
+ // Map to track unique models per provider
344
+ providerModels := make (map [schemas.ModelProvider ]map [string ]bool )
345
+
346
+ // Iterate through all pricing data to collect models per provider
347
+ for _ , pricing := range pm .pricingData {
348
+ // Normalize provider before adding to model pool
349
+ normalizedProvider := schemas .ModelProvider (normalizeProvider (pricing .Provider ))
350
+
351
+ // Initialize map for this provider if not exists
352
+ if providerModels [normalizedProvider ] == nil {
353
+ providerModels [normalizedProvider ] = make (map [string ]bool )
354
+ }
355
+
356
+ // Add model to the provider's model set (using map for deduplication)
357
+ providerModels [normalizedProvider ][pricing.Model ] = true
358
+ }
359
+
360
+ // Convert sets to slices
361
+ for provider , modelSet := range providerModels {
362
+ models := make ([]string , 0 , len (modelSet ))
363
+ for model := range modelSet {
364
+ models = append (models , model )
365
+ }
366
+ pm .modelPool [provider ] = models
367
+ }
368
+
369
+ // Log the populated model pool for debugging
370
+ totalModels := 0
371
+ for provider , models := range pm .modelPool {
372
+ totalModels += len (models )
373
+ pm .logger .Debug ("populated %d models for provider %s" , len (models ), string (provider ))
374
+ }
375
+ pm .logger .Info ("populated model pool with %d models across %d providers" , totalModels , len (pm .modelPool ))
376
+ }
377
+
378
+ // GetModelsForProvider returns all available models for a given provider (thread-safe)
379
+ func (pm * PricingManager ) GetModelsForProvider (provider schemas.ModelProvider ) []string {
380
+ pm .mu .RLock ()
381
+ defer pm .mu .RUnlock ()
382
+
383
+ models , exists := pm .modelPool [provider ]
384
+ if ! exists {
385
+ return []string {}
386
+ }
387
+
388
+ // Return a copy to prevent external modification
389
+ result := make ([]string , len (models ))
390
+ copy (result , models )
391
+ return result
392
+ }
393
+
394
+ // GetProvidersForModel returns all providers for a given model (thread-safe)
395
+ func (pm * PricingManager ) GetProvidersForModel (model string ) []schemas.ModelProvider {
396
+ pm .mu .RLock ()
397
+ defer pm .mu .RUnlock ()
398
+
399
+ providers := make ([]schemas.ModelProvider , 0 )
400
+ for provider , models := range pm .modelPool {
401
+ if slices .Contains (models , model ) {
402
+ providers = append (providers , provider )
403
+ } else if slices .Contains (models , model ) {
404
+ providers = append (providers , provider )
405
+ }
406
+ }
407
+ return providers
408
+ }
409
+
330
410
// getPricing returns pricing information for a model (thread-safe)
331
411
func (pm * PricingManager ) getPricing (model , provider string , requestType schemas.RequestType ) (* configstore.TableModelPricing , bool ) {
332
412
pm .mu .RLock ()
0 commit comments