Skip to content
32 changes: 32 additions & 0 deletions docs/_core_features/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,38 @@ puts response.content
# => "Current weather at 52.52, 13.4: Temperature: 12.5°C, Wind Speed: 8.3 km/h, Conditions: Mainly clear, partly cloudy, and overcast."
```

### Tool Choice Control

Control when and how tools are called using `choice` and `parallel` options.

```ruby
chat = RubyLLM.chat(model: 'gpt-4o')

# Basic usage with defaults
chat.with_tools(Weather, Calculator) # uses provider defaults

# Force tool usage, one at a time
chat.with_tools(Weather, Calculator, choice: :required, parallel: false)

# Force specific tool
chat.with_tool(Weather, choice: :weather, parallel: true)
```

**Parameter Values:**
- **`choice`**: Controls tool choice behavior
- `:auto` Model decides whether to use any tools
- `:required` - Model must use one of the provided tools
- `:none` - Disable all tools
- `"tool_name"` - Force a specific tool (e.g., `:weather` for `Weather` tool)
- **`parallel`**: Controls parallel tool calls
- `true` Allow multiple tool calls simultaneously
- `false` - One at a time

If not provided, RubyLLM will use the provider's default behavior for tool choice and parallel tool calls.

> With `:required` or specific tool choices, the tool_choice is automatically reset to `nil` after tool execution to prevent infinite loops.
{: .note }

### Model Compatibility
{: .d-inline-block }

Expand Down
39 changes: 34 additions & 5 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module RubyLLM
class Chat
include Enumerable

attr_reader :model, :messages, :tools, :params, :headers, :schema
attr_reader :model, :messages, :tools, :tool_prefs, :params, :headers, :schema

def initialize(model: nil, provider: nil, assume_model_exists: false, context: nil)
if assume_model_exists && !provider
Expand All @@ -19,6 +19,7 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n
@temperature = nil
@messages = []
@tools = {}
@tool_prefs = { choice: nil, parallel: nil }
@params = {}
@headers = {}
@schema = nil
Expand All @@ -44,15 +45,19 @@ def with_instructions(instructions, replace: false)
self
end

def with_tool(tool)
tool_instance = tool.is_a?(Class) ? tool.new : tool
@tools[tool_instance.name.to_sym] = tool_instance
def with_tool(tool, choice: nil, parallel: nil)
unless tool.nil?
tool_instance = tool.is_a?(Class) ? tool.new : tool
@tools[tool_instance.name.to_sym] = tool_instance
end
update_tool_options(choice:, parallel:)
self
end

def with_tools(*tools, replace: false)
def with_tools(*tools, replace: false, choice: nil, parallel: nil)
@tools.clear if replace
tools.compact.each { |tool| with_tool tool }
update_tool_options(choice:, parallel:)
self
end

Expand Down Expand Up @@ -125,6 +130,7 @@ def complete(&) # rubocop:disable Metrics/PerceivedComplexity
response = @provider.complete(
messages,
tools: @tools,
tool_prefs: @tool_prefs,
temperature: @temperature,
model: @model,
params: @params,
Expand Down Expand Up @@ -200,6 +206,7 @@ def handle_tool_calls(response, &) # rubocop:disable Metrics/PerceivedComplexity
halt_result = result if result.is_a?(Tool::Halt)
end

reset_tool_choice if forced_tool_choice?
halt_result || complete(&)
end

Expand All @@ -208,5 +215,27 @@ def execute_tool(tool_call)
args = tool_call.arguments
tool.call(args)
end

def update_tool_options(choice:, parallel:)
unless choice.nil?
valid_tool_choices = %i[auto none required] + tools.keys
unless valid_tool_choices.include?(choice.to_sym)
raise InvalidToolChoiceError,
"Invalid tool choice: #{choice}. Valid choices are: #{valid_tool_choices.join(', ')}"
end

@tool_prefs[:choice] = choice.to_sym
end

@tool_prefs[:parallel] = !!parallel unless parallel.nil?
end

def forced_tool_choice?
@tool_prefs[:choice] && !%i[auto none].include?(@tool_prefs[:choice])
end

def reset_tool_choice
@tool_prefs[:choice] = nil
end
end
end
1 change: 1 addition & 0 deletions lib/ruby_llm/error.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def initialize(response = nil, message = nil)
# Error classes for non-HTTP errors
class ConfigurationError < StandardError; end
class InvalidRoleError < StandardError; end
class InvalidToolChoiceError < StandardError; end
class ModelNotFoundError < StandardError; end
class UnsupportedAttachmentError < StandardError; end

Expand Down
6 changes: 5 additions & 1 deletion lib/ruby_llm/provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ def configuration_requirements
self.class.configuration_requirements
end

def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, &) # rubocop:disable Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists
def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil,
tool_prefs: nil, &)
normalized_temperature = maybe_normalize_temperature(temperature, model)

payload = Utils.deep_merge(
render_payload(
messages,
tools: tools,
tool_prefs: tool_prefs,
temperature: normalized_temperature,
model: model,
stream: block_given?,
Expand All @@ -58,6 +61,7 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc
sync_response @connection, payload, headers
end
end
# rubocop:enable Metrics/ParameterLists

def list_models
response = @connection.get models_url
Expand Down
15 changes: 11 additions & 4 deletions lib/ruby_llm/providers/anthropic/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ def completion_url
'/v1/messages'
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_prefs:,
temperature:, model:, stream: false, schema: nil)
system_messages, chat_messages = separate_messages(messages)
system_content = build_system_content(system_messages)

build_base_payload(chat_messages, model, stream).tap do |payload|
add_optional_fields(payload, system_content:, tools:, temperature:)
add_optional_fields(payload, system_content:, tools:, tool_prefs:, temperature:)
end
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

def separate_messages(messages)
messages.partition { |msg| msg.role == :system }
Expand All @@ -44,8 +47,12 @@ def build_base_payload(chat_messages, model, stream)
}
end

def add_optional_fields(payload, system_content:, tools:, temperature:)
payload[:tools] = tools.values.map { |t| Tools.function_for(t) } if tools.any?
def add_optional_fields(payload, system_content:, tools:, tool_prefs:, temperature:)
if tools.any?
payload[:tools] = tools.values.map { |t| Tools.function_for(t) }
payload[:tool_choice] = build_tool_choice(tool_prefs) unless tool_prefs[:choice].nil?
end

payload[:system] = system_content unless system_content.empty?
payload[:temperature] = temperature unless temperature.nil?
end
Expand Down
19 changes: 19 additions & 0 deletions lib/ruby_llm/providers/anthropic/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,25 @@ def clean_parameters(parameters)
def required_parameters(parameters)
parameters.select { |_, param| param.required }.keys
end

def build_tool_choice(tool_prefs)
tool_choice = tool_prefs[:choice]
parallel_tool_calls = tool_prefs[:parallel]

{
type: case tool_choice
when :auto, :none
tool_choice
when :required
:any
else
:tool
end
}.tap do |tc|
tc[:name] = tool_choice if tc[:type] == :tool
tc[:disable_parallel_tool_use] = !parallel_tool_calls unless tc[:type] == :none || parallel_tool_calls.nil?
end
end
end
end
end
Expand Down
8 changes: 6 additions & 2 deletions lib/ruby_llm/providers/bedrock/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,20 @@ def completion_url
"model/#{@model_id}/invoke"
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Lint/UnusedMethodArgument,Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_prefs:, temperature:, model:, stream: false,
schema: nil)
@model_id = model.id

system_messages, chat_messages = Anthropic::Chat.separate_messages(messages)
system_content = Anthropic::Chat.build_system_content(system_messages)

build_base_payload(chat_messages, model).tap do |payload|
Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, temperature:)
Anthropic::Chat.add_optional_fields(payload, system_content:, tools:, tool_prefs:,
temperature:)
end
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

def build_base_payload(chat_messages, model)
{
Expand Down
11 changes: 9 additions & 2 deletions lib/ruby_llm/providers/gemini/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def completion_url
"models/#{@model}:generateContent"
end

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
# rubocop:disable Metrics/ParameterLists,Lint/UnusedMethodArgument
def render_payload(messages, tools:, tool_prefs:, temperature:, model:, stream: false, schema: nil)
@model = model.id
payload = {
contents: format_messages(messages),
Expand All @@ -25,9 +26,15 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema
payload[:generationConfig][:responseSchema] = convert_schema_to_gemini(schema)
end

payload[:tools] = format_tools(tools) if tools.any?
if tools.any?
payload[:tools] = format_tools(tools)
# Gemini doesn't support controlling parallel tool calls
payload[:toolConfig] = build_tool_config(tool_prefs[:choice]) unless tool_prefs[:choice].nil?
end

payload
end
# rubocop:enable Metrics/ParameterLists,Lint/UnusedMethodArgument

private

Expand Down
19 changes: 19 additions & 0 deletions lib/ruby_llm/providers/gemini/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,25 @@ def param_type_for_gemini(type)
else 'STRING'
end
end

def build_tool_config(tool_choice)
{
functionCallingConfig: {
mode: forced_tool_choice?(tool_choice) ? 'any' : tool_choice
}.tap do |config|
# Use allowedFunctionNames to simulate specific tool choice
config[:allowedFunctionNames] = [tool_choice] if specific_tool_choice?(tool_choice)
end
}
end

def forced_tool_choice?(tool_choice)
tool_choice == :required || specific_tool_choice?(tool_choice)
end

def specific_tool_choice?(tool_choice)
!%i[auto none required].include?(tool_choice)
end
end
end
end
Expand Down
3 changes: 2 additions & 1 deletion lib/ruby_llm/providers/mistral/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def format_role(role)
end

# rubocop:disable Metrics/ParameterLists
def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil)
def render_payload(messages, tools:, tool_prefs:, temperature:, model:, stream: false,
schema: nil)
payload = super
payload.delete(:stream_options)
payload
Expand Down
11 changes: 9 additions & 2 deletions lib/ruby_llm/providers/openai/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@ def completion_url

module_function

def render_payload(messages, tools:, temperature:, model:, stream: false, schema: nil) # rubocop:disable Metrics/ParameterLists
# rubocop:disable Metrics/ParameterLists
def render_payload(messages, tools:, tool_prefs:, temperature:, model:, stream: false, schema: nil)
payload = {
model: model.id,
messages: format_messages(messages),
stream: stream
}

payload[:temperature] = temperature unless temperature.nil?
payload[:tools] = tools.map { |_, tool| tool_for(tool) } if tools.any?

if tools.any?
payload[:tools] = tools.map { |_, tool| tool_for(tool) }
payload[:tool_choice] = build_tool_choice(tool_prefs[:choice]) unless tool_prefs[:choice].nil?
payload[:parallel_tool_calls] = tool_prefs[:parallel] unless tool_prefs[:parallel].nil?
end

if schema
strict = schema[:strict] != false
Expand All @@ -37,6 +43,7 @@ def render_payload(messages, tools:, temperature:, model:, stream: false, schema
payload[:stream_options] = { include_usage: true } if stream
payload
end
# rubocop:enable Metrics/ParameterLists

def parse_completion_response(response)
data = response.body
Expand Down
14 changes: 14 additions & 0 deletions lib/ruby_llm/providers/openai/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def parse_tool_calls(tool_calls, parse_arguments: true)
]
end
end

def build_tool_choice(tool_choice)
case tool_choice
when :auto, :none, :required
tool_choice
else
{
type: 'function',
function: {
name: tool_choice
}
}
end
end
end
end
end
Expand Down