Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions agent-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,10 @@
"type": "string",
"description": "A comma-delimited list of regular expressions of tools to toonify"
},
"model": {
"type": "string",
"description": "Model to use for the LLM turn that processes tool results from this toolset. Enables per-tool model routing: cheaper/faster models handle simple tool results (e.g. knowledge-base lookups, file reads) while the agent's primary model handles complex reasoning. Value can be a model name from the models section or an inline provider/model format (e.g. 'openai/gpt-4o-mini')."
},
"ref": {
"type": "string",
"description": "Reference to a Docker MCP tool (e.g., 'docker:context7') or a named MCP definition from the top-level 'mcps' section"
Expand Down
46 changes: 46 additions & 0 deletions per_tool_model_routing.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Per-Tool Model Routing Example
#
# This example demonstrates how to use the `model` field on toolsets
# to automatically route specific tool results through a cheaper/faster
# model, while keeping the agent's primary model for complex reasoning.
#
# When the LLM calls a tool from a toolset with a `model` field, the
# next LLM turn (processing the tool results) uses the specified model
# instead of the agent's primary model. This is a one-shot override:
# subsequent turns return to the primary model.

version: v1

name: per-tool-model-routing

models:
primary:
provider: anthropic
model: claude-sonnet-4-5-20250514
fast:
provider: anthropic
model: claude-haiku-4-5-20250514

agents:
- name: assistant
model: primary
description: >
An assistant that uses a fast model for simple tool operations
and the primary model for complex reasoning.
instruction: >
You are a helpful assistant. Use the available tools to help the user.
toolsets:
# The filesystem toolset uses the fast model to process results.
# Reading files and listing directories are simple operations that
# don't need the most capable model to interpret.
- type: filesystem
model: fast

# The shell toolset also uses the fast model. Most shell command
# outputs (ls, cat, grep, etc.) are straightforward to interpret.
- type: shell
model: fast

# The think tool stays on the primary model (no model override).
# Complex reasoning benefits from the agent's full capabilities.
- type: think
5 changes: 5 additions & 0 deletions pkg/config/latest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,11 @@ type Toolset struct {
Instruction string `json:"instruction,omitempty"`
Toon string `json:"toon,omitempty"`

// Model overrides the LLM used for the turn that processes tool results
// from this toolset, enabling per-toolset model routing. Value can be a
// model name from the models section or "provider/model" (e.g. "openai/gpt-4o-mini").
Model string `json:"model,omitempty"`

Defer DeferConfig `json:"defer" yaml:"defer,omitempty"`

// For the `mcp` tool
Expand Down
42 changes: 32 additions & 10 deletions pkg/runtime/model_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,45 @@ func (r *LocalRuntime) SetAgentModel(ctx context.Context, agentName, modelRef st
return nil
}

// Try parsing as inline spec (provider/model)
// Try single inline spec (provider/model)
prov, err := r.resolveModelRef(ctx, modelRef)
if err != nil {
return fmt.Errorf("failed to resolve model %q: %w", modelRef, err)
}
a.SetModelOverride(prov)
slog.Info("Set agent model override (inline)", "agent", agentName, "model", prov.ID())
return nil
}

// resolveModelRef resolves a model reference to a single provider.
// The reference can be a named model from the config or an inline
// "provider/model" spec (e.g. "openai/gpt-4o-mini").
func (r *LocalRuntime) resolveModelRef(ctx context.Context, modelRef string) (provider.Provider, error) {
if r.modelSwitcherCfg == nil {
return nil, fmt.Errorf("model switching not configured for this runtime")
}

// Try named model from config first.
if modelCfg, exists := r.modelSwitcherCfg.Models[modelRef]; exists {
if isAlloyModelConfig(modelCfg) {
return nil, fmt.Errorf("model reference %q is an alloy (multi-model) config and cannot be used as a single model override", modelRef)
}
modelCfg.Name = modelRef
return r.createProviderFromConfig(ctx, &modelCfg)
}

// Try inline "provider/model" format.
providerName, modelName, ok := strings.Cut(modelRef, "/")
if !ok {
return fmt.Errorf("invalid model reference %q: expected a model name from config or 'provider/model' format", modelRef)
if !ok || providerName == "" || modelName == "" {
return nil, fmt.Errorf("invalid model reference %q: expected a model name from config or 'provider/model' format", modelRef)
}

inlineCfg := &latest.ModelConfig{
Provider: providerName,
Model: modelName,
}
prov, err := r.createProviderFromConfig(ctx, inlineCfg)
if err != nil {
return fmt.Errorf("failed to create inline model: %w", err)
}
a.SetModelOverride(prov)
slog.Info("Set agent model override (inline)", "agent", agentName, "model", prov.ID())
return nil

return r.createProviderFromConfig(ctx, inlineCfg)
}

// isAlloyModelConfig checks if a model config is an alloy model (multiple models).
Expand Down
55 changes: 55 additions & 0 deletions pkg/runtime/model_switcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,58 @@ func TestBuildCatalogChoicesWithDuplicates(t *testing.T) {
assert.NotEqual(t, "openai/gpt-4o", c.Ref, "should not include duplicates from config")
}
}

func TestResolveModelRef_RejectsAlloyConfig(t *testing.T) {
t.Parallel()

r := &LocalRuntime{
modelSwitcherCfg: &ModelSwitcherConfig{
Models: map[string]latest.ModelConfig{
// Alloy config: no provider, comma-separated models
"alloy_model": {Model: "openai/gpt-4o,anthropic/claude-sonnet-4-0"},
},
},
}

_, err := r.resolveModelRef(t.Context(), "alloy_model")
require.Error(t, err)
assert.Contains(t, err.Error(), "alloy")
}

func TestResolveModelRef_NilConfig(t *testing.T) {
t.Parallel()

r := &LocalRuntime{}

_, err := r.resolveModelRef(t.Context(), "openai/gpt-4o")
require.Error(t, err)
assert.Contains(t, err.Error(), "not configured")
}

func TestResolveModelRef_InvalidFormat(t *testing.T) {
t.Parallel()

r := &LocalRuntime{
modelSwitcherCfg: &ModelSwitcherConfig{
Models: map[string]latest.ModelConfig{},
},
}

tests := []struct {
name string
modelRef string
}{
{"no slash", "invalid"},
{"empty provider", "/model"},
{"empty model", "provider/"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := r.resolveModelRef(t.Context(), tt.modelRef)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid model reference")
})
}
}
29 changes: 29 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -1038,10 +1038,22 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
// Use a runtime copy of maxIterations so we don't modify the session's persistent config
runtimeMaxIterations := sess.MaxIterations

