Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions agent-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,17 @@
"2m30s",
"5m"
]
},
"no_fallback_status_codes": {
"type": "array",
"description": "HTTP status codes for which the fallback chain should be skipped entirely when using a models gateway (--models-gateway). If the primary model returns one of these codes, the error is returned immediately without trying fallback models. Useful with gateways that perform their own routing.",
"items": {
"type": "integer"
},
"examples": [
[401, 403],
[429]
]
}
},
"additionalProperties": false
Expand Down
7 changes: 7 additions & 0 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Agent struct {
fallbackModels []provider.Provider // Fallback models to try if primary fails
fallbackRetries int // Number of retries per fallback model with exponential backoff
fallbackCooldown time.Duration // Duration to stick with fallback after non-retryable error
noFallbackStatusCodes []int // Status codes that skip fallback when using a gateway
modelOverrides atomic.Pointer[[]provider.Provider] // Optional model override(s) set at runtime (supports alloy)
subAgents []*Agent
handoffs []*Agent
Expand Down Expand Up @@ -188,6 +189,12 @@ func (a *Agent) FallbackCooldown() time.Duration {
return a.fallbackCooldown
}

// NoFallbackStatusCodes returns the HTTP status codes for which fallback should
// be skipped when using a models gateway.
func (a *Agent) NoFallbackStatusCodes() []int {
return a.noFallbackStatusCodes
}

// Commands returns the named commands configured for this agent.
func (a *Agent) Commands() types.Commands {
return a.commands
Expand Down
8 changes: 8 additions & 0 deletions pkg/agent/opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,11 @@ func WithThinkingConfigured(configured bool) Opt {
a.thinkingConfigured = configured
}
}

// WithNoFallbackStatusCodes sets the HTTP status codes for which the fallback
// chain should be skipped entirely when using a models gateway.
func WithNoFallbackStatusCodes(codes []int) Opt {
return func(a *Agent) {
a.noFallbackStatusCodes = codes
}
}
16 changes: 16 additions & 0 deletions pkg/config/latest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ type FallbackConfig struct {
// retrying the primary. Only applies after a non-retryable error (e.g., 429).
// Default is 1 minute. Use Go duration format (e.g., "1m", "30s", "2m30s").
Cooldown Duration `json:"cooldown"`
// NoFallbackStatusCodes is a list of HTTP status codes for which the fallback
// chain should be skipped entirely when using a models gateway. If the primary
// model returns one of these status codes, the error is returned immediately
// without trying fallback models. This is useful with gateways that perform
// their own routing and return specific codes (e.g., 401, 403) that should
// not trigger client-side fallback.
NoFallbackStatusCodes []int `json:"no_fallback_status_codes,omitempty"`
}

// Duration is a wrapper around time.Duration that supports YAML/JSON unmarshaling
Expand Down Expand Up @@ -257,6 +264,15 @@ func (a *AgentConfig) GetFallbackCooldown() time.Duration {
return 0
}

// GetNoFallbackStatusCodes returns the status codes for which fallback should be
// skipped when using a models gateway.
func (a *AgentConfig) GetNoFallbackStatusCodes() []int {
if a.Fallback != nil {
return a.Fallback.NoFallbackStatusCodes
}
return nil
}

// ModelConfig represents the configuration for a model
type ModelConfig struct {
// Name is the manifest model name (map key), populated at runtime.
Expand Down
6 changes: 6 additions & 0 deletions pkg/config/latest/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ func (a *AgentConfig) validateFallback() error {
return errors.New("fallback.cooldown must be non-negative")
}

for _, code := range a.Fallback.NoFallbackStatusCodes {
if code < 400 || code > 599 {
return errors.New("fallback.no_fallback_status_codes must contain HTTP error codes (400-599)")
}
}

return nil
}

Expand Down
3 changes: 3 additions & 0 deletions pkg/config/v4/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ type FallbackConfig struct {
// retrying the primary. Only applies after a non-retryable error (e.g., 429).
// Default is 1 minute. Use Go duration format (e.g., "1m", "30s", "2m30s").
Cooldown Duration `json:"cooldown"`
// NoFallbackStatusCodes is a list of HTTP status codes for which the fallback
// chain should be skipped entirely when using a models gateway.
NoFallbackStatusCodes []int `json:"no_fallback_status_codes,omitempty"`
}

// Duration is a wrapper around time.Duration that supports YAML/JSON unmarshaling
Expand Down
6 changes: 6 additions & 0 deletions pkg/config/v4/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ func (a *AgentConfig) validateFallback() error {
return errors.New("fallback.cooldown must be non-negative")
}

for _, code := range a.Fallback.NoFallbackStatusCodes {
if code < 400 || code > 599 {
return errors.New("fallback.no_fallback_status_codes must contain HTTP error codes (400-599)")
}
}

return nil
}

