Skip to content

Commit

Permalink
AI Services: added an option to configure tools programmatically (lan…
Browse files Browse the repository at this point in the history
…gchain4j#1364)

## Issue
Implements langchain4j#141


## Change
This PR introduces an option to configure tools programmatically when using AI Services.
Tools can now be provided as a map of `ToolSpecification` to
`ToolExecutor` pairs:

```java
ToolSpecification toolSpecification = ToolSpecification.builder()
    .name("get_booking_details")
    .description("Returns booking details")
    .addParameter("bookingNumber", type("string"))
    .build();

ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> {
    Map<String, Object> arguments = toMap(toolExecutionRequest.arguments());
    assertThat(arguments).containsExactly(entry("bookingNumber", "123-456"));
    return "Booking period: from 1 July 2027 to 10 July 2027";
};

Assistant assistant = AiServices.builder(Assistant.class)
    .chatLanguageModel(chatLanguageModel)
    .tools(singletonMap(toolSpecification, toolExecutor))
    .build();

String answer = assistant.chat("When does my booking 123-456 starts?");

assertThat(answer).contains("2027");
```

This approach offers a lot of flexibility, as tools can now be loaded
from external sources such as databases and configuration files.
Tool names, descriptions, parameter names, and descriptions can all be
dynamically configured via `ToolSpecification`.

For instance, one of the LC4j users wants to store tools (where each
tool is an API endpoint) in a configuration file like this:
```json
[
  {
    "name": "get_order_details",
    "url": "https://url.com",
    "method": "POST",
    "description": "Get additional order details by providing an order ID"
    "parameters": {
      "order": {
        "type": "string",
        "description": "an order ID, for example: 300"
      }
    },
    "examples": [
      "Get additional details for order 300",
      "Show more information for order 301",
      "Show request delivery date for order 201"
    ]
  }
]
```

With this PR, this can be implemented like so:
```java
List<ApiTool> apiTools = loadFromFile("tools.json");
Map<ToolSpecification, ToolExecutor> tools = new HashMap<>();
for (ApiTool apiTool : apiTools) {
    if ("GET".equals(apiTool.getMethod())) {

        ToolSpecification toolSpecification = ToolSpecification.builder()
            .name(apiTool.getName())
            .description(apiTool.getDescription())
            .build();

            ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> httpClient.get(apiTool.getUrl());
            
            tools.put(toolSpecification, toolExecutor);
    } else if ("POST".equals(apiTool.getMethod())) {

        ToolSpecification.Builder toolSpecificationBuilder = ToolSpecification.builder()
            .name(apiTool.getName())
            .description(apiTool.getDescription);

        apiTool.getParameters().forEach((parameterName, parameterProperties) -> {
            toolSpecificationBuilder.addParameter(parameterName, type(parameterProperties.get("type")), description(parameterProperties.get("description")));
        });

        ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> httpClient.post(apiTool.getUrl(), toolExecutionRequest.arguments());

        tools.put(toolSpecificationBuilder.build(), toolExecutor);
    }
}
```

The drawback of this way of configuring tools is that it requires one to
implement a rather low-level `ToolExecutor` interface and manually parse
tool arguments. In future iterations, we could either automatically
parse arguments into a `Map<String, Object>` tree and/or allow users to
explicitly specify a `Class` to parse into.


## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [x] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
langchain4j authored Jul 2, 2024
1 parent 8d3d1ca commit c5e8af8
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 13 deletions.
40 changes: 39 additions & 1 deletion docs/docs/tutorials/6-tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ Please note that tools/function calling is not the same as [JSON mode](/tutorial

## 2 levels of abstraction

LangChain4j provides two levels of abstraction for using tools.
LangChain4j provides two levels of abstraction for using tools:
- Low-level, using the `ChatLanguageModel` API
- High-level, using [AI Services](/tutorials/ai-services) and `@Tool`-annotated Java methods

### Low level Tool API

Expand Down Expand Up @@ -294,6 +296,42 @@ The value provided to the AI Service method will be automatically passed to the
This feature is useful if you have multiple users and/or multiple chats/memories per user
and wish to distinguish between them inside the `@Tool` method.

### Configuring Tools Programmatically

When using AI Services, tools can also be configured programmatically.
This approach offers a lot of flexibility, as tools can now be loaded
from external sources such as databases and configuration files.

Tool names, descriptions, parameter names, and descriptions
can all be configured using `ToolSpecification`:
```java
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("get_booking_details")
.description("Returns booking details")
.addParameter("bookingNumber", type("string"), description("Booking number in B-12345 format"))
.build();
```

For each `ToolSpecification`, one needs to provide a `ToolExecutor` implementation
that will be handling tool execution requests generated by the LLM:
```java
ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> {
Map<String, Object> arguments = fromJson(toolExecutionRequest.arguments());
String bookingNumber = arguments.get("bookingNumber").toString();
Booking booking = getBooking(bookingNumber);
return booking.toString();
};
```

Once we have one or multiple (`ToolSpecification`, `ToolExecutor`) pairs,
we can specify them when creating an AI Service:
```java
Assistant assistant = AiServices.builder(Assistant.class)
.chatLanguageModel(chatLanguageModel)
.tools(singletonMap(toolSpecification, toolExecutor))
.build();
```

## Related Tutorials

- [Great guide on Tools](https://www.youtube.com/watch?v=cjI_6Siry-s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
import static java.lang.annotation.RetentionPolicy.RUNTIME;

/**
* Java methods annotated with @Tool are considered tools that language model can use.
* When using OpenAI models, <a href="https://platform.openai.com/docs/guides/function-calling">function calling</a>
* Java methods annotated with {@code @Tool} are considered tools/functions that language model can execute/call.
* Tool/function calling LLM capability (e.g., see <a href="https://platform.openai.com/docs/guides/function-calling">OpenAI function calling documentation</a>)
* is used under the hood.
* A low-level {@link ToolSpecification} will be automatically created from the method signature
* (e.g. method name, method parameters (names and types), @Tool and @P annotations, etc.)
* and will be sent to the LLM.
* If LLM decides to call the tool, the arguments are automatically parsed and injected as method arguments.
*/
@Retention(RUNTIME)
@Target({METHOD})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import static dev.langchain4j.internal.Utils.quoted;

/**
* Represents a request to execute a tool.
* Represents an LLM-generated request to execute a tool.
*/
public class ToolExecutionRequest {
private final String id;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
package dev.langchain4j.agent.tool;

import dev.langchain4j.service.MemoryId;

/**
* A low-level executor/handler of a {@link ToolExecutionRequest}.
*/
public interface ToolExecutor {

/**
* Executes a tool requests.
*
* @param toolExecutionRequest The tool execution request. Contains tool name and arguments.
* @param memoryId The ID of the chat memory. See {@link MemoryId} for more details.
* @return The result of the tool execution.
*/
String execute(ToolExecutionRequest toolExecutionRequest, Object memoryId);
}
41 changes: 36 additions & 5 deletions langchain4j/src/main/java/dev/langchain4j/service/AiServices.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dev.langchain4j.agent.tool.DefaultToolExecutor;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
Expand Down Expand Up @@ -37,6 +38,7 @@
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.util.Arrays.asList;
import static java.util.stream.Collectors.toList;

/**
Expand Down Expand Up @@ -294,7 +296,6 @@ public AiServices<T> moderationModel(ModerationModel moderationModel) {

/**
* Configures the tools that the LLM can use.
* A {@link ChatMemory} that can hold at least 3 messages is required for the tools to work properly.
*
* @param objectsWithTools One or more objects whose methods are annotated with {@link Tool}.
* All these tools (methods annotated with {@link Tool}) will be accessible to the LLM.
Expand All @@ -303,12 +304,11 @@ public AiServices<T> moderationModel(ModerationModel moderationModel) {
* @see Tool
*/
public AiServices<T> tools(Object... objectsWithTools) {
return tools(Arrays.asList(objectsWithTools));
return tools(asList(objectsWithTools));
}

/**
* Configures the tools that the LLM can use.
* A {@link ChatMemory} that can hold at least 3 messages is required for the tools to work properly.
*
* @param objectsWithTools A list of objects whose methods are annotated with {@link Tool}.
* All these tools (methods annotated with {@link Tool}) are accessible to the LLM.
Expand All @@ -318,8 +318,13 @@ public AiServices<T> tools(Object... objectsWithTools) {
*/
public AiServices<T> tools(List<Object> objectsWithTools) { // TODO Collection?
// TODO validate uniqueness of tool names
context.toolSpecifications = new ArrayList<>();
context.toolExecutors = new HashMap<>();

if (context.toolSpecifications == null) {
context.toolSpecifications = new ArrayList<>();
}
if (context.toolExecutors == null) {
context.toolExecutors = new HashMap<>();
}

for (Object objectWithTool : objectsWithTools) {
if (objectWithTool instanceof Class) {
Expand All @@ -338,6 +343,32 @@ public AiServices<T> tools(List<Object> objectsWithTools) { // TODO Collection?
return this;
}

/**
* Configures the tools that the LLM can use.
*
* @param tools A map of {@link ToolSpecification} to {@link ToolExecutor} entries.
* This method of configuring tools is useful when tools must be configured programmatically.
* Otherwise, it is recommended to use the {@link Tool}-annotated java methods
* and configure tools with the {@link #tools(Object...)} and {@link #tools(List)} methods.
* @return builder
*/
public AiServices<T> tools(Map<ToolSpecification, ToolExecutor> tools) {

if (context.toolSpecifications == null) {
context.toolSpecifications = new ArrayList<>();
}
if (context.toolExecutors == null) {
context.toolExecutors = new HashMap<>();
}

tools.forEach((toolSpecification, toolExecutor) -> {
context.toolSpecifications.add(toolSpecification);
context.toolExecutors.put(toolSpecification.name(), toolExecutor);
});

return this;
}

/**
* Deprecated. Use {@link #contentRetriever(ContentRetriever)}
* (e.g. {@link EmbeddingStoreContentRetriever}) instead.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package dev.langchain4j.service;

import dev.langchain4j.agent.tool.P;
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.agent.tool.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
Expand All @@ -26,6 +26,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.description;
Expand All @@ -38,7 +39,9 @@
import static dev.langchain4j.service.AiServicesWithToolsIT.TransactionService.EXPECTED_SPECIFICATION;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.MapEntry.entry;
import static org.mockito.Mockito.*;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -646,4 +649,42 @@ void should_use_tool_with_pojo(ChatLanguageModel chatLanguageModel) {

assertThat(response.content().text()).contains("Amar", "Akbar", "Antony");
}

@ParameterizedTest
@MethodSource("models")
void should_use_programmatically_configured_tools(ChatLanguageModel chatLanguageModel) {

// given
ToolSpecification toolSpecification = ToolSpecification.builder()
.name("get_booking_details")
.description("Returns booking details")
.addParameter("bookingNumber", type("string"))
.build();

ToolExecutor toolExecutor = (toolExecutionRequest, memoryId) -> {
Map<String, Object> arguments = toMap(toolExecutionRequest.arguments());
assertThat(arguments).containsExactly(entry("bookingNumber", "123-456"));
return "Booking period: from 1 July 2027 to 10 July 2027";
};

Assistant assistant = AiServices.builder(Assistant.class)
.chatLanguageModel(chatLanguageModel)
.tools(singletonMap(toolSpecification, toolExecutor))
.build();

// when
Response<AiMessage> response = assistant.chat("When does my booking 123-456 starts?");

// then
assertThat(response.content().text()).contains("2027");
}

private static Map<String, Object> toMap(String arguments) {
try {
return new ObjectMapper().readValue(arguments, new TypeReference<Map<String, Object>>() {
});
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}

0 comments on commit c5e8af8

Please sign in to comment.