Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions agent-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,8 @@
"a2a",
"lsp",
"user_prompt",
"openapi"
"openapi",
"model_picker"
]
},
"instruction": {
Expand Down Expand Up @@ -840,6 +841,13 @@
"items": {
"type": "string"
}
},
"models": {
"type": "array",
"description": "List of allowed models for the model_picker tool.",
"items": {
"type": "string"
}
}
},
"additionalProperties": false,
Expand Down Expand Up @@ -890,7 +898,8 @@
"api",
"a2a",
"lsp",
"user_prompt"
"user_prompt",
"model_picker"
]
}
}
Expand Down Expand Up @@ -958,6 +967,22 @@
]
}
]
},
{
"allOf": [
{
"properties": {
"type": {
"const": "model_picker"
}
}
},
{
"required": [
"models"
]
}
]
}
]
},
Expand Down
48 changes: 48 additions & 0 deletions examples/model_picker.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env docker agent run

# This example demonstrates the model_picker toolset, which lets the agent
# dynamically switch between models mid-conversation. The agent can pick the
# best model for each sub-task (e.g. a fast model for simple questions, a
# powerful one for complex reasoning) and revert back when done.

agents:
root:
model: google/gemini-2.5-flash-lite
description: A versatile assistant that picks the best model for each task
instruction: |
You are a helpful assistant with access to multiple AI models.
toolsets:
- type: filesystem
- type: shell
- type: model_picker
instruction: |
{ORIGINAL_INSTRUCTIONS}

## Model selection policy

Your default model (`gemini-2.5-flash-lite`) is fast and cheap but
limited. You MUST follow this policy for every user message:

1. **Classify first.** Decide whether the request is *trivial*
(greetings, single-fact lookups, yes/no answers, short
clarifications) or *non-trivial* (anything else: writing, coding,
analysis, planning, multi-step reasoning, tool use, etc.).

2. **Trivial → stay on the default model.** Answer directly.

3. **Non-trivial → switch before you do any work.**
Call `change_model` to `claude-haiku-4-5` as the very first action,
*before* reasoning, planning, or calling any other tool.
Then carry out the task.

4. **ALWAYS revert when done.** After completing a non-trivial task,
you MUST call `revert_model` as your very last action so the next
turn starts on the cheap default again. This is mandatory—treat
it as the final step of every non-trivial request. Never end your
turn on a non-default model.

**Important:** never start working on a non-trivial task while still
on the default model. When in doubt, switch.
models:
- google/gemini-2.5-flash-lite
- anthropic/claude-haiku-4-5
3 changes: 3 additions & 0 deletions pkg/config/latest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,9 @@ type Toolset struct {

// For the `fetch` tool
Timeout int `json:"timeout,omitempty"`

// For the `model_picker` tool
Models []string `json:"models,omitempty"`
}

