Skip to content

ToolCallingChatOptions support internalToolExecutionMaxIterations, to limit the maximum number of tool calls and prevent infinite recursive calls to LLM in special cases #3380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
72bceab
ToolCallingChatOptions support internalToolExecutionMaxIterations
lambochen May 29, 2025
33847aa
rename internalToolExecutionMaxAttempts
lambochen May 29, 2025
e243d42
ToolExecutionEligibilityChecker add logical for attempts
lambochen May 29, 2025
670d691
OpenAiChatModel support internalToolExecutionMaxAttempts
lambochen May 29, 2025
092eb6c
internalToolExecutionEnabled set default value is Integer.MAX_VALUE
lambochen May 29, 2025
2d128ba
OpenAiChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
253d19e
AnthropicChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
93efce5
AzureOpenAiChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
d660152
BedrockProxyChatModel support internalToolExecutionMaxAttempts
lambochen May 29, 2025
2c7fc3d
DeepSeekChatModel support internalToolExecutionMaxAttempts
lambochen May 29, 2025
1a11eca
MiniMaxChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
69d0d5e
MistralAiChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
f808731
OllamaChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
bf9df71
VertexAiGeminiChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
8dd816d
ZhiPuAiChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
0ec9490
fix: api compatability
lambochen May 29, 2025
8aaf132
UT for IsToolExecutionRequiredWithAttempts
lambochen May 30, 2025
4269f37
UT for ToolExecutionEligibilityChecker attempts
lambochen May 30, 2025
8e1f251
merge main
lambochen May 31, 2025
fcf1a76
UT config for openai
lambochen May 31, 2025
ae8d1ab
any UT for internalTooolExecutionMaxAttempts
lambochen May 31, 2025
3b1162e
add UT for internalToolCallingExecutionMaxAttempts in some options
lambochen May 31, 2025
e894e67
code format by spring-javaformat plugin
lambochen May 31, 2025
c63341b
fix attempts check logical
lambochen May 31, 2025
56fba7c
ToolCallingChatOptions: rename attempts to iterations for tool execution
lambochen Jun 1, 2025
80c2e26
rename iterations to toolExecutionIterations
lambochen Jun 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
* @author Alexandros Pappas
* @author Jonghoon Park
* @author Soby Chacko
* @author lambochen
* @since 1.0.0
*/
public class AnthropicChatModel implements ChatModel {
Expand Down Expand Up @@ -174,6 +175,10 @@ public ChatResponse call(Prompt prompt) {
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
return this.internalCall(prompt, previousChatResponse, 1);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) {
ChatCompletionRequest request = createRequest(prompt, false);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Expand Down Expand Up @@ -203,7 +208,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
return chatResponse;
});

if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
Expand All @@ -215,7 +220,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
else {
// Send the tool execution result back to the model.
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
response, iterations + 1);
}
}

Expand All @@ -236,6 +241,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return this.internalStream(prompt, previousChatResponse, 1);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Expand All @@ -260,7 +269,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);

if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of("tool_use"))) {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, iterations)
&& chatResponse.hasFinishReasons(Set.of("tool_use"))) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.defer(() -> {
Expand All @@ -274,7 +284,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
chatResponse);
chatResponse, iterations + 1);
}
}).subscribeOn(Schedulers.boundedElastic());
}
Expand Down Expand Up @@ -437,6 +447,9 @@ Prompt buildRequestPrompt(Prompt prompt) {
requestOptions.setInternalToolExecutionEnabled(
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
this.defaultOptions.getInternalToolExecutionEnabled()));
requestOptions.setInternalToolExecutionMaxIterations(
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(),
defaultOptions.getInternalToolExecutionMaxIterations()));
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
this.defaultOptions.getToolNames()));
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
Expand All @@ -447,6 +460,8 @@ Prompt buildRequestPrompt(Prompt prompt) {
else {
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
requestOptions
.setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations());
requestOptions.setToolNames(this.defaultOptions.getToolNames());
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
requestOptions.setToolContext(this.defaultOptions.getToolContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
* @author Thomas Vitale
* @author Alexandros Pappas
* @author Ilayaperumal Gopinathan
* @author lambochen
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
Expand Down Expand Up @@ -79,6 +80,9 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
@JsonIgnore
private Boolean internalToolExecutionEnabled;

@JsonIgnore
private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS;

@JsonIgnore
private Map<String, Object> toolContext = new HashMap<>();

Expand Down Expand Up @@ -109,6 +113,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.internalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations())
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null)
.build();
Expand Down Expand Up @@ -226,6 +231,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

