Skip to content

Commit c7b303f

Browse files
feat: vk provider routing added
1 parent 469f117 commit c7b303f

File tree

14 files changed

+523
-65
lines changed

14 files changed

+523
-65
lines changed

framework/configstore/migrations.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error {
1919
if err := migrationAddCustomProviderConfigJSONColumn(ctx, db); err != nil {
2020
return err
2121
}
22+
if err := migrationAddVirtualKeyProviderConfigTable(db); err != nil {
23+
return err
24+
}
2225
return nil
2326
}
2427

@@ -247,3 +250,33 @@ func migrationAddCustomProviderConfigJSONColumn(ctx context.Context, db *gorm.DB
247250
}
248251
return nil
249252
}
253+
254+
func migrationAddVirtualKeyProviderConfigTable(db *gorm.DB) error {
255+
m := migration.New(db, migration.DefaultOptions, []*migration.Migration{{
256+
ID: "addvirtualkeyproviderconfig",
257+
Migrate: func(tx *gorm.DB) error {
258+
migrator := tx.Migrator()
259+
260+
if !migrator.HasTable(&TableVirtualKeyProviderConfig{}) {
261+
if err := migrator.CreateTable(&TableVirtualKeyProviderConfig{}); err != nil {
262+
return err
263+
}
264+
}
265+
266+
return nil
267+
},
268+
Rollback: func(tx *gorm.DB) error {
269+
migrator := tx.Migrator()
270+
271+
if err := migrator.DropTable(&TableVirtualKeyProviderConfig{}); err != nil {
272+
return err
273+
}
274+
return nil
275+
},
276+
}})
277+
err := m.Migrate()
278+
if err != nil {
279+
return fmt.Errorf("error while running db migration: %s", err.Error())
280+
}
281+
return nil
282+
}

framework/configstore/sqlite.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,7 @@ func (s *SQLiteConfigStore) GetVirtualKeys(ctx context.Context) ([]TableVirtualK
845845
Preload("Customer").
846846
Preload("Budget").
847847
Preload("RateLimit").
848+
Preload("ProviderConfigs").
848849
Preload("Keys", func(db *gorm.DB) *gorm.DB {
849850
return db.Select("id, key_id, models_json")
850851
}).Find(&virtualKeys).Error; err != nil {
@@ -861,6 +862,7 @@ func (s *SQLiteConfigStore) GetVirtualKey(ctx context.Context, id string) (*Tabl
861862
Preload("Customer").
862863
Preload("Budget").
863864
Preload("RateLimit").
865+
Preload("ProviderConfigs").
864866
Preload("Keys", func(db *gorm.DB) *gorm.DB {
865867
return db.Select("id, key_id, models_json")
866868
}).First(&virtualKey, "id = ?", id).Error; err != nil {
@@ -923,6 +925,47 @@ func (s *SQLiteConfigStore) UpdateVirtualKey(ctx context.Context, virtualKey *Ta
923925
return nil
924926
}
925927

928+
func (s *SQLiteConfigStore) GetVirtualKeyProviderConfigsFromValue(virtualKeyValue string) ([]TableVirtualKeyProviderConfig, error) {
929+
var virtualKey TableVirtualKey
930+
if err := s.db.First(&virtualKey, "value = ?", virtualKeyValue).Error; err != nil {
931+
return nil, err
932+
}
933+
934+
if virtualKey.ID == "" {
935+
return nil, nil
936+
}
937+
938+
var providerConfigs []TableVirtualKeyProviderConfig
939+
if err := s.db.Where("virtual_key_id = ?", virtualKey.ID).Find(&providerConfigs).Error; err != nil {
940+
return nil, err
941+
}
942+
return providerConfigs, nil
943+
}
944+
945+
func (s *SQLiteConfigStore) CreateVirtualKeyProviderConfig(virtualKeyProviderConfig *TableVirtualKeyProviderConfig, tx ...*gorm.DB) error {
946+
var txDB *gorm.DB
947+
if len(tx) > 0 {
948+
txDB = tx[0]
949+
} else {
950+
txDB = s.db
951+
}
952+
return txDB.Create(virtualKeyProviderConfig).Error
953+
}
954+
955+
func (s *SQLiteConfigStore) UpdateVirtualKeyProviderConfig(virtualKeyProviderConfig *TableVirtualKeyProviderConfig, tx ...*gorm.DB) error {
956+
var txDB *gorm.DB
957+
if len(tx) > 0 {
958+
txDB = tx[0]
959+
} else {
960+
txDB = s.db
961+
}
962+
return txDB.Save(virtualKeyProviderConfig).Error
963+
}
964+
965+
func (s *SQLiteConfigStore) DeleteVirtualKeyProviderConfig(id uint) error {
966+
return s.db.Delete(&TableVirtualKeyProviderConfig{}, "id = ?", id).Error
967+
}
968+
926969
// GetKeysByIDs retrieves multiple keys by their IDs
927970
func (s *SQLiteConfigStore) GetKeysByIDs(ctx context.Context, ids []string) ([]TableKey, error) {
928971
if len(ids) == 0 {

framework/configstore/store.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ type ConfigStore interface {
5959
UpdateVirtualKey(ctx context.Context, virtualKey *TableVirtualKey, tx ...*gorm.DB) error
6060
DeleteVirtualKey(ctx context.Context, id string) error
6161

62+
// Virtual key provider config CRUD
63+
GetVirtualKeyProviderConfigsFromValue(virtualKeyValue string) ([]TableVirtualKeyProviderConfig, error)
64+
CreateVirtualKeyProviderConfig(virtualKeyProviderConfig *TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
65+
UpdateVirtualKeyProviderConfig(virtualKeyProviderConfig *TableVirtualKeyProviderConfig, tx ...*gorm.DB) error
66+
DeleteVirtualKeyProviderConfig(id uint) error
67+
6268
// Team CRUD
6369
GetTeams(ctx context.Context, customerID string) ([]TableTeam, error)
6470
GetTeam(ctx context.Context, id string) (*TableTeam, error)

framework/configstore/tables.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -616,13 +616,13 @@ type TableTeam struct {
616616

617617
// TableVirtualKey represents a virtual key with budget, rate limits, and team/customer association
618618
type TableVirtualKey struct {
619-
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
620-
Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"`
621-
Description string `gorm:"type:text" json:"description,omitempty"`
622-
Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:varchar(255);not null" json:"value"` // The virtual key value
623-
IsActive bool `gorm:"default:true" json:"is_active"`
624-
AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // Empty means all models allowed
625-
AllowedProviders []string `gorm:"type:text;serializer:json" json:"allowed_providers"` // Empty means all providers allowed
619+
ID string `gorm:"primaryKey;type:varchar(255)" json:"id"`
620+
Name string `gorm:"uniqueIndex:idx_virtual_key_name;type:varchar(255);not null" json:"name"`
621+
Description string `gorm:"type:text" json:"description,omitempty"`
622+
Value string `gorm:"uniqueIndex:idx_virtual_key_value;type:varchar(255);not null" json:"value"` // The virtual key value
623+
IsActive bool `gorm:"default:true" json:"is_active"`
624+
AllowedModels []string `gorm:"type:text;serializer:json" json:"allowed_models"` // Empty means all models allowed
625+
ProviderConfigs []TableVirtualKeyProviderConfig `gorm:"foreignKey:VirtualKeyID;constraint:OnDelete:CASCADE" json:"provider_configs"` // Empty means all providers allowed
626626

627627
// Foreign key relationships (mutually exclusive: either TeamID or CustomerID, not both)
628628
TeamID *string `gorm:"type:varchar(255);index" json:"team_id,omitempty"`
@@ -641,6 +641,14 @@ type TableVirtualKey struct {
641641
UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"`
642642
}
643643

644+
// TableVirtualKeyProviderConfig represents a provider configuration for a virtual key
645+
type TableVirtualKeyProviderConfig struct {
646+
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
647+
VirtualKeyID string `gorm:"type:varchar(255);not null" json:"virtual_key_id"`
648+
Provider string `gorm:"type:varchar(50);not null" json:"provider"`
649+
Weight float64 `gorm:"default:1.0" json:"weight"`
650+
}
651+
644652
// TableModelPricing represents pricing information for AI models
645653
type TableModelPricing struct {
646654
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
@@ -675,11 +683,14 @@ type TableModelPricing struct {
675683
}
676684

677685
// Table names
678-
func (TableBudget) TableName() string { return "governance_budgets" }
679-
func (TableRateLimit) TableName() string { return "governance_rate_limits" }
680-
func (TableCustomer) TableName() string { return "governance_customers" }
681-
func (TableTeam) TableName() string { return "governance_teams" }
682-
func (TableVirtualKey) TableName() string { return "governance_virtual_keys" }
686+
func (TableBudget) TableName() string { return "governance_budgets" }
687+
func (TableRateLimit) TableName() string { return "governance_rate_limits" }
688+
func (TableCustomer) TableName() string { return "governance_customers" }
689+
func (TableTeam) TableName() string { return "governance_teams" }
690+
func (TableVirtualKey) TableName() string { return "governance_virtual_keys" }
691+
func (TableVirtualKeyProviderConfig) TableName() string {
692+
return "governance_virtual_key_provider_configs"
693+
}
683694
func (TableConfig) TableName() string { return "governance_config" }
684695
func (TableModelPricing) TableName() string { return "governance_model_pricing" }
685696

framework/pricing/main.go

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pricing
33
import (
44
"context"
55
"fmt"
6+
"slices"
67
"sync"
78
"time"
89

@@ -26,6 +27,8 @@ type PricingManager struct {
2627
pricingData map[string]configstore.TableModelPricing
2728
mu sync.RWMutex
2829

30+
modelPool map[schemas.ModelProvider][]string
31+
2932
// Background sync worker
3033
syncTicker *time.Ticker
3134
done chan struct{}
@@ -73,9 +76,12 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem
7376
configStore: configStore,
7477
logger: logger,
7578
pricingData: make(map[string]configstore.TableModelPricing),
79+
modelPool: make(map[schemas.ModelProvider][]string),
7680
done: make(chan struct{}),
7781
}
7882

83+
logger.Info("initializing pricing manager...")
84+
7985
if configStore != nil {
8086
// Load initial pricing data
8187
if err := pm.loadPricingFromDatabase(ctx); err != nil {
@@ -86,14 +92,16 @@ func Init(ctx context.Context, configStore configstore.ConfigStore, logger schem
8692
if err := pm.syncPricing(ctx); err != nil {
8793
return nil, fmt.Errorf("failed to sync pricing data: %w", err)
8894
}
89-
9095
} else {
9196
// Load pricing data from config memory
9297
if err := pm.loadPricingIntoMemory(); err != nil {
9398
return nil, fmt.Errorf("failed to load pricing data from config memory: %w", err)
9499
}
95100
}
96101

102+
// Populate model pool with normalized providers
103+
pm.populateModelPool()
104+
97105
// Start background sync worker
98106
pm.startSyncWorker(ctx)
99107
pm.configStore = configStore
@@ -327,6 +335,78 @@ func (pm *PricingManager) CalculateCostFromUsage(provider string, model string,
327335
return totalCost
328336
}
329337

338+
// populateModelPool populates the model pool with all available models per provider (thread-safe)
339+
func (pm *PricingManager) populateModelPool() {
340+
// Clear existing model pool
341+
pm.modelPool = make(map[schemas.ModelProvider][]string)
342+
343+
// Map to track unique models per provider
344+
providerModels := make(map[schemas.ModelProvider]map[string]bool)
345+
346+
// Iterate through all pricing data to collect models per provider
347+
for _, pricing := range pm.pricingData {
348+
// Normalize provider before adding to model pool
349+
normalizedProvider := schemas.ModelProvider(normalizeProvider(pricing.Provider))
350+
351+
// Initialize map for this provider if not exists
352+
if providerModels[normalizedProvider] == nil {
353+
providerModels[normalizedProvider] = make(map[string]bool)
354+
}
355+
356+
// Add model to the provider's model set (using map for deduplication)
357+
providerModels[normalizedProvider][pricing.Model] = true
358+
}
359+
360+
// Convert sets to slices
361+
for provider, modelSet := range providerModels {
362+
models := make([]string, 0, len(modelSet))
363+
for model := range modelSet {
364+
models = append(models, model)
365+
}
366+
pm.modelPool[provider] = models
367+
}
368+
369+
// Log the populated model pool for debugging
370+
totalModels := 0
371+
for provider, models := range pm.modelPool {
372+
totalModels += len(models)
373+
pm.logger.Debug("populated %d models for provider %s", len(models), string(provider))
374+
}
375+
pm.logger.Info("populated model pool with %d models across %d providers", totalModels, len(pm.modelPool))
376+
}
377+
378+
// GetModelsForProvider returns all available models for a given provider (thread-safe)
379+
func (pm *PricingManager) GetModelsForProvider(provider schemas.ModelProvider) []string {
380+
pm.mu.RLock()
381+
defer pm.mu.RUnlock()
382+
383+
models, exists := pm.modelPool[provider]
384+
if !exists {
385+
return []string{}
386+
}
387+
388+
// Return a copy to prevent external modification
389+
result := make([]string, len(models))
390+
copy(result, models)
391+
return result
392+
}
393+
394+
// GetProvidersForModel returns all providers for a given model (thread-safe)
395+
func (pm *PricingManager) GetProvidersForModel(model string) []schemas.ModelProvider {
396+
pm.mu.RLock()
397+
defer pm.mu.RUnlock()
398+
399+
providers := make([]schemas.ModelProvider, 0)
400+
for provider, models := range pm.modelPool {
401+
if slices.Contains(models, model) {
402+
providers = append(providers, provider)
403+
} else if slices.Contains(models, model) {
404+
providers = append(providers, provider)
405+
}
406+
}
407+
return providers
408+
}
409+
330410
// getPricing returns pricing information for a model (thread-safe)
331411
func (pm *PricingManager) getPricing(model, provider string, requestType schemas.RequestType) (*configstore.TableModelPricing, bool) {
332412
pm.mu.RLock()

framework/pricing/sync.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ func (pm *PricingManager) syncPricing(ctx context.Context) error {
6262
pricingData, err := pm.loadPricingFromURL()
6363
if err != nil {
6464
// Check if we have existing data in database
65-
pricingRecords, err := pm.configStore.GetModelPrices(ctx)
66-
if err != nil {
65+
pricingRecords, pricingErr := pm.configStore.GetModelPrices(ctx)
66+
if pricingErr != nil {
6767
return fmt.Errorf("failed to get pricing records: %w", err)
6868
}
6969
if len(pricingRecords) > 0 {

plugins/governance/resolver.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,17 @@ func (r *BudgetResolver) isModelAllowed(vk *configstore.TableVirtualKey, model s
153153
// isProviderAllowed checks if the requested provider is allowed for this VK
154154
func (r *BudgetResolver) isProviderAllowed(vk *configstore.TableVirtualKey, provider schemas.ModelProvider) bool {
155155
// Empty AllowedProviders means all providers are allowed
156-
if len(vk.AllowedProviders) == 0 {
156+
if len(vk.ProviderConfigs) == 0 {
157157
return true
158158
}
159159

160-
return slices.Contains(vk.AllowedProviders, string(provider))
160+
for _, pc := range vk.ProviderConfigs {
161+
if pc.Provider == string(provider) {
162+
return true
163+
}
164+
}
165+
166+
return false
161167
}
162168

163169
// checkRateLimits checks the VK's rate limits using flexible approach

plugins/governance/tracker.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ func (t *UsageTracker) PerformStartupResets(ctx context.Context) error {
220220
}
221221
}
222222
}
223-
t.logger.Info("startup reset summary: VKs with RL=%d, without RL=%d, RL resets=%d", vksWithRateLimits, vksWithoutRateLimits, len(resetRateLimits))
224223
if len(errs) > 0 {
225224
t.logger.Error("startup reset encountered %d errors: %v", len(errs), errs)
226225
return fmt.Errorf("startup reset completed with %d errors", len(errs))

0 commit comments

Comments
 (0)