Expand Down
39 changes: 39 additions & 0 deletions pkg/runtime/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,18 @@ func getEffectiveRetries(a *agent.Agent) int {
return retries
}

// isNoFallbackStatusCode returns true if the given status code is in the
// agent's configured no-fallback set. Used with models gateways to short-circuit
// the fallback chain for specific error codes.
func isNoFallbackStatusCode(statusCode int, codes []int) bool {
for _, c := range codes {
if c == statusCode {
return true
}
}
return false
}

// tryModelWithFallback attempts to create a stream and get a response using the primary model,
// falling back to configured fallback models if the primary fails.
//
Expand Down Expand Up @@ -441,6 +453,13 @@ func (r *LocalRuntime) tryModelWithFallback(
"cooldown_until", cooldownState.until.Format(time.RFC3339))
}

// When using a models gateway, check if this error's status code should
// skip the entire fallback chain. The gateway handles its own routing,
// so certain status codes (e.g., auth errors) won't resolve by trying
// a different client-side fallback model.
useGateway := r.modelSwitcherCfg != nil && r.modelSwitcherCfg.ModelsGateway != ""
noFallbackCodes := a.NoFallbackStatusCodes()

var lastErr error
primaryFailedWithNonRetryable := false

Expand Down Expand Up @@ -508,6 +527,16 @@ func (r *LocalRuntime) tryModelWithFallback(
"model", modelEntry.provider.ID(),
"error", err)

if useGateway && len(noFallbackCodes) > 0 {
if sc := extractHTTPStatusCode(err); sc != 0 && isNoFallbackStatusCode(sc, noFallbackCodes) {
slog.Warn("Gateway no-fallback status code, skipping entire fallback chain",
"agent", a.Name(),
"status_code", sc,
"error", err)
return streamResult{}, nil, err
}
}

// Track if primary failed with non-retryable error
if !modelEntry.isFallback {
primaryFailedWithNonRetryable = true
Expand Down Expand Up @@ -544,6 +573,16 @@ func (r *LocalRuntime) tryModelWithFallback(
"model", modelEntry.provider.ID(),
"error", err)

if useGateway && len(noFallbackCodes) > 0 {
if sc := extractHTTPStatusCode(err); sc != 0 && isNoFallbackStatusCode(sc, noFallbackCodes) {
slog.Warn("Gateway no-fallback status code, skipping entire fallback chain",
"agent", a.Name(),
"status_code", sc,
"error", err)
return streamResult{}, nil, err
}
}

// Track if primary failed with non-retryable error
if !modelEntry.isFallback {
primaryFailedWithNonRetryable = true
Expand Down
196 changes: 196 additions & 0 deletions pkg/runtime/fallback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,202 @@ func TestFallbackModelsClonedWithThinkingEnabled(t *testing.T) {
})
}

func TestIsNoFallbackStatusCode(t *testing.T) {
t.Parallel()

tests := []struct {
name string
statusCode int
codes []int
expected bool
}{
{
name: "empty codes list",
statusCode: 401,
codes: nil,
expected: false,
},
{
name: "status code in list",
statusCode: 401,
codes: []int{401, 403},
expected: true,
},
{
name: "status code not in list",
statusCode: 429,
codes: []int{401, 403},
expected: false,
},
{
name: "single code match",
statusCode: 429,
codes: []int{429},
expected: true,
},
{
name: "zero status code",
statusCode: 0,
codes: []int{401, 403},
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := isNoFallbackStatusCode(tt.statusCode, tt.codes)
assert.Equal(t, tt.expected, result)
})
}
}

