Skip to content

Add ToolChoice support to Bedrock Converse (BedrockProxyChatModel) #2752

Open
@unavailable-username

Description

@unavailable-username

Hello,

I would like to add ToolChoice to my bedrock requests, but this does not currently appear to be supported as of 1.0.0-M7.

This line builds the ToolConfiguration and we need a way to pass in the tool choice: https://github.com/spring-projects/spring-ai/blob/v1.0.0-M7/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java#L406

Expected Behavior

I configure my prompt options with a ToolChoice and a SpecificToolChoice, so that the LLM (Bedrock Anthropic Claude 3) responds with a single tool use and does not return any other messages. This cuts down on token use and latency.

Add a BedrockChatOptions class which implements ToolCallingChatOptions and has a ToolChoice toolChoice field.

val toolChoice = ToolChoice.builder()
    // I specifically need `SpecificToolChoice`, but that should be irrelevant
    .tool(SpecificToolChoice.builder().name("ExampleTool").build())
    .build()
val chatOptions = BedrockChatOptions.builder()
    .toolChoice(toolChoice)
    .toolNames("ExampleTool").build()
val prompt = Prompt("example", chatOptions)

Current Behavior

I use ToolCallingChatOptions to configure my tools, but I am unable to set my tool choice, so my tool is invoked correctly, but the call is slow and costs me more tokens than necessary. I get an additional message in my ChatResponse that I do not need.

val chatOptions = ToolCallingChatOptions.builder().toolNames("ExampleTool").build()
val prompt = Prompt("example", chatOptions)
chatModel.call(prompt)

Workarounds

One workaround I was using (which is potentially flaky and consume extra tokens) is to add this to the prompt:
Do not respond with any text or any messages. Your only purpose is to make a tool call to {toolName}.

A more "complete", but still hacky workaround was to override the converse function of BedrockRuntimeClient and inject the tool configuration at runtime via reflection.

class BedrockAnthropicRuntimeClientWithToolSupport(
  private val defaultClient: BedrockRuntimeClient,
  private val toolName: String?,
) : BedrockRuntimeClient by defaultClient {
  override fun converse(converseRequest: ConverseRequest): ConverseResponse {
    if (toolName != null) {
      val toolConfig = converseRequest.toolConfig()
      setToolChoiceUnsafe(toolConfig, toolName)
    }
    return defaultClient.converse(converseRequest)
  }

  companion object {
    fun setToolChoiceUnsafe(toolConfig: ToolConfiguration, toolName: String) {
      val toolChoice = ToolChoice.builder()
        .tool(SpecificToolChoice.builder().name(toolName).build())
        .build()
      ToolConfiguration::class.java.getDeclaredField("toolChoice")
        .apply {
          isAccessible = true
          set(toolConfig, toolChoice)
        }
    }
  }
}

Then I provided a bean so my service(s) could build different chat models depending on the tool which I need ToolChoice for

  @Bean("bedrockAnthropicProxyChatModelWithToolChoice")
  @ConditionalOnProperty("spring.ai.bedrock.converse.chat.options.model")
  fun bedrockAnthropicProxyChatModelWithToolChoice(
    credentialsProvider: AwsCredentialsProvider,
    regionProvider: AwsRegionProvider,
    connectionProperties: BedrockAwsConnectionProperties,
    chatProperties: BedrockConverseProxyChatProperties,
    toolCallingManager: ToolCallingManager?,
    observationRegistry: ObservationRegistry,
  ): (name: String) -> ChatModel {
    return { toolName: String ->
      val runtimeClient = BedrockAnthropicRuntimeClientWithToolSupport(
        BedrockRuntimeClient.builder()
          .region(regionProvider.region)
          .credentialsProvider(credentialsProvider)
          .httpClientBuilder(null)
          .overrideConfiguration({ it.apiCallTimeout(Duration.ofSeconds(60)) })
          .build(),
        toolName
      )
      BedrockProxyChatModel.builder()
          .credentialsProvider(credentialsProvider)
          .region(regionProvider.region)
          .timeout(connectionProperties.timeout)
          .defaultOptions(chatProperties.options)
          .toolCallingManager(toolCallingManager)
          .bedrockRuntimeClient(runtimeClient)
      .build()
    }
  }

Which is used in a service like

class ExampleService(
  @Qualifier("bedrockAnthropicProxyChatModelWithToolChoice")
  chatModelSupplier: (String) -> ChatModel,
) {
  private val chatModel: ChatModel = chatModelSupplier.invoke("ExampleTool")

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions