Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j committed Sep 22, 2024
1 parent 84915ba commit 33199dc
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.*;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.Tokenizer;

import java.util.List;
Expand All @@ -19,7 +26,13 @@
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.model.openai.OpenAiChatModelName.*;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO_0125;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO_1106;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_0125_PREVIEW;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_1106_PREVIEW;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_TURBO_PREVIEW;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_VISION_PREVIEW;
import static java.util.Collections.singletonList;

/**
Expand Down Expand Up @@ -247,9 +260,9 @@ private int estimateTokenCountInToolParameters(ToolParameters parameters) {
} else {
tokenCount -= 3;
}
for (Object enumValue : (Object[]) entry.getValue()) {
for (String enumValue : (List<String>) entry.getValue()) {
tokenCount += 3;
tokenCount += estimateTokenCountInText(enumValue.toString());
tokenCount += estimateTokenCountInText(enumValue);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.ARRAY;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.STRING;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.description;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.from;
import static dev.langchain4j.agent.tool.JsonSchemaProperty.items;
import static dev.langchain4j.model.mistralai.MistralAiChatModelName.MISTRAL_LARGE_LATEST;
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;
import static dev.langchain4j.service.StreamingAiServicesWithToolsIT.TemperatureUnit.CELSIUS;
import static dev.langchain4j.service.StreamingAiServicesWithToolsIT.TransactionService.EXPECTED_SPECIFICATION;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.any;
Expand All @@ -52,6 +55,7 @@ static Stream<StreamingChatLanguageModel> models() {
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.apiKey(System.getenv("OPENAI_API_KEY"))
.organizationId(System.getenv("OPENAI_ORGANIZATION_ID"))
.modelName(GPT_4_O_MINI)
.temperature(0.0)
.logRequests(true)
.logResponses(true)
Expand Down Expand Up @@ -133,7 +137,7 @@ void should_use_tool_with_List_of_Strings_parameter(StreamingChatLanguageModel m
.onComplete(future::complete)
.onError(future::completeExceptionally)
.start();
Response<AiMessage> response = future.get(60, TimeUnit.SECONDS);
Response<AiMessage> response = future.get(60, SECONDS);

// then
assertThat(response.content().text()).contains("42", "57");
Expand All @@ -156,6 +160,74 @@ void should_use_tool_with_List_of_Strings_parameter(StreamingChatLanguageModel m
);
}

static class WeatherService {

static ToolSpecification EXPECTED_SPECIFICATION = ToolSpecification.builder()
.name("currentTemperature")
.description("")
.addParameter("arg0", STRING)
.addParameter("arg1", STRING, from("enum", asList("CELSIUS", "fahrenheit", "Kelvin")))
.build();

@Tool
int currentTemperature(String city, TemperatureUnit unit) {
System.out.printf("called currentTemperature(%s, %s)%n", city, unit);
return 19;
}
}

enum TemperatureUnit {
CELSIUS, fahrenheit, Kelvin
}

@ParameterizedTest
@MethodSource("models")
void should_use_tool_with_enum_parameter(StreamingChatLanguageModel model) throws Exception {

// given
WeatherService weatherService = spy(new WeatherService());

ChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(10);

StreamingChatLanguageModel spyModel = spy(model);

Assistant assistant = AiServices.builder(Assistant.class)
.streamingChatLanguageModel(spyModel)
.chatMemory(chatMemory)
.tools(weatherService)
.build();

String userMessage = "What is the temperature in Munich now, in Celsius?";

// when
CompletableFuture<Response<AiMessage>> future = new CompletableFuture<>();
assistant.chat(userMessage)
.onNext(token -> {
})
.onComplete(future::complete)
.onError(future::completeExceptionally)
.start();
Response<AiMessage> response = future.get(60, SECONDS);

// then
assertThat(response.content().text()).contains("19");

verify(weatherService).currentTemperature("Munich", CELSIUS);
verifyNoMoreInteractions(weatherService);

List<ChatMessage> messages = chatMemory.messages();
verify(spyModel).generate(
eq(singletonList(messages.get(0))),
eq(singletonList(WeatherService.EXPECTED_SPECIFICATION)),
any()
);
verify(spyModel).generate(
eq(asList(messages.get(0), messages.get(1), messages.get(2))),
eq(singletonList(WeatherService.EXPECTED_SPECIFICATION)),
any()
);
}

@Test
void should_use_tool_provider() throws Exception {

Expand Down Expand Up @@ -186,7 +258,7 @@ void should_use_tool_provider() throws Exception {
.onComplete(future::complete)
.onError(future::completeExceptionally)
.start();
Response<AiMessage> response = future.get(60, TimeUnit.SECONDS);
Response<AiMessage> response = future.get(60, SECONDS);

// then
assertThat(response.content().text()).contains("42", "57");
Expand Down

0 comments on commit 33199dc

Please sign in to comment.