Skip to content

[FEATURE] avoid second round trip for function call #656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -48,6 +49,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -391,6 +393,16 @@ public ChatCompletion build() {

}

@Override
protected boolean hasReturningFunction(RequestMessage responseMessage) {
return responseMessage.content()
.stream()
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
.map(MediaContent::name)
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
}

@Override
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
RequestMessage responseMessage, List<RequestMessage> conversationHistory) {
Expand All @@ -414,8 +426,9 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques

String functionResponse = this.functionCallbackRegister.get(functionName)
.call(ModelOptionsUtils.toJsonString(functionArguments));

toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
if (functionResponse != null) {
toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
}
}

// Add the function response to the conversation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.ChatResponseMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
import com.azure.ai.openai.models.FunctionCall;
Expand All @@ -50,6 +51,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
Expand Down Expand Up @@ -513,6 +515,15 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
return copyOptions;
}

@Override
protected boolean hasReturningFunction(ChatRequestMessage responseMessage) {
return ((ChatRequestAssistantMessage) responseMessage).getToolCalls()
.stream()
.map(toolCall -> ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName())
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
}

@Override
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
Expand All @@ -530,8 +541,10 @@ protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOpti

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);

// Add the function response to the conversation.
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
if (functionResponse != null) {
// Add the function response to the conversation.
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
}
}

// Recursively call chatCompletionWithTools until the model doesn't call a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,29 @@ void streamFunctionCallTest() {
assertThat(content).containsAnyOf("15.0", "15");
}

@Test
void functionCallWithoutCompleteRoundTrip() {

UserMessage userMessage = new UserMessage("What's the weather like in San Francisco?");

List<Message> messages = new ArrayList<>(List.of(userMessage));

final var spyingMockWeatherService = new SpyingMockWeatherService();
var promptOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(selectedModel)
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(spyingMockWeatherService)
.withName("getCurrentWeather")
.withDescription("Get the current weather in a given location")
.build()))
.build();

ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));

logger.info("Response: {}", response);
final var interceptedRequest = spyingMockWeatherService.getInterceptedRequest();
assertThat(interceptedRequest.location()).containsIgnoringCase("San Francisco");
}

@SpringBootConfiguration
public static class TestConfiguration {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.azure.openai.function;

import java.util.function.Function;

public class SpyingMockWeatherService implements Function<MockWeatherService.Request, Void> {

private MockWeatherService.Request interceptedRequest = null;

@Override
public Void apply(MockWeatherService.Request request) {
interceptedRequest = request;
return null;
}

public MockWeatherService.Request getInterceptedRequest() {
return interceptedRequest;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -98,7 +99,6 @@ public ChatResponse call(Prompt prompt) {
var request = createRequest(prompt, false);

return retryTemplate.execute(ctx -> {

ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);

var chatCompletion = completionEntity.getBody();
Expand Down Expand Up @@ -239,13 +239,18 @@ private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNam
}).toList();
}

//
// Function Calling Support
//
@Override
protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) {
return responseMessage.toolCalls()
.stream()
.map(toolCall -> toolCall.function().name())
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
}

@Override
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {

// Every tool-call item requires a separate function call and a response (TOOL)
// message.
for (ToolCall toolCall : responseMessage.toolCalls()) {
Expand All @@ -258,10 +263,12 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques
}

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
if (functionResponse != null) {
// Add the function response to the conversation.
conversationHistory.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL,
functionName, null));
}

// Add the function response to the conversation.
conversationHistory
.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL, functionName, null));
}

// Recursively call chatCompletionWithTools until the model doesn't call a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
Expand Down Expand Up @@ -324,6 +325,15 @@ private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames)
}).toList();
}

@Override
protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) {
return responseMessage.toolCalls()
.stream()
.map(toolCall -> toolCall.function().name())
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
}

@Override
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
Expand All @@ -340,10 +350,11 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques
}

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);

// Add the function response to the conversation.
conversationHistory
.add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
if (functionResponse != null) {
// Add the function response to the conversation.
conversationHistory
.add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
}
}

// Recursively call chatCompletionWithTools until the model doesn't call a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata;
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
Expand All @@ -57,6 +58,7 @@
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -406,6 +408,14 @@ public void destroy() throws Exception {
}
}

@Override
protected boolean hasReturningFunction(Content responseMessage) {
final var functionName = responseMessage.getPartsList().get(0).getFunctionCall().getName();
return Optional.ofNullable(this.functionCallbackRegister.get(functionName))
.map(FunctionCallback::returningFunction)
.orElse(false);
}

@Override
protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousRequest, Content responseMessage,
List<Content> conversationHistory) {
Expand All @@ -420,17 +430,18 @@ protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousReques
}

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);

Content contentFnResp = Content.newBuilder()
.addParts(Part.newBuilder()
.setFunctionResponse(FunctionResponse.newBuilder()
.setName(functionCall.getName())
.setResponse(jsonToStruct(functionResponse))
if (functionResponse != null) {
Content contentFnResp = Content.newBuilder()
.addParts(Part.newBuilder()
.setFunctionResponse(FunctionResponse.newBuilder()
.setName(functionCall.getName())
.setResponse(jsonToStruct(functionResponse))
.build())
.build())
.build())
.build();
.build();

conversationHistory.add(contentFnResp);
conversationHistory.add(contentFnResp);
}

return new GeminiRequest(conversationHistory, previousRequest.model());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,17 @@ protected Set<String> handleFunctionCallbackConfigurations(FunctionCallingOption

if (options != null) {
if (!CollectionUtils.isEmpty(options.getFunctionCallbacks())) {
options.getFunctionCallbacks().stream().forEach(functionCallback -> {

options.getFunctionCallbacks().forEach(functionCallback -> {
// Register the tool callback.
if (isRuntimeCall) {
this.functionCallbackRegister.put(functionCallback.getName(), functionCallback);
// Automatically enable the function, usually from prompt
// callback.
functionToCall.add(functionCallback.getName());
}
else {
this.functionCallbackRegister.putIfAbsent(functionCallback.getName(), functionCallback);
}

// Automatically enable the function, usually from prompt callback.
if (isRuntimeCall) {
functionToCall.add(functionCallback.getName());
}
});
}

Expand Down Expand Up @@ -147,6 +144,9 @@ protected Resp handleFunctionCallOrReturn(Req request, Resp response) {

Req newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory);

if (!this.hasReturningFunction(responseMessage)) {
return response;
}
return this.callWithFunctionSupport(newRequest);
}

Expand Down Expand Up @@ -180,6 +180,8 @@ protected Flux<Resp> handleFunctionCallOrReturnStream(Req request, Flux<Resp> re

}

abstract protected boolean hasReturningFunction(Msg responseMessage);

abstract protected Req doCreateToolResponseRequest(Req previousRequest, Msg responseMessage,
List<Msg> conversationHistory);

Expand Down
Loading