func TestGatewayNoFallbackStatusCodes(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
// Primary fails with 401 (configured as no-fallback when using gateway)
primary := &failingProvider{id: "primary/auth-fail", err: errors.New("401 unauthorized")}

// Fallback should NOT be tried because 401 is in no_fallback_status_codes
fallback := &countingProvider{
id: "fallback/should-not-be-called",
failCount: 0,
stream: newStreamBuilder().
AddContent("Fallback content").
AddStopWithUsage(5, 2).
Build(),
}

root := agent.New("root", "test",
agent.WithModel(primary),
agent.WithFallbackModel(fallback),
agent.WithFallbackRetries(2),
agent.WithNoFallbackStatusCodes([]int{401, 403}),
)

tm := team.New(team.WithAgents(root))
rt, err := NewLocalRuntime(tm,
WithSessionCompaction(false),
WithModelStore(mockModelStore{}),
WithModelSwitcherConfig(&ModelSwitcherConfig{
ModelsGateway: "https://gateway.example.com",
}),
)
require.NoError(t, err)

sess := session.New(session.WithUserMessage("test"))
sess.Title = "Gateway No-Fallback Test"

events := rt.RunStream(t.Context(), sess)

var gotError bool
var gotFallbackContent bool
for ev := range events {
if _, ok := ev.(*ErrorEvent); ok {
gotError = true
}
if choice, ok := ev.(*AgentChoiceEvent); ok {
if choice.Content == "Fallback content" {
gotFallbackContent = true
}
}
}

assert.True(t, gotError, "should get an error since 401 is in no-fallback codes")
assert.False(t, gotFallbackContent, "fallback should not be tried for no-fallback status code")
assert.Equal(t, 0, fallback.callCount, "fallback provider should not be called")
})
}

func TestGatewayNoFallbackStatusCodes_AllowsFallbackForOtherCodes(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
// Primary fails with 429 (NOT in the no-fallback list)
primary := &failingProvider{id: "primary/rate-limited", err: errors.New("429 too many requests")}

// Fallback should be tried since 429 is not in no_fallback_status_codes
successStream := newStreamBuilder().
AddContent("Fallback success").
AddStopWithUsage(10, 5).
Build()
fallback := &mockProvider{id: "fallback/success", stream: successStream}

root := agent.New("root", "test",
agent.WithModel(primary),
agent.WithFallbackModel(fallback),
agent.WithFallbackRetries(0),
agent.WithNoFallbackStatusCodes([]int{401, 403}), // Only 401 and 403 block fallback
)

tm := team.New(team.WithAgents(root))
rt, err := NewLocalRuntime(tm,
WithSessionCompaction(false),
WithModelStore(mockModelStore{}),
WithModelSwitcherConfig(&ModelSwitcherConfig{
ModelsGateway: "https://gateway.example.com",
}),
)
require.NoError(t, err)

sess := session.New(session.WithUserMessage("test"))
sess.Title = "Gateway Allows Fallback Test"

events := rt.RunStream(t.Context(), sess)

var gotFallbackContent bool
for ev := range events {
if choice, ok := ev.(*AgentChoiceEvent); ok {
if choice.Content == "Fallback success" {
gotFallbackContent = true
}
}
}

assert.True(t, gotFallbackContent, "should receive fallback content since 429 is not in no-fallback list")
})
}

func TestGatewayNoFallbackStatusCodes_NoEffectWithoutGateway(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
// Primary fails with 401
primary := &failingProvider{id: "primary/auth-fail", err: errors.New("401 unauthorized")}

// Even though 401 is in no_fallback_status_codes, without a gateway
// the fallback chain should proceed normally
successStream := newStreamBuilder().
AddContent("Fallback success without gateway").
AddStopWithUsage(10, 5).
Build()
fallback := &mockProvider{id: "fallback/success", stream: successStream}

root := agent.New("root", "test",
agent.WithModel(primary),
agent.WithFallbackModel(fallback),
agent.WithFallbackRetries(0),
agent.WithNoFallbackStatusCodes([]int{401, 403}),
)

tm := team.New(team.WithAgents(root))
// No WithModelSwitcherConfig — no gateway
rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{}))
require.NoError(t, err)

sess := session.New(session.WithUserMessage("test"))
sess.Title = "No Gateway Test"

events := rt.RunStream(t.Context(), sess)

var gotFallbackContent bool
for ev := range events {
if choice, ok := ev.(*AgentChoiceEvent); ok {
if choice.Content == "Fallback success without gateway" {
gotFallbackContent = true
}
}
}

assert.True(t, gotFallbackContent, "fallback should proceed normally without a gateway")
})
}

// Verify interface compliance
var (
_ provider.Provider = (*mockProvider)(nil)
Expand Down
3 changes: 3 additions & 0 deletions pkg/teamloader/teamloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c
agent.WithFallbackRetries(agentConfig.GetFallbackRetries()),
agent.WithFallbackCooldown(agentConfig.GetFallbackCooldown()),
)
if codes := agentConfig.GetNoFallbackStatusCodes(); len(codes) > 0 {
opts = append(opts, agent.WithNoFallbackStatusCodes(codes))
}
}

agentTools, warnings := getToolsForAgent(ctx, &agentConfig, parentDir, runConfig, loadOpts.toolsetRegistry)
Expand Down