@Override
public Integer getInternalToolExecutionMaxIterations() {
return this.internalToolExecutionMaxIterations;
}

@Override
public void setInternalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) {
this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations;
}

@Override
@JsonIgnore
public Double getFrequencyPenalty() {
Expand Down Expand Up @@ -281,6 +296,7 @@ public boolean equals(Object o) {
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
&& Objects.equals(this.toolNames, that.toolNames)
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
&& Objects.equals(this.internalToolExecutionMaxIterations, that.internalToolExecutionMaxIterations)
&& Objects.equals(this.toolContext, that.toolContext)
&& Objects.equals(this.httpHeaders, that.httpHeaders);
}
Expand All @@ -289,7 +305,7 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP,
this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
this.toolContext, this.httpHeaders);
this.internalToolExecutionMaxIterations, this.toolContext, this.httpHeaders);
}

public static class Builder {
Expand Down Expand Up @@ -374,6 +390,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut
return this;
}

public Builder internalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) {
this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations);
return this;
}

public Builder toolContext(Map<String, Object> toolContext) {
if (this.options.toolContext == null) {
this.options.toolContext = toolContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
import org.junit.jupiter.api.Test;

import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.Metadata;
import org.springframework.ai.model.tool.ToolCallingChatOptions;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Tests for {@link AnthropicChatOptions}.
*
* @author Alexandros Pappas
* @author lambochen
*/
class AnthropicChatOptionsTests {

Expand All @@ -42,10 +44,13 @@ void testBuilderWithAllFields() {
.topP(0.8)
.topK(50)
.metadata(new Metadata("userId_123"))
.internalToolExecutionMaxIterations(3)
.build();

assertThat(options).extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata")
.containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"));
assertThat(options)
.extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata",
"internalToolExecutionMaxIterations")
.containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"), 3);
}

@Test
Expand All @@ -59,6 +64,7 @@ void testCopy() {
.topK(50)
.metadata(new Metadata("userId_123"))
.toolContext(Map.of("key1", "value1"))
.internalToolExecutionMaxIterations(3)
.build();

AnthropicChatOptions copied = original.copy();
Expand All @@ -67,6 +73,8 @@ void testCopy() {
// Ensure deep copy
assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences());
assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext());

assertThat(copied.getInternalToolExecutionMaxIterations()).isEqualTo(3);
}

@Test
Expand All @@ -79,6 +87,7 @@ void testSetters() {
options.setTopP(0.8);
options.setStopSequences(List.of("stop1", "stop2"));
options.setMetadata(new Metadata("userId_123"));
options.setInternalToolExecutionMaxIterations(3);

assertThat(options.getModel()).isEqualTo("test-model");
assertThat(options.getMaxTokens()).isEqualTo(100);
Expand All @@ -87,6 +96,7 @@ void testSetters() {
assertThat(options.getTopP()).isEqualTo(0.8);
assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2"));
assertThat(options.getMetadata()).isEqualTo(new Metadata("userId_123"));
assertThat(options.getInternalToolExecutionMaxIterations()).isEqualTo(3);
}

@Test
Expand All @@ -99,6 +109,8 @@ void testDefaultValues() {
assertThat(options.getTopP()).isNull();
assertThat(options.getStopSequences()).isNull();
assertThat(options.getMetadata()).isNull();
assertThat(options.getInternalToolExecutionMaxIterations())
.isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@
* @author Berjan Jonker
* @author Andres da Silva Santos
* @author Bart Veenstra
* @author lambochen
* @see ChatModel
* @see com.azure.ai.openai.OpenAIClient
* @see ToolCallingChatOptions
* @since 1.0.0
*/
public class AzureOpenAiChatModel implements ChatModel {
Expand Down Expand Up @@ -251,6 +253,10 @@ public ChatResponse call(Prompt prompt) {
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
return internalCall(prompt, previousChatResponse, 1);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) {

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
Expand All @@ -270,7 +276,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
return chatResponse;
});

if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
Expand All @@ -282,7 +288,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
else {
// Send the tool execution result back to the model.
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
response, iterations + 1);
}
}

Expand All @@ -298,6 +304,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return this.internalStream(prompt, previousChatResponse, 1);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) {

return Flux.deferContextual(contextView -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
Expand Down Expand Up @@ -377,7 +387,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
});