// toolModelOverride holds the per-toolset model from the most recent
// tool calls. It applies for one LLM turn, then resets.
var toolModelOverride string
var prevAgentName string

for {
// Set elicitation handler on all MCP toolsets before getting tools
a := r.CurrentAgent()

// Clear per-tool model override on agent switch so it doesn't
// leak from one agent's toolset into another agent's turn.
if a.Name() != prevAgentName {
toolModelOverride = ""
prevAgentName = a.Name()
}

r.emitAgentWarnings(a, events)
r.configureToolsetHandlers(a, events)

Expand Down Expand Up @@ -1115,6 +1127,20 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c

model := a.Model()

// Per-tool model routing: use a cheaper model for this turn
// if the previous tool calls specified one, then reset.
if toolModelOverride != "" {
if overrideModel, err := r.resolveModelRef(ctx, toolModelOverride); err != nil {
slog.Warn("Failed to resolve per-tool model override; using agent default",
"model_override", toolModelOverride, "error", err)
} else {
slog.Info("Using per-tool model override for this turn",
"agent", a.Name(), "override", overrideModel.ID(), "primary", model.ID())
model = overrideModel
}
toolModelOverride = ""
}

// Apply thinking setting based on session state.
// When thinking is disabled: clone with thinking=false to clear any thinking config.
// When thinking is enabled: clone with thinking=true to ensure defaults are applied
Expand Down Expand Up @@ -1290,6 +1316,9 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c

r.processToolCalls(ctx, sess, res.Calls, agentTools, events)

// Record per-toolset model override for the next LLM turn.
toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools)

