Description
Hello,
I would like to add ToolChoice to my bedrock requests, but this does not currently appear to be supported as of 1.0.0-M7.
This line builds the ToolConfiguration and we need a way to pass in the tool choice: https://github.com/spring-projects/spring-ai/blob/v1.0.0-M7/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java#L406
Expected Behavior
I configure my prompt options with a ToolChoice
and a SpecificToolChoice
, so that the LLM (Bedrock Anthropic Claude 3) responds with a single tool use and does not return any other messages. This cuts down on token use and latency.
Add a BedrockChatOptions
class which implements ToolCallingChatOptions
and has a ToolChoice toolChoice
field.
val toolChoice = ToolChoice.builder()
// I specifically need `SpecificToolChoice`, but that should be irrelevant
.tool(SpecificToolChoice.builder().name("ExampleTool").build())
.build()
val chatOptions = BedrockChatOptions.builder()
.toolChoice(toolChoice)
.toolNames("ExampleTool").build()
val prompt = Prompt("example", chatOptions)
Current Behavior
I use ToolCallingChatOptions
to configure my tools, but I am unable to set my tool choice, so my tool is invoked correctly, but the call is slow and costs me more tokens than necessary. I get an additional message in my ChatResponse that I do not need.
val chatOptions = ToolCallingChatOptions.builder().toolNames("ExampleTool").build()
val prompt = Prompt("example", chatOptions)
chatModel.call(prompt)
Workarounds
One workaround I was using (which is potentially flaky and consume extra tokens) is to add this to the prompt:
Do not respond with any text or any messages. Your only purpose is to make a tool call to {toolName}.
A more "complete", but still hacky workaround was to override the converse
function of BedrockRuntimeClient
and inject the tool configuration at runtime via reflection.
class BedrockAnthropicRuntimeClientWithToolSupport(
private val defaultClient: BedrockRuntimeClient,
private val toolName: String?,
) : BedrockRuntimeClient by defaultClient {
override fun converse(converseRequest: ConverseRequest): ConverseResponse {
if (toolName != null) {
val toolConfig = converseRequest.toolConfig()
setToolChoiceUnsafe(toolConfig, toolName)
}
return defaultClient.converse(converseRequest)
}
companion object {
fun setToolChoiceUnsafe(toolConfig: ToolConfiguration, toolName: String) {
val toolChoice = ToolChoice.builder()
.tool(SpecificToolChoice.builder().name(toolName).build())
.build()
ToolConfiguration::class.java.getDeclaredField("toolChoice")
.apply {
isAccessible = true
set(toolConfig, toolChoice)
}
}
}
}
Then I provided a bean so my service(s) could build different chat models depending on the tool which I need ToolChoice for
@Bean("bedrockAnthropicProxyChatModelWithToolChoice")
@ConditionalOnProperty("spring.ai.bedrock.converse.chat.options.model")
fun bedrockAnthropicProxyChatModelWithToolChoice(
credentialsProvider: AwsCredentialsProvider,
regionProvider: AwsRegionProvider,
connectionProperties: BedrockAwsConnectionProperties,
chatProperties: BedrockConverseProxyChatProperties,
toolCallingManager: ToolCallingManager?,
observationRegistry: ObservationRegistry,
): (name: String) -> ChatModel {
return { toolName: String ->
val runtimeClient = BedrockAnthropicRuntimeClientWithToolSupport(
BedrockRuntimeClient.builder()
.region(regionProvider.region)
.credentialsProvider(credentialsProvider)
.httpClientBuilder(null)
.overrideConfiguration({ it.apiCallTimeout(Duration.ofSeconds(60)) })
.build(),
toolName
)
BedrockProxyChatModel.builder()
.credentialsProvider(credentialsProvider)
.region(regionProvider.region)
.timeout(connectionProperties.timeout)
.defaultOptions(chatProperties.options)
.toolCallingManager(toolCallingManager)
.bedrockRuntimeClient(runtimeClient)
.build()
}
}
Which is used in a service like
class ExampleService(
@Qualifier("bedrockAnthropicProxyChatModelWithToolChoice")
chatModelSupplier: (String) -> ChatModel,
) {
private val chatModel: ChatModel = chatModelSupplier.invoke("ExampleTool")