Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion include/ai/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,24 @@ class ToolExecutor {
/// @param tools Available tools
/// @param messages Context messages
/// @param parallel Whether to execute in parallel (default: true)
/// @param options Optional generate options containing callbacks
/// @return Vector of tool execution results
static std::vector<ToolResult> execute_tools(
const std::vector<ToolCall>& tool_calls,
const ToolSet& tools,
const Messages& messages = {},
bool parallel = true);
bool parallel = true,
const GenerateOptions* options = nullptr);

/// Execute multiple tool calls with options (simplified interface)
/// @param tool_calls Vector of tool calls to execute
/// @param options Generate options containing tools, messages, and callbacks
/// @param parallel Whether to execute in parallel (default: false for safety)
/// @return Vector of tool execution results
static std::vector<ToolResult> execute_tools_with_options(
const std::vector<ToolCall>& tool_calls,
const GenerateOptions& options,
bool parallel = false);

/// Validate tool call arguments against tool schema
/// @param tool_call The tool call to validate
Expand Down
2 changes: 2 additions & 0 deletions include/ai/types/generate_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct GenerateOptions {

// Callbacks for tool calling
std::optional<std::function<void(const GenerateStep&)>> on_step_finish;
std::optional<std::function<void(const ToolCall&)>> on_tool_call_start;
std::optional<std::function<void(const ToolResult&)>> on_tool_call_finish;

GenerateOptions(std::string model_name, std::string user_prompt)
: model(std::move(model_name)), prompt(std::move(user_prompt)) {}
Expand Down
9 changes: 7 additions & 2 deletions src/providers/base_provider_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,18 @@ GenerateResult BaseProviderClient::generate_text_single_step(
auto parsed_result =
response_parser_->parse_success_response(json_response);

if (parsed_result.has_tool_calls()) {
ai::logger::log_debug("Model made {} tool calls",
parsed_result.tool_calls.size());
}

// Execute tools if the model made tool calls
if (parsed_result.has_tool_calls() && options.has_tools()) {
ai::logger::log_debug("Model made {} tool calls, executing them",
parsed_result.tool_calls.size());

auto tool_results = ToolExecutor::execute_tools(
parsed_result.tool_calls, options.tools, options.messages);
auto tool_results = ToolExecutor::execute_tools_with_options(
parsed_result.tool_calls, options, false);

parsed_result.tool_results = tool_results;
ai::logger::log_debug("Executed {} tools", tool_results.size());
Expand Down
18 changes: 15 additions & 3 deletions src/tools/multi_step_coordinator.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <ai/logger.h>
#include <ai/tools.h>
#include <ai/types/enums.h>
#include <ai/types/tool.h>
Expand All @@ -16,6 +17,9 @@ GenerateResult MultiStepCoordinator::execute_multi_step(
GenerateOptions current_options = initial_options;

for (int step = 0; step < initial_options.max_steps; ++step) {
ai::logger::log_debug("Executing step {} of {}", step + 1,
initial_options.max_steps);

// Execute the current step
GenerateResult step_result = generate_func(current_options);

Expand Down Expand Up @@ -85,9 +89,10 @@ GenerateResult MultiStepCoordinator::execute_multi_step(
if (step_result.finish_reason == kFinishReasonToolCalls &&
step_result.has_tool_calls()) {
// Execute tools and prepare for next step
std::vector<ToolResult> tool_results = ToolExecutor::execute_tools(
step_result.tool_calls, initial_options.tools,
current_options.messages);
// Use sequential execution to avoid thread-safety issues
std::vector<ToolResult> tool_results =
ToolExecutor::execute_tools_with_options(step_result.tool_calls,
initial_options, false);

// Store tool results in the step
final_result.steps.back().tool_results = tool_results;
Expand Down Expand Up @@ -116,6 +121,13 @@ GenerateResult MultiStepCoordinator::execute_multi_step(
}
}

// Log if we hit the max steps limit
if (final_result.steps.size() == initial_options.max_steps &&
final_result.finish_reason != kFinishReasonStop) {
ai::logger::log_debug("Reached max steps limit ({}) without completion",
initial_options.max_steps);
}

return final_result;
}

Expand Down
25 changes: 23 additions & 2 deletions src/tools/tool_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,27 @@ std::vector<ToolResult> ToolExecutor::execute_tools(
const std::vector<ToolCall>& tool_calls,
const ToolSet& tools,
const Messages& messages,
bool parallel) {
bool parallel,
const GenerateOptions* options) {
std::vector<ToolResult> results;
results.reserve(tool_calls.size());

if (!parallel) {
// Execute sequentially
for (const auto& tool_call : tool_calls) {
results.push_back(execute_tool(tool_call, tools, messages));
// Call the on_tool_call_start callback if provided
if (options && options->on_tool_call_start.has_value()) {
options->on_tool_call_start.value()(tool_call);
}

auto result = execute_tool(tool_call, tools, messages);

// Call the on_tool_call_finish callback if provided
if (options && options->on_tool_call_finish.has_value()) {
options->on_tool_call_finish.value()(result);
}

results.push_back(result);
}
} else {
// Execute in parallel using futures
Expand All @@ -99,6 +112,14 @@ std::vector<ToolResult> ToolExecutor::execute_tools(
return results;
}

std::vector<ToolResult> ToolExecutor::execute_tools_with_options(
const std::vector<ToolCall>& tool_calls,
const GenerateOptions& options,
bool parallel) {
return execute_tools(tool_calls, options.tools, options.messages, parallel,
&options);
}

bool ToolExecutor::validate_tool_call(const ToolCall& tool_call,
const Tool& tool) {
// Basic validation - check if the arguments match the expected schema
Expand Down