if res.Stopped {
slog.Debug("Conversation stopped", "agent", a.Name())
break
Expand Down
31 changes: 31 additions & 0 deletions pkg/runtime/tool_model_override.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package runtime

import (
"log/slog"

"github.com/docker/docker-agent/pkg/tools"
)

// resolveToolCallModelOverride returns the per-toolset model override from the
// given tool calls, or "" if none. When multiple tools specify different
// overrides, the first one wins.
func resolveToolCallModelOverride(calls []tools.ToolCall, agentTools []tools.Tool) string {
if len(calls) == 0 {
return ""
}

toolMap := make(map[string]tools.Tool, len(agentTools))
for _, t := range agentTools {
toolMap[t.Name] = t
}

for _, call := range calls {
if t, ok := toolMap[call.Function.Name]; ok && t.ModelOverride != "" {
slog.Debug("Per-tool model override detected",
"tool", call.Function.Name, "model", t.ModelOverride)
return t.ModelOverride
}
}

return ""
}
82 changes: 82 additions & 0 deletions pkg/runtime/tool_model_override_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package runtime

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/docker/docker-agent/pkg/tools"
)

func TestResolveToolCallModelOverride_NoCalls(t *testing.T) {
result := resolveToolCallModelOverride(nil, nil)
assert.Empty(t, result)
}

func TestResolveToolCallModelOverride_NoOverride(t *testing.T) {
agentTools := []tools.Tool{
{Name: "read_file"},
{Name: "write_file"},
}
calls := []tools.ToolCall{
{Function: tools.FunctionCall{Name: "read_file"}},
}

result := resolveToolCallModelOverride(calls, agentTools)
assert.Empty(t, result)
}

func TestResolveToolCallModelOverride_SingleOverride(t *testing.T) {
agentTools := []tools.Tool{
{Name: "read_file", ModelOverride: "openai/gpt-4o-mini"},
{Name: "write_file"},
}
calls := []tools.ToolCall{
{Function: tools.FunctionCall{Name: "read_file"}},
}

result := resolveToolCallModelOverride(calls, agentTools)
assert.Equal(t, "openai/gpt-4o-mini", result)
}

func TestResolveToolCallModelOverride_FirstOverrideWins(t *testing.T) {
agentTools := []tools.Tool{
{Name: "read_file", ModelOverride: "openai/gpt-4o-mini"},
{Name: "search_kb", ModelOverride: "anthropic/claude-haiku"},
}
calls := []tools.ToolCall{
{Function: tools.FunctionCall{Name: "read_file"}},
{Function: tools.FunctionCall{Name: "search_kb"}},
}

result := resolveToolCallModelOverride(calls, agentTools)
assert.Equal(t, "openai/gpt-4o-mini", result)
}

func TestResolveToolCallModelOverride_MixedOverrideAndNonOverride(t *testing.T) {
agentTools := []tools.Tool{
{Name: "read_file"},
{Name: "search_kb", ModelOverride: "openai/gpt-4o-mini"},
}
calls := []tools.ToolCall{
{Function: tools.FunctionCall{Name: "read_file"}},
{Function: tools.FunctionCall{Name: "search_kb"}},
}

// read_file has no override, search_kb does. Since read_file is first
// but has no override, we skip it and use search_kb's.
result := resolveToolCallModelOverride(calls, agentTools)
assert.Equal(t, "openai/gpt-4o-mini", result)
}

func TestResolveToolCallModelOverride_UnknownTool(t *testing.T) {
agentTools := []tools.Tool{
{Name: "read_file"},
}
calls := []tools.ToolCall{
{Function: tools.FunctionCall{Name: "unknown_tool"}},
}

result := resolveToolCallModelOverride(calls, agentTools)
assert.Empty(t, result)
}
Loading
Loading