func (t *Toolset) UnmarshalYAML(unmarshal func(any) error) error {
Expand Down
7 changes: 7 additions & 0 deletions pkg/config/latest/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ func (t *Toolset) validate() error {
if len(t.FileTypes) > 0 && t.Type != "lsp" {
return errors.New("file_types can only be used with type 'lsp'")
}
if len(t.Models) > 0 && t.Type != "model_picker" {
return errors.New("models can only be used with type 'model_picker'")
}
if t.Sandbox != nil && t.Type != "shell" {
return errors.New("sandbox can only be used with type 'shell'")
}
Expand Down Expand Up @@ -154,6 +157,10 @@ func (t *Toolset) validate() error {
if t.URL == "" {
return errors.New("openapi toolset requires a url to be set")
}
case "model_picker":
if len(t.Models) == 0 {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MEDIUM: Missing validation for individual model strings

The validation checks that len(t.Models) > 0, but it doesn't validate whether each string in the Models slice is non-empty or valid. A configuration like this would pass validation:

model_picker:
  models:
    - ""
    - "gpt-4"

But would cause runtime errors when the agent tries to switch to the empty model string. Consider adding:

for i, model := range t.Models {
    if strings.TrimSpace(model) == "" {
        return fmt.Errorf("toolset %q: models[%d] cannot be empty", t.Name, i)
    }
}

return errors.New("model_picker toolset requires at least one model in the 'models' list")
}
}

return nil
Expand Down
112 changes: 80 additions & 32 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,6 @@ func ResumeReject(reason string) ResumeRequest {
// ToolHandlerFunc is a function type for handling tool calls
type ToolHandlerFunc func(ctx context.Context, sess *session.Session, toolCall tools.ToolCall, events chan Event) (*tools.ToolCallResult, error)

type ToolHandler struct {
handler ToolHandlerFunc
tool tools.Tool
}

// ElicitationRequestHandler is a function type for handling elicitation requests
type ElicitationRequestHandler func(ctx context.Context, message string, schema map[string]any) (map[string]any, error)

Expand Down Expand Up @@ -196,7 +191,7 @@ type ToolsChangeSubscriber interface {

// LocalRuntime manages the execution of agents
type LocalRuntime struct {
toolMap map[string]ToolHandler
toolMap map[string]ToolHandlerFunc
team *team.Team
currentAgent string
resumeChan chan ResumeRequest
Expand Down Expand Up @@ -297,7 +292,7 @@ func NewLocalRuntime(agents *team.Team, opts ...Opt) (*LocalRuntime, error) {
}

r := &LocalRuntime{
toolMap: make(map[string]ToolHandler),
toolMap: make(map[string]ToolHandlerFunc),
team: agents,
currentAgent: defaultAgent.Name(),
resumeChan: make(chan ResumeRequest),
Expand Down Expand Up @@ -909,30 +904,14 @@ func (r *LocalRuntime) emitToolsProgressively(ctx context.Context, a *agent.Agen
send(ToolsetInfo(totalTools, false, r.currentAgent))
}

// registerDefaultTools registers the default tool handlers
// registerDefaultTools registers the runtime-managed tool handlers.
// The tool definitions themselves come from the agent's toolsets; this only
// maps tool names to the runtime handler functions that implement them.
func (r *LocalRuntime) registerDefaultTools() {
slog.Debug("Registering default tools")

tt := builtin.NewTransferTaskTool()
ht := builtin.NewHandoffTool()
ttTools, _ := tt.Tools(context.TODO())
htTools, _ := ht.Tools(context.TODO())
allTools := append(ttTools, htTools...)

handlers := map[string]ToolHandlerFunc{
builtin.ToolNameTransferTask: r.handleTaskTransfer,
builtin.ToolNameHandoff: r.handleHandoff,
}

for _, t := range allTools {
if h, exists := handlers[t.Name]; exists {
r.toolMap[t.Name] = ToolHandler{handler: h, tool: t}
} else {
slog.Warn("No handler found for default tool", "tool", t.Name)
}
}

slog.Debug("Registered default tools", "count", len(r.toolMap))
r.toolMap[builtin.ToolNameTransferTask] = r.handleTaskTransfer
r.toolMap[builtin.ToolNameHandoff] = r.handleHandoff
r.toolMap[builtin.ToolNameChangeModel] = r.handleChangeModel
r.toolMap[builtin.ToolNameRevertModel] = r.handleRevertModel
}

func (r *LocalRuntime) finalizeEventChannel(ctx context.Context, sess *session.Session, events chan Event) {
Expand Down Expand Up @@ -1579,8 +1558,8 @@ func (r *LocalRuntime) processToolCalls(ctx context.Context, sess *session.Sessi
// Pick the handler: runtime-managed tools (transfer_task, handoff)
// have dedicated handlers; everything else goes through the toolset.
var runTool func()
if def, exists := r.toolMap[toolCall.Function.Name]; exists {
runTool = func() { r.runAgentTool(callCtx, def.handler, sess, toolCall, tool, events, a) }
if handler, exists := r.toolMap[toolCall.Function.Name]; exists {
runTool = func() { r.runAgentTool(callCtx, handler, sess, toolCall, tool, events, a) }
} else {
runTool = func() { r.runTool(callCtx, tool, toolCall, events, sess, a) }
}
Expand Down Expand Up @@ -2089,6 +2068,75 @@ func (r *LocalRuntime) handleHandoff(_ context.Context, _ *session.Session, tool
return tools.ResultSuccess(handoffMessage), nil
}

// findModelPickerTool returns the ModelPickerTool from the current agent's
// toolsets, or nil if the agent has no model_picker configured.
func (r *LocalRuntime) findModelPickerTool() *builtin.ModelPickerTool {
a, err := r.team.Agent(r.currentAgent)
if err != nil {
return nil
}
for _, ts := range a.ToolSets() {
if mpt, ok := tools.As[*builtin.ModelPickerTool](ts); ok {
return mpt
}
}
return nil
}

// handleChangeModel handles the change_model tool call by switching the current agent's model.
func (r *LocalRuntime) handleChangeModel(ctx context.Context, _ *session.Session, toolCall tools.ToolCall, events chan Event) (*tools.ToolCallResult, error) {
var params builtin.ChangeModelArgs
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &params); err != nil {
return nil, fmt.Errorf("invalid arguments: %w", err)
}

if params.Model == "" {
return tools.ResultError("model parameter is required"), nil
}

// Validate the requested model against the allowed list
mpt := r.findModelPickerTool()
if mpt == nil {
return tools.ResultError("model_picker is not configured for this agent"), nil
}
allowed := mpt.AllowedModels()
if !slices.Contains(allowed, params.Model) {
return tools.ResultError(fmt.Sprintf(
"model %q is not in the allowed list. Available models: %s",
params.Model, strings.Join(allowed, ", "),
)), nil
}

return r.setModelAndEmitInfo(ctx, params.Model, events)
}

// handleRevertModel handles the revert_model tool call by reverting the current agent to its default model.
func (r *LocalRuntime) handleRevertModel(ctx context.Context, _ *session.Session, _ tools.ToolCall, events chan Event) (*tools.ToolCallResult, error) {
return r.setModelAndEmitInfo(ctx, "", events)
}

// setModelAndEmitInfo sets the model for the current agent and emits an updated
// AgentInfo event so the UI reflects the change. An empty modelRef reverts to
// the agent's default model.
func (r *LocalRuntime) setModelAndEmitInfo(ctx context.Context, modelRef string, events chan Event) (*tools.ToolCallResult, error) {
if err := r.SetAgentModel(ctx, r.currentAgent, modelRef); err != nil {
return tools.ResultError(fmt.Sprintf("failed to set model: %v", err)), nil
}

if a, err := r.team.Agent(r.currentAgent); err == nil {
events <- AgentInfo(a.Name(), r.getEffectiveModelID(a), a.Description(), a.WelcomeMessage())
} else {
slog.Warn("Failed to retrieve agent after model change; UI may not reflect the update", "agent", r.currentAgent, "error", err)
}

if modelRef == "" {
slog.Info("Model reverted via model_picker tool", "agent", r.currentAgent)
return tools.ResultSuccess("Model reverted to the agent's default model"), nil
}
slog.Info("Model changed via model_picker tool", "agent", r.currentAgent, "model", modelRef)
return tools.ResultSuccess(fmt.Sprintf("Model changed to %s", modelRef)), nil
}

// Summarize generates a summary for the session based on the conversation history.
// The additionalPrompt parameter allows users to provide additional instructions
// for the summarization (e.g., "focus on code changes" or "include action items").
Expand Down
8 changes: 8 additions & 0 deletions pkg/teamloader/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func NewDefaultToolsetRegistry() *ToolsetRegistry {
r.Register("lsp", createLSPTool)
r.Register("user_prompt", createUserPromptTool)
r.Register("openapi", createOpenAPITool)
r.Register("model_picker", createModelPickerTool)
return r
}

Expand Down Expand Up @@ -327,3 +328,10 @@ func createOpenAPITool(ctx context.Context, toolset latest.Toolset, _ string, ru

return builtin.NewOpenAPITool(specURL, headers), nil
}

func createModelPickerTool(_ context.Context, toolset latest.Toolset, _ string, _ *config.RuntimeConfig) (tools.ToolSet, error) {
if len(toolset.Models) == 0 {
return nil, fmt.Errorf("model_picker toolset requires at least one model")
}
return builtin.NewModelPickerTool(toolset.Models), nil
}
81 changes: 81 additions & 0 deletions pkg/tools/builtin/model_picker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package builtin

import (
"context"
"fmt"
"strings"

"github.com/docker/cagent/pkg/tools"
)

const (
ToolNameChangeModel = "change_model"
ToolNameRevertModel = "revert_model"
)

// ModelPickerTool provides tools for dynamically switching the agent's model mid-conversation.
type ModelPickerTool struct {
models []string // list of available model references
}

// Verify interface compliance
var (
_ tools.ToolSet = (*ModelPickerTool)(nil)
_ tools.Instructable = (*ModelPickerTool)(nil)
)

// ChangeModelArgs are the arguments for the change_model tool.
type ChangeModelArgs struct {
Model string `json:"model" jsonschema:"The model to switch to. Must be one of the available models."`
}

// NewModelPickerTool creates a new ModelPickerTool with the given list of allowed models.
func NewModelPickerTool(models []string) *ModelPickerTool {
return &ModelPickerTool{models: models}
}

// Instructions returns guidance for the LLM on when and how to use the model picker tools.
func (t *ModelPickerTool) Instructions() string {
return "## Model Switching\n\n" +
"You have access to multiple models and can switch between them mid-conversation " +
"using the `" + ToolNameChangeModel + "` and `" + ToolNameRevertModel + "` tools.\n\n" +
"Available models: " + strings.Join(t.models, ", ") + ".\n\n" +
"Use `" + ToolNameChangeModel + "` when the current task would benefit from a different model's strengths " +
"(e.g., switching to a faster model for simple tasks or a more capable model for complex reasoning).\n" +
"Use `" + ToolNameRevertModel + "` to return to the original model after the specialized task is complete."
}

// AllowedModels returns the list of models this tool allows switching to.
func (t *ModelPickerTool) AllowedModels() []string {
return t.models
}

// Tools returns the change_model and revert_model tool definitions.
func (t *ModelPickerTool) Tools(context.Context) ([]tools.Tool, error) {
return []tools.Tool{
{
Name: ToolNameChangeModel,
Category: "model",
Description: fmt.Sprintf(
"Change the current model to one of the available models: %s. "+
"Use this when you need a different model for the current task.",
strings.Join(t.models, ", "),
),
Parameters: tools.MustSchemaFor[ChangeModelArgs](),
Annotations: tools.ToolAnnotations{
ReadOnlyHint: true,
Title: "Change Model",
},
},
{
Name: ToolNameRevertModel,
Category: "model",
Description: "Revert to the agent's original/default model. " +
"Use this after completing a task that required a different model.",
Annotations: tools.ToolAnnotations{
ReadOnlyHint: true,
Title: "Revert Model",
},
},
}, nil
}
Loading
Loading