return chatResponseFlux.flatMap(chatResponse -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse,
iterations)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.defer(() -> {
Expand All @@ -393,7 +404,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
// Send the tool execution result back to the model.
return this.internalStream(
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
chatResponse);
chatResponse, iterations + 1);
}
}).subscribeOn(Schedulers.boundedElastic());
}
Expand Down Expand Up @@ -666,6 +677,9 @@ Prompt buildRequestPrompt(Prompt prompt) {
requestOptions.setInternalToolExecutionEnabled(
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
this.defaultOptions.getInternalToolExecutionEnabled()));
runtimeOptions.setInternalToolExecutionMaxIterations(
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionMaxIterations(),
this.defaultOptions.getInternalToolExecutionMaxIterations()));
requestOptions.setStreamUsage(ModelOptionsUtils.mergeOption(runtimeOptions.getStreamUsage(),
this.defaultOptions.getStreamUsage()));
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
Expand All @@ -677,6 +691,8 @@ Prompt buildRequestPrompt(Prompt prompt) {
}
else {
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
requestOptions
.setInternalToolExecutionMaxIterations(this.defaultOptions.getInternalToolExecutionMaxIterations());
requestOptions.setStreamUsage(this.defaultOptions.getStreamUsage());
requestOptions.setToolNames(this.defaultOptions.getToolNames());
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
* @author Ilayaperumal Gopinathan
* @author Alexandros Pappas
* @author Andres da Silva Santos
* @author lambochen
*/
@JsonInclude(Include.NON_NULL)
public class AzureOpenAiChatOptions implements ToolCallingChatOptions {
Expand Down Expand Up @@ -200,6 +201,9 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions {
@JsonIgnore
private Boolean internalToolExecutionEnabled;

@JsonIgnore
private Integer internalToolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS;

/**
* Whether to include token usage information in streaming chat completion responses.
* Only applies to streaming responses.
Expand Down Expand Up @@ -257,6 +261,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

@Override
public Integer getInternalToolExecutionMaxIterations() {
return this.internalToolExecutionMaxIterations;
}

@Override
public void setInternalToolExecutionMaxIterations(Integer internalToolExecutionMaxIterations) {
this.internalToolExecutionMaxIterations = internalToolExecutionMaxIterations;
}

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -284,6 +298,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
.enhancements(fromOptions.getEnhancements())
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.internalToolExecutionMaxIterations(fromOptions.getInternalToolExecutionMaxIterations())
.streamOptions(fromOptions.getStreamOptions())
.toolCallbacks(
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
Expand Down Expand Up @@ -504,6 +519,7 @@ public boolean equals(Object o) {
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
&& Objects.equals(this.toolNames, that.toolNames)
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
&& Objects.equals(this.internalToolExecutionMaxIterations, that.internalToolExecutionMaxIterations)
&& Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.topLogProbs, that.topLogProbs)
&& Objects.equals(this.enhancements, that.enhancements)
&& Objects.equals(this.streamOptions, that.streamOptions)
Expand All @@ -518,10 +534,10 @@ public boolean equals(Object o) {
@Override
public int hashCode() {
return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat,
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs,
this.topLogProbs, this.enhancements, this.streamOptions, this.reasoningEffort, this.enableStreamUsage,
this.toolContext, this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature,
this.topP);
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
this.internalToolExecutionMaxIterations, this.seed, this.logprobs, this.topLogProbs, this.enhancements,
this.streamOptions, this.reasoningEffort, this.enableStreamUsage, this.toolContext, this.maxTokens,
this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP);
}

public static class Builder {
Expand Down Expand Up @@ -664,6 +680,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut
return this;
}

public Builder internalToolExecutionMaxIterations(@Nullable Integer internalToolExecutionMaxIterations) {
this.options.setInternalToolExecutionMaxIterations(internalToolExecutionMaxIterations);
return this;
}

public AzureOpenAiChatOptions build() {
return this.options;
}
Expand Down
Loading