Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 50 additions & 57 deletions pkg/model/provider/rulebased/client.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
// Package rulebased provides a rule-based model router that selects
// the appropriate model based on NLP analysis of the input using Bleve.
//
// Routes are defined with example texts, and Bleve's full-text search
// determines the best matching route based on text similarity.
// the appropriate model based on text similarity using Bleve full-text search.
//
// A model becomes a rule-based router when it has routing rules configured.
// The model's provider/model fields define the fallback model, and each
Expand Down Expand Up @@ -43,17 +40,11 @@ type ProviderFactory func(ctx context.Context, modelSpec string, models map[stri
// Client implements the Provider interface for rule-based model routing.
type Client struct {
base.Config
routes []route
routes []Provider
fallback Provider
index bleve.Index
}

// route represents a single routing rule.
type route struct {
model string
provider Provider
}

// NewClient creates a new rule-based routing client.
// The cfg parameter should have Routing rules configured. The provider/model
// fields of cfg define the fallback model that is used when no routing rule matches.
Expand All @@ -69,11 +60,21 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, models map[string]l
return nil, fmt.Errorf("creating bleve index: %w", err)
}

// Create fallback provider from the model's provider/model fields
// On any subsequent error, close the index before returning.
var cleanupErr error
defer func() {
if cleanupErr != nil {
_ = index.Close()
}
}()

routeOpts := filterOutMaxTokens(opts)

// Create fallback provider from the model's provider/model fields.
fallbackSpec := cfg.Provider + "/" + cfg.Model
fallback, err := providerFactory(ctx, fallbackSpec, models, env, filterOutMaxTokens(opts)...)
fallback, err := providerFactory(ctx, fallbackSpec, models, env, routeOpts...)
if err != nil {
_ = index.Close()
cleanupErr = err
return nil, fmt.Errorf("creating fallback provider %q: %w", fallbackSpec, err)
}

Expand All @@ -87,27 +88,28 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, models map[string]l
fallback: fallback,
}

// Process routing rules
// Process routing rules. Each example is indexed with a doc ID
// that encodes the route index (e.g. "r0_e1") so we can map
// search hits back to the corresponding provider.
for i, rule := range cfg.Routing {
if rule.Model == "" {
_ = index.Close()
return nil, fmt.Errorf("routing rule %d: 'model' field is required", i)
cleanupErr = fmt.Errorf("routing rule %d: 'model' field is required", i)
return nil, cleanupErr
}

provider, err := providerFactory(ctx, rule.Model, models, env, filterOutMaxTokens(opts)...)
provider, err := providerFactory(ctx, rule.Model, models, env, routeOpts...)
if err != nil {
_ = index.Close()
cleanupErr = err
return nil, fmt.Errorf("creating provider for routing rule %q: %w", rule.Model, err)
}

routeIndex := len(client.routes)
client.routes = append(client.routes, route{model: rule.Model, provider: provider})
client.routes = append(client.routes, provider)

// Index examples for this route
for j, example := range rule.Examples {
docID := fmt.Sprintf("r%d_e%d", routeIndex, j)
if err := index.Index(docID, map[string]any{"text": example, "route": routeIndex}); err != nil {
_ = index.Close()
if err := index.Index(docID, map[string]any{"text": example}); err != nil {
cleanupErr = err
return nil, fmt.Errorf("indexing example: %w", err)
}
}
Expand All @@ -124,27 +126,23 @@ func createIndex() (bleve.Index, error) {
textField := mapping.NewTextFieldMapping()
textField.Analyzer = "en"
docMapping.AddFieldMappingsAt("text", textField)
docMapping.AddFieldMappingsAt("route", mapping.NewNumericFieldMapping())

indexMapping.DefaultMapping = docMapping

return bleve.NewMemOnly(indexMapping)
}

// filterOutMaxTokens removes WithMaxTokens options from the slice.
// This is necessary because child providers may have different token limits
// than the parent router, and should determine their own limits.
// Child providers may have different token limits than the parent router.
func filterOutMaxTokens(opts []options.Opt) []options.Opt {
var filtered []options.Opt
for _, opt := range opts {
if opt == nil {
continue
}
// Test if this option sets maxTokens by applying it to an empty ModelOptions
var test options.ModelOptions
opt(&test)
// If maxTokens was set, skip this option
if test.MaxTokens() != 0 {
var probe options.ModelOptions
opt(&probe)
if probe.MaxTokens() != 0 {
continue
}
filtered = append(filtered, opt)
Expand Down Expand Up @@ -173,6 +171,7 @@ func (c *Client) CreateChatCompletionStream(
}

// selectProvider finds the best matching provider for the messages.
// Bleve returns hits sorted by score, so the top hit determines the route.
func (c *Client) selectProvider(messages []chat.Message) Provider {
userMessage := getLastUserMessage(messages)
if userMessage == "" {
Expand All @@ -183,8 +182,7 @@ func (c *Client) selectProvider(messages []chat.Message) Provider {
query.SetField("text")

searchRequest := bleve.NewSearchRequest(query)
searchRequest.Size = 10
searchRequest.Fields = []string{"route"}
searchRequest.Size = 1

results, err := c.index.Search(searchRequest)
if err != nil {
Expand All @@ -196,41 +194,36 @@ func (c *Client) selectProvider(messages []chat.Message) Provider {
return c.defaultProvider()
}

// Find best matching route by aggregating scores
scores := make(map[int]float64)
for _, hit := range results.Hits {
var routeIdx int
if _, err := fmt.Sscanf(hit.ID, "r%d_e", &routeIdx); err == nil {
if hit.Score > scores[routeIdx] {
scores[routeIdx] = hit.Score
}
}
// Parse the route index from the top hit's doc ID (e.g. "r2_e0" → 2).
hit := results.Hits[0]
routeIdx, ok := parseRouteIndex(hit.ID)
if !ok || routeIdx >= len(c.routes) {
return c.defaultProvider()
}

bestRoute, bestScore := -1, 0.0
for idx, score := range scores {
if score > bestScore {
bestRoute, bestScore = idx, score
}
}
selected := c.routes[routeIdx]
slog.Debug("Route matched",
"model", selected.ID(),
"score", hit.Score,
)
return selected
}

if bestRoute >= 0 && bestRoute < len(c.routes) {
slog.Debug("Route matched",
"model", c.routes[bestRoute].model,
"score", bestScore,
)
return c.routes[bestRoute].provider
// parseRouteIndex extracts the route index from a doc ID like "r2_e0".
func parseRouteIndex(docID string) (int, bool) {
var idx int
if _, err := fmt.Sscanf(docID, "r%d_e", &idx); err != nil || idx < 0 {
return 0, false
}

return c.defaultProvider()
return idx, true
}

func (c *Client) defaultProvider() Provider {
if c.fallback != nil {
return c.fallback
}
if len(c.routes) > 0 {
return c.routes[0].provider
return c.routes[0]
}
return nil
}
Expand Down
52 changes: 33 additions & 19 deletions pkg/model/provider/rulebased/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,9 @@ func (m *mockProvider) BaseConfig() base.Config {
// mockProviderFactory creates a mock provider factory for testing.
// It resolves model references from the models map or parses inline specs.
func mockProviderFactory(_ context.Context, modelSpec string, models map[string]latest.ModelConfig, _ environment.Provider, _ ...options.Opt) (Provider, error) {
// Check if it's a model reference
if cfg, exists := models[modelSpec]; exists {
return &mockProvider{id: cfg.Provider + "/" + cfg.Model}, nil
}
// Otherwise treat as inline spec
return &mockProvider{id: modelSpec}, nil
}

Expand All @@ -62,7 +60,7 @@ func TestNewClient(t *testing.T) {
name: "valid config with routing rules",
modelCfg: latest.ModelConfig{
Provider: "openai",
Model: "gpt-4o", // fallback
Model: "gpt-4o",
Routing: []latest.RoutingRule{
{
Model: "anthropic/claude-3-haiku",
Expand All @@ -80,7 +78,7 @@ func TestNewClient(t *testing.T) {
name: "routing with model references",
modelCfg: latest.ModelConfig{
Provider: "anthropic",
Model: "claude-haiku-4-5", // fallback
Model: "claude-haiku-4-5",
Routing: []latest.RoutingRule{
{
Model: "fast",
Expand Down Expand Up @@ -183,7 +181,7 @@ func TestClient_SelectProvider(t *testing.T) {

cfg := &latest.ModelConfig{
Provider: "openai",
Model: "gpt-4o", // fallback
Model: "gpt-4o",
Routing: []latest.RoutingRule{
{
Model: "anthropic/claude-3-haiku",
Expand Down Expand Up @@ -262,11 +260,9 @@ func TestCreateIndex(t *testing.T) {
require.NoError(t, err)
defer index.Close()

// Index a document
err = index.Index("test", map[string]any{"text": "hello world", "route": 0})
err = index.Index("test", map[string]any{"text": "hello world"})
require.NoError(t, err)

// Search for it
query := bleve.NewMatchQuery("hello")
query.SetField("text")
results, err := index.Search(bleve.NewSearchRequest(query))
Expand Down Expand Up @@ -298,10 +294,9 @@ func TestClient_ID(t *testing.T) {
func TestClient_DefaultProvider(t *testing.T) {
t.Parallel()

// Test that fallback is always used for empty messages
cfg := &latest.ModelConfig{
Provider: "openai",
Model: "gpt-4o", // fallback
Model: "gpt-4o",
Routing: []latest.RoutingRule{
{
Model: "anthropic/claude-3-haiku",
Expand All @@ -314,16 +309,13 @@ func TestClient_DefaultProvider(t *testing.T) {
require.NoError(t, err)
defer client.Close()

// Empty message should use fallback
provider := client.selectProvider(nil)
assert.Equal(t, "openai/gpt-4o", provider.ID())
}

func TestClient_CreateChatCompletionStream_NilProvider(t *testing.T) {
t.Parallel()

// Create a client with no routes and no fallback by directly manipulating the struct
// This simulates an edge case where defaultProvider returns nil
index, err := createIndex()
require.NoError(t, err)

Expand All @@ -335,7 +327,6 @@ func TestClient_CreateChatCompletionStream_NilProvider(t *testing.T) {
}
defer client.Close()

// Attempt to create stream should return error, not panic
messages := []chat.Message{{Role: chat.MessageRoleUser, Content: "hello"}}
_, err = client.CreateChatCompletionStream(t.Context(), messages, nil)
require.Error(t, err)
Expand All @@ -348,8 +339,6 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) {
// This test verifies that the models map and env are stored in the base config.
// This is required for CloneWithOptions to work correctly with routers
// that use model references (e.g., "fast" instead of "anthropic/claude-haiku-4-5").
// Without this, cloning a router would fail because model references can't be resolved
// and the environment provider would be nil.

models := map[string]latest.ModelConfig{
"fast": {Provider: "anthropic", Model: "claude-haiku-4-5"},
Expand All @@ -358,7 +347,7 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) {

cfg := &latest.ModelConfig{
Provider: "anthropic",
Model: "claude-haiku-4-5", // fallback
Model: "claude-haiku-4-5",
Routing: []latest.RoutingRule{
{
Model: "fast",
Expand All @@ -371,21 +360,46 @@ func TestClient_ModelsMapStoredInBaseConfig(t *testing.T) {
},
}

// Create a mock env provider
mockEnv := &mockEnvProvider{}

client, err := NewClient(t.Context(), cfg, models, mockEnv, mockProviderFactory)
require.NoError(t, err)
defer client.Close()

// Verify the models map and env are stored in the base config
baseConfig := client.BaseConfig()
assert.NotNil(t, baseConfig.Models, "Models map should be stored in base config for cloning")
assert.Equal(t, models, baseConfig.Models, "Models map should match what was passed to NewClient")
assert.NotNil(t, baseConfig.Env, "Env should be stored in base config for cloning")
assert.Equal(t, mockEnv, baseConfig.Env, "Env should match what was passed to NewClient")
}

func TestParseRouteIndex(t *testing.T) {
t.Parallel()

tests := []struct {
docID string
wantIdx int
wantOK bool
}{
{"r0_e0", 0, true},
{"r2_e5", 2, true},
{"r10_e3", 10, true},
{"invalid", 0, false},
{"", 0, false},
}

for _, tt := range tests {
t.Run(tt.docID, func(t *testing.T) {
t.Parallel()
idx, ok := parseRouteIndex(tt.docID)
assert.Equal(t, tt.wantOK, ok)
if ok {
assert.Equal(t, tt.wantIdx, idx)
}
})
}
}

// mockEnvProvider is a minimal mock for environment.Provider.
type mockEnvProvider struct{}

Expand Down