Skip to content

Commit

Permalink
OpenAI: Support parallel tool calling (langchain4j#338)
Browse files Browse the repository at this point in the history
This PR introduces a support for [parallel tool
calling](https://platform.openai.com/docs/guides/function-calling/parallel-function-calling)
in OpenAI integration.
  • Loading branch information
dliubarskyi authored Dec 8, 2023
1 parent 09ab6a1 commit 303b2ab
Show file tree
Hide file tree
Showing 56 changed files with 1,560 additions and 351 deletions.
1 change: 1 addition & 0 deletions langchain4j-azure-open-ai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

<licenses>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import dev.langchain4j.model.output.Response;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.azure.AzureOpenAiModelName.GPT_3_5_TURBO;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.setupOpenAIClient;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.toFunctions;
import static java.util.Collections.singletonList;

/**
Expand Down Expand Up @@ -135,7 +137,7 @@ private void generate(List<ChatMessage> messages,
Integer inputTokenCount = tokenizer == null ? null : tokenizer.estimateTokenCountInMessages(messages);

if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
options.setFunctions(InternalAzureOpenAiHelper.toFunctions(toolSpecifications));
options.setFunctions(toFunctions(toolSpecifications));
if (tokenizer != null) {
inputTokenCount += tokenizer.estimateTokenCountInToolSpecifications(toolSpecifications);
}
Expand All @@ -156,7 +158,8 @@ private void generate(List<ChatMessage> messages,
responseBuilder.append(chatCompletions);
handle(chatCompletions, handler);
});
handler.onComplete(responseBuilder.build());
Response<AiMessage> response = responseBuilder.build(tokenizer, toolThatMustBeExecuted != null);
handler.onComplete(response);
} catch (Exception exception) {
handler.onError(exception);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void generate(String prompt, StreamingResponseHandler<String> handler) {
handle(completions, handler);
});

Response<AiMessage> response = responseBuilder.build();
Response<AiMessage> response = responseBuilder.build(tokenizer, false);
handler.onComplete(Response.from(
response.content().text(),
response.tokenUsage(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import com.azure.ai.openai.models.*;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;

import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
import static dev.langchain4j.model.azure.InternalAzureOpenAiHelper.finishReasonFrom;
import static java.util.Collections.singletonList;

/**
* This class needs to be thread safe because it is called when a streaming result comes back
Expand All @@ -21,9 +22,9 @@ class AzureOpenAiStreamingResponseBuilder {
private final StringBuffer contentBuilder = new StringBuffer();
private final StringBuffer toolNameBuilder = new StringBuffer();
private final StringBuffer toolArgumentsBuilder = new StringBuffer();
private volatile CompletionsFinishReason finishReason;

private final Integer inputTokenCount;
private final AtomicInteger outputTokenCount = new AtomicInteger();
private volatile String finishReason;

public AzureOpenAiStreamingResponseBuilder(Integer inputTokenCount) {
this.inputTokenCount = inputTokenCount;
Expand All @@ -46,7 +47,7 @@ public void append(ChatCompletions completions) {

CompletionsFinishReason finishReason = chatCompletionChoice.getFinishReason();
if (finishReason != null) {
this.finishReason = finishReason.toString();
this.finishReason = finishReason;
}

com.azure.ai.openai.models.ChatMessage delta = chatCompletionChoice.getDelta();
Expand All @@ -57,20 +58,17 @@ public void append(ChatCompletions completions) {
String content = delta.getContent();
if (content != null) {
contentBuilder.append(content);
outputTokenCount.incrementAndGet();
return;
}

FunctionCall functionCall = delta.getFunctionCall();
if (functionCall != null) {
if (functionCall.getName() != null) {
toolNameBuilder.append(functionCall.getName());
outputTokenCount.incrementAndGet();
}

if (functionCall.getArguments() != null) {
toolArgumentsBuilder.append(functionCall.getArguments());
outputTokenCount.incrementAndGet();
}
}
}
Expand All @@ -92,39 +90,63 @@ public void append(Completions completions) {

CompletionsFinishReason completionsFinishReason = completionChoice.getFinishReason();
if (completionsFinishReason != null) {
this.finishReason = completionsFinishReason.toString();
this.finishReason = completionsFinishReason;
}

String token = completionChoice.getText();
if (token != null) {
contentBuilder.append(token);
outputTokenCount.incrementAndGet();
}
}

public Response<AiMessage> build() {
public Response<AiMessage> build(Tokenizer tokenizer, boolean forcefulToolExecution) {

String content = contentBuilder.toString();
if (!content.isEmpty()) {
return Response.from(
AiMessage.from(content),
new TokenUsage(inputTokenCount, outputTokenCount.get()),
tokenUsage(content, tokenizer),
finishReasonFrom(finishReason)
);
}

String toolName = toolNameBuilder.toString();
if (!toolName.isEmpty()) {
ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder()
.name(toolName)
.arguments(toolArgumentsBuilder.toString())
.build();
return Response.from(
AiMessage.from(ToolExecutionRequest.builder()
.name(toolName)
.arguments(toolArgumentsBuilder.toString())
.build()),
new TokenUsage(inputTokenCount, outputTokenCount.get()),
AiMessage.from(toolExecutionRequest),
tokenUsage(toolExecutionRequest, tokenizer, forcefulToolExecution),
finishReasonFrom(finishReason)
);
}

return null;
}

private TokenUsage tokenUsage(String content, Tokenizer tokenizer) {
if (tokenizer == null) {
return null;
}
int outputTokenCount = tokenizer.estimateTokenCountInText(content);
return new TokenUsage(inputTokenCount, outputTokenCount);
}

private TokenUsage tokenUsage(ToolExecutionRequest toolExecutionRequest, Tokenizer tokenizer, boolean forcefulToolExecution) {
if (tokenizer == null) {
return null;
}

int outputTokenCount = 0;
if (forcefulToolExecution) {
// OpenAI calculates output tokens differently when tool is executed forcefully
outputTokenCount += tokenizer.estimateTokenCountInForcefulToolExecutionRequest(toolExecutionRequest);
} else {
outputTokenCount = tokenizer.estimateTokenCountInToolExecutionRequests(singletonList(toolExecutionRequest));
}

return new TokenUsage(inputTokenCount, outputTokenCount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,10 @@ private static String nameFrom(ChatMessage message) {
private static FunctionCall functionCallFrom(ChatMessage message) {
if (message instanceof AiMessage) {
AiMessage aiMessage = (AiMessage) message;
if (aiMessage.toolExecutionRequest() != null) {
return new FunctionCall(aiMessage.toolExecutionRequest().name(),
aiMessage.toolExecutionRequest().arguments());
if (aiMessage.hasToolExecutionRequests()) {
// TODO switch to tools once supported
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
return new FunctionCall(toolExecutionRequest.name(), toolExecutionRequest.arguments());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import static dev.langchain4j.model.output.FinishReason.STOP;
import static org.assertj.core.api.Assertions.assertThat;

public class AzureOpenAIChatModelIT {
public class AzureOpenAiChatModelIT {

Logger logger = LoggerFactory.getLogger(AzureOpenAIChatModelIT.class);
Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModelIT.class);

@Test
void should_generate_answer_and_return_token_usage_and_finish_reason_stop() {
Expand Down Expand Up @@ -102,14 +102,17 @@ void should_call_function_with_argument() {

Response<AiMessage> response = model.generate(Collections.singletonList(userMessage), toolSpecification);

assertThat(response.content().text()).isBlank();
assertThat(response.content().toolExecutionRequest().name()).isEqualTo(toolName);
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isBlank();

assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo(toolName);

// We should get a response telling how to call the "getCurrentWeather" function, with the correct parameters in JSON format.
logger.info(response.toString());

// We can now call the function with the correct parameters.
ToolExecutionRequest toolExecutionRequest = response.content().toolExecutionRequest();
WeatherLocation weatherLocation = BinaryData.fromString(toolExecutionRequest.arguments()).toObject(WeatherLocation.class);
int currentWeather = 0;
currentWeather = getCurrentWeather(weatherLocation);
Expand All @@ -121,12 +124,13 @@ void should_call_function_with_argument() {
assertThat(weather).isEqualTo("The weather in Paris, France is 35 degrees celsius.");

// Now that we know the function's result, we can call the model again with the result as input.
ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolName, weather);
ToolExecutionResultMessage toolExecutionResultMessage = toolExecutionResultMessage(toolExecutionRequest, weather);
SystemMessage systemMessage = SystemMessage.systemMessage("If the weather is above 30 degrees celsius, recommend the user wears a t-shirt and shorts.");

List<ChatMessage> chatMessages = new ArrayList<>();
chatMessages.add(systemMessage);
chatMessages.add(userMessage);
chatMessages.add(aiMessage);
chatMessages.add(toolExecutionResultMessage);

Response<AiMessage> response2 = model.generate(chatMessages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ public void onError(Throwable error) {
AiMessage aiMessage = response.content();
assertThat(aiMessage.text()).isNull();

ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequest();
assertThat(aiMessage.toolExecutionRequests()).hasSize(1);
ToolExecutionRequest toolExecutionRequest = aiMessage.toolExecutionRequests().get(0);
assertThat(toolExecutionRequest.name()).isEqualTo("calculator");
assertThat(toolExecutionRequest.arguments()).isEqualToIgnoringWhitespace("{\"first\": 2, \"second\": 2}");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.Objects;

import static dev.langchain4j.internal.Utils.quoted;
import static java.util.Collections.singletonMap;

public class JsonSchemaProperty {

Expand Down Expand Up @@ -95,4 +96,8 @@ public static JsonSchemaProperty enums(Class<?> enumClass) {

return from("enum", enumClass.getEnumConstants());
}

public static JsonSchemaProperty items(JsonSchemaProperty type) {
return from("items", singletonMap(type.key, type.value));
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
package dev.langchain4j.agent.tool;

import static dev.langchain4j.internal.Utils.quoted;
import java.util.Objects;

import static dev.langchain4j.internal.Utils.quoted;

public class ToolExecutionRequest {

private final String id;
private final String name;
private final String arguments;

private ToolExecutionRequest(Builder builder) {
this.id = builder.id;
this.name = builder.name;
this.arguments = builder.arguments;
}

public String id() {
return id;
}

public String name() {
return name;
}
Expand All @@ -29,13 +36,15 @@ public boolean equals(Object another) {
}

private boolean equalTo(ToolExecutionRequest another) {
return Objects.equals(name, another.name)
return Objects.equals(id, another.id)
&& Objects.equals(name, another.name)
&& Objects.equals(arguments, another.arguments);
}

@Override
public int hashCode() {
int h = 5381;
h += (h << 5) + Objects.hashCode(id);
h += (h << 5) + Objects.hashCode(name);
h += (h << 5) + Objects.hashCode(arguments);
return h;
Expand All @@ -44,7 +53,8 @@ public int hashCode() {
@Override
public String toString() {
return "ToolExecutionRequest {"
+ " name = " + quoted(name)
+ " id = " + quoted(id)
+ ", name = " + quoted(name)
+ ", arguments = " + quoted(arguments)
+ " }";
}
Expand All @@ -55,12 +65,18 @@ public static Builder builder() {

public static final class Builder {

private String id;
private String name;
private String arguments;

private Builder() {
}

public Builder id(String id) {
this.id = id;
return this;
}

public Builder name(String name) {
this.name = name;
return this;
Expand Down
Loading

0 comments on commit 303b2ab

Please sign in to comment.