Skip to content

Commit

Permalink
Added support for streaming ToolExecutionRequest from LLM (langchain4…
Browse files Browse the repository at this point in the history
…j#44)

Now, the StreamingChatLanguageModel can be used in conjunction with
tools.
One can send tool specifications along with a message to the LLM, and
the LLM can either stream a response or initiate a request to execute a
tool (also as a stream of tokens).
  • Loading branch information
langchain4j authored Jul 23, 2023
1 parent 81e1a9c commit 3cc75c7
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 50 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package dev.langchain4j.model;

public interface StreamingResponseHandler {

/**
* This method is invoked each time LLM sends a token.
*
* @param token single token, part of a complete response
*/
void onNext(String token);

/**
* This method is invoked when LLM decides to execute a tool.
* It is supposed to work exclusively with the StreamingChatLanguageModel.
*
* @param name the name of the tool that LLM has chosen to execute
*/
default void onToolName(String name) {
}

/**
* This method is invoked each time LLM sends a token.
* This is how the following string with arguments { "argument": "value" } can be streamed:
* 1. "{"
* 2. " \""
* 3. "argument"
* 4: "\":"
* 5. " \""
* 6. "value"
* 7. "\" "
* 8. "}"
* It is supposed to work exclusively with the StreamingChatLanguageModel.
*
* @param arguments single token, a part of the complete arguments JSON object
*/
default void onToolArguments(String arguments) {
}

/**
* This method is invoked once LLM has finished responding.
*/
default void onComplete() {
}

/**
* This method is invoked when an error occurs during streaming.
*
* @param error the Throwable error that occurred
*/
void onError(Throwable error);
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
package dev.langchain4j.model.chat;

import dev.langchain4j.WillChangeSoon;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResultHandler;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.StreamingResponseHandler;

import java.util.List;

public interface StreamingChatLanguageModel {

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void sendUserMessage(String userMessage, StreamingResultHandler handler);
void sendUserMessage(String userMessage, StreamingResponseHandler handler);

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void sendUserMessage(UserMessage userMessage, StreamingResultHandler handler);
void sendUserMessage(UserMessage userMessage, StreamingResponseHandler handler);

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void sendUserMessage(Object structuredPrompt, StreamingResultHandler handler);
void sendUserMessage(Object structuredPrompt, StreamingResponseHandler handler);

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void sendMessages(List<ChatMessage> messages, StreamingResultHandler handler);
void sendMessages(List<ChatMessage> messages, StreamingResponseHandler handler);

void sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler handler);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import dev.langchain4j.model.input.Prompt;

/**
* Represents a LLM with a simple text interface.
* It is recommended to use the ChatLanguageModel instead, as it offers greater capabilities.
* More details: https://openai.com/blog/gpt-4-api-general-availability
*/
public interface LanguageModel {

String process(String text);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
package dev.langchain4j.model.language;

import dev.langchain4j.WillChangeSoon;
import dev.langchain4j.model.StreamingResultHandler;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.input.Prompt;

public interface StreamingLanguageModel {

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void process(String text, StreamingResultHandler handler);
void process(String text, StreamingResponseHandler handler);

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void process(Prompt prompt, StreamingResultHandler handler);
void process(Prompt prompt, StreamingResponseHandler handler);

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void process(Object structuredPrompt, StreamingResultHandler handler);
void process(Object structuredPrompt, StreamingResponseHandler handler);
}
2 changes: 1 addition & 1 deletion langchain4j-parent/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
<dependency>
<groupId>dev.ai4j</groupId>
<artifactId>openai4j</artifactId>
<version>0.5.2</version>
<version>0.6.0</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.Delta;
import dev.ai4j.openai4j.chat.FunctionCall;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.StreamingResultHandler;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.input.Prompt;
Expand All @@ -16,6 +20,7 @@

import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;
import static dev.langchain4j.model.openai.OpenAiHelper.toFunctions;
import static dev.langchain4j.model.openai.OpenAiHelper.toOpenAiMessages;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static java.time.Duration.ofSeconds;
Expand Down Expand Up @@ -67,28 +72,33 @@ public OpenAiStreamingChatModel(String apiKey,
}

@Override
public void sendUserMessage(String text, StreamingResultHandler handler) {
public void sendUserMessage(String text, StreamingResponseHandler handler) {
sendUserMessage(userMessage(text), handler);
}

@Override
public void sendUserMessage(UserMessage userMessage, StreamingResultHandler handler) {
public void sendUserMessage(UserMessage userMessage, StreamingResponseHandler handler) {
sendMessages(singletonList(userMessage), handler);
}

@Override
public void sendUserMessage(Object structuredPrompt, StreamingResultHandler handler) {
public void sendUserMessage(Object structuredPrompt, StreamingResponseHandler handler) {
Prompt prompt = toPrompt(structuredPrompt);
sendUserMessage(prompt.toUserMessage(), handler);
}

@Override
public void sendMessages(List<ChatMessage> messages, StreamingResultHandler handler) {
public void sendMessages(List<ChatMessage> messages, StreamingResponseHandler handler) {
sendMessages(messages, null, handler);
}

@Override
public void sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler handler) {
ChatCompletionRequest request = ChatCompletionRequest.builder()
.stream(true)
.model(modelName)
.messages(toOpenAiMessages(messages))
.functions(toFunctions(toolSpecifications))
.temperature(temperature)
.topP(topP)
.maxTokens(maxTokens)
Expand All @@ -97,17 +107,28 @@ public void sendMessages(List<ChatMessage> messages, StreamingResultHandler hand
.build();

client.chatCompletion(request)
.onPartialResponse(partialResponse -> {
String content = partialResponse.choices().get(0).delta().content();
if (content != null) {
handler.onPartialResult(content);
}
})
.onPartialResponse(partialResponse -> handle(partialResponse, handler))
.onComplete(handler::onComplete)
.onError(handler::onError)
.execute();
}

private static void handle(ChatCompletionResponse partialResponse,
StreamingResponseHandler handler) {
Delta delta = partialResponse.choices().get(0).delta();
String content = delta.content();
FunctionCall functionCall = delta.functionCall();
if (content != null) {
handler.onNext(content);
} else if (functionCall != null) {
if (functionCall.name() != null) {
handler.onToolName(functionCall.name());
} else if (functionCall.arguments() != null) {
handler.onToolArguments(functionCall.arguments());
}
}
}

@Override
public int estimateTokenCount(String text) {
return estimateTokenCount(userMessage(text));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.completion.CompletionRequest;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.StreamingResultHandler;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.language.StreamingLanguageModel;
import dev.langchain4j.model.language.TokenCountEstimator;
Expand Down Expand Up @@ -50,7 +50,7 @@ public OpenAiStreamingLanguageModel(String apiKey,
}

@Override
public void process(String text, StreamingResultHandler handler) {
public void process(String text, StreamingResponseHandler handler) {
CompletionRequest request = CompletionRequest.builder()
.model(modelName)
.prompt(text)
Expand All @@ -61,7 +61,7 @@ public void process(String text, StreamingResultHandler handler) {
.onPartialResponse(partialResponse -> {
String partialResponseText = partialResponse.text();
if (partialResponseText != null) {
handler.onPartialResult(partialResponseText);
handler.onNext(partialResponseText);
}
})
.onComplete(handler::onComplete)
Expand All @@ -70,12 +70,12 @@ public void process(String text, StreamingResultHandler handler) {
}

@Override
public void process(Prompt prompt, StreamingResultHandler handler) {
public void process(Prompt prompt, StreamingResponseHandler handler) {
process(prompt.text(), handler);
}

@Override
public void process(Object structuredPrompt, StreamingResultHandler handler) {
public void process(Object structuredPrompt, StreamingResponseHandler handler) {
process(toPrompt(structuredPrompt), handler);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package dev.langchain4j.model.openai;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import org.junit.jupiter.api.Test;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;

import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;

class OpenAiStreamingChatModelIT {

StreamingChatLanguageModel model
= OpenAiStreamingChatModel.withApiKey(System.getenv("OPENAI_API_KEY"));

@Test
void should_stream_answer() throws ExecutionException, InterruptedException, TimeoutException {

CompletableFuture<String> future = new CompletableFuture<>();

model.sendUserMessage(
"What is the capital of Germany?",
new StreamingResponseHandler() {

final StringBuilder answerBuilder = new StringBuilder();

@Override
public void onNext(String partialResult) {
answerBuilder.append(partialResult);
System.out.println("onPartialResult: '" + partialResult + "'");
}

@Override
public void onComplete() {
future.complete(answerBuilder.toString());
System.out.println("onComplete");
}

@Override
public void onError(Throwable error) {
future.completeExceptionally(error);
}
});

String answer = future.get(30, SECONDS);

assertThat(answer).contains("Berlin");
}

@Test
void should_stream_tool_execution_request() throws Exception {

ToolSpecification toolSpecification = ToolSpecification.builder()
.name("calculator")
.description("returns a sum of two numbers")
.addParameter("first", INTEGER)
.addParameter("second", INTEGER)
.build();

UserMessage userMessage = userMessage("Two plus two?");

CompletableFuture<String> future = new CompletableFuture<>();

model.sendMessages(
singletonList(userMessage),
singletonList(toolSpecification),
new StreamingResponseHandler() {

final StringBuilder answerBuilder = new StringBuilder();

@Override
public void onNext(String partialResult) {
answerBuilder.append(partialResult);
System.out.println("onPartialResult: '" + partialResult + "'");
}

@Override
public void onToolName(String name) {
answerBuilder.append(name);
System.out.println("onToolName: '" + name + "'");
}

@Override
public void onToolArguments(String arguments) {
answerBuilder.append(arguments);
System.out.println("onToolArguments: '" + arguments + "'");
}

@Override
public void onComplete() {
future.complete(answerBuilder.toString());
System.out.println("onComplete");
}

@Override
public void onError(Throwable error) {
future.completeExceptionally(error);
}
});

String answer = future.get(30, SECONDS);

assertThat(answer).isEqualToIgnoringWhitespace("calculator {\"first\": 2, \"second\": 2}");
}
}

0 comments on commit 3cc75c7

Please sign in to comment.