forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support for streaming ToolExecutionRequest from LLM (langchain4…
…j#44) Now, the StreamingChatLanguageModel can be used in conjunction with tools. One can send tool specifications along with a message to the LLM, and the LLM can either stream a response or initiate a request to execute a tool (also as a stream of tokens).
- Loading branch information
1 parent
81e1a9c
commit 3cc75c7
Showing
9 changed files
with
219 additions
and
50 deletions.
There are no files selected for viewing
51 changes: 51 additions & 0 deletions
51
langchain4j-core/src/main/java/dev/langchain4j/model/StreamingResponseHandler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package dev.langchain4j.model; | ||
|
||
public interface StreamingResponseHandler { | ||
|
||
/** | ||
* This method is invoked each time LLM sends a token. | ||
* | ||
* @param token single token, part of a complete response | ||
*/ | ||
void onNext(String token); | ||
|
||
/** | ||
* This method is invoked when LLM decides to execute a tool. | ||
* It is supposed to work exclusively with the StreamingChatLanguageModel. | ||
* | ||
* @param name the name of the tool that LLM has chosen to execute | ||
*/ | ||
default void onToolName(String name) { | ||
} | ||
|
||
/** | ||
* This method is invoked each time LLM sends a token. | ||
* This is how the following string with arguments { "argument": "value" } can be streamed: | ||
* 1. "{" | ||
* 2. " \"" | ||
* 3. "argument" | ||
* 4: "\":" | ||
* 5. " \"" | ||
* 6. "value" | ||
* 7. "\" " | ||
* 8. "}" | ||
* It is supposed to work exclusively with the StreamingChatLanguageModel. | ||
* | ||
* @param arguments single token, a part of the complete arguments JSON object | ||
*/ | ||
default void onToolArguments(String arguments) { | ||
} | ||
|
||
/** | ||
* This method is invoked once LLM has finished responding. | ||
*/ | ||
default void onComplete() { | ||
} | ||
|
||
/** | ||
* This method is invoked when an error occurs during streaming. | ||
* | ||
* @param error the Throwable error that occurred | ||
*/ | ||
void onError(Throwable error); | ||
} |
14 changes: 0 additions & 14 deletions
14
langchain4j-core/src/main/java/dev/langchain4j/model/StreamingResultHandler.java
This file was deleted.
Oops, something went wrong.
19 changes: 8 additions & 11 deletions
19
langchain4j-core/src/main/java/dev/langchain4j/model/chat/StreamingChatLanguageModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,21 @@ | ||
package dev.langchain4j.model.chat; | ||
|
||
import dev.langchain4j.WillChangeSoon; | ||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.model.StreamingResultHandler; | ||
import dev.langchain4j.model.input.Prompt; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
|
||
import java.util.List; | ||
|
||
public interface StreamingChatLanguageModel { | ||
|
||
@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API") | ||
void sendUserMessage(String userMessage, StreamingResultHandler handler); | ||
void sendUserMessage(String userMessage, StreamingResponseHandler handler); | ||
|
||
@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API") | ||
void sendUserMessage(UserMessage userMessage, StreamingResultHandler handler); | ||
void sendUserMessage(UserMessage userMessage, StreamingResponseHandler handler); | ||
|
||
@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API") | ||
void sendUserMessage(Object structuredPrompt, StreamingResultHandler handler); | ||
void sendUserMessage(Object structuredPrompt, StreamingResponseHandler handler); | ||
|
||
@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API") | ||
void sendMessages(List<ChatMessage> messages, StreamingResultHandler handler); | ||
void sendMessages(List<ChatMessage> messages, StreamingResponseHandler handler); | ||
|
||
void sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler handler); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 4 additions & 8 deletions
12
langchain4j-core/src/main/java/dev/langchain4j/model/language/StreamingLanguageModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,13 @@ | ||
package dev.langchain4j.model.language; | ||
|
||
import dev.langchain4j.WillChangeSoon; | ||
import dev.langchain4j.model.StreamingResultHandler; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.input.Prompt; | ||
|
||
public interface StreamingLanguageModel { | ||
|
||
@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API") | ||
void process(String text, StreamingResultHandler handler); | ||
void process(String text, StreamingResponseHandler handler); | ||
|
||
@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API") | ||
void process(Prompt prompt, StreamingResultHandler handler); | ||
void process(Prompt prompt, StreamingResponseHandler handler); | ||
|
||
@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API") | ||
void process(Object structuredPrompt, StreamingResultHandler handler); | ||
void process(Object structuredPrompt, StreamingResponseHandler handler); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
113 changes: 113 additions & 0 deletions
113
langchain4j/src/test/java/dev/langchain4j/model/openai/OpenAiStreamingChatModelIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
package dev.langchain4j.model.openai; | ||
|
||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.util.concurrent.CompletableFuture; | ||
import java.util.concurrent.ExecutionException; | ||
import java.util.concurrent.TimeoutException; | ||
|
||
import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; | ||
import static dev.langchain4j.data.message.UserMessage.userMessage; | ||
import static java.util.Collections.singletonList; | ||
import static java.util.concurrent.TimeUnit.SECONDS; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
class OpenAiStreamingChatModelIT { | ||
|
||
StreamingChatLanguageModel model | ||
= OpenAiStreamingChatModel.withApiKey(System.getenv("OPENAI_API_KEY")); | ||
|
||
@Test | ||
void should_stream_answer() throws ExecutionException, InterruptedException, TimeoutException { | ||
|
||
CompletableFuture<String> future = new CompletableFuture<>(); | ||
|
||
model.sendUserMessage( | ||
"What is the capital of Germany?", | ||
new StreamingResponseHandler() { | ||
|
||
final StringBuilder answerBuilder = new StringBuilder(); | ||
|
||
@Override | ||
public void onNext(String partialResult) { | ||
answerBuilder.append(partialResult); | ||
System.out.println("onPartialResult: '" + partialResult + "'"); | ||
} | ||
|
||
@Override | ||
public void onComplete() { | ||
future.complete(answerBuilder.toString()); | ||
System.out.println("onComplete"); | ||
} | ||
|
||
@Override | ||
public void onError(Throwable error) { | ||
future.completeExceptionally(error); | ||
} | ||
}); | ||
|
||
String answer = future.get(30, SECONDS); | ||
|
||
assertThat(answer).contains("Berlin"); | ||
} | ||
|
||
@Test | ||
void should_stream_tool_execution_request() throws Exception { | ||
|
||
ToolSpecification toolSpecification = ToolSpecification.builder() | ||
.name("calculator") | ||
.description("returns a sum of two numbers") | ||
.addParameter("first", INTEGER) | ||
.addParameter("second", INTEGER) | ||
.build(); | ||
|
||
UserMessage userMessage = userMessage("Two plus two?"); | ||
|
||
CompletableFuture<String> future = new CompletableFuture<>(); | ||
|
||
model.sendMessages( | ||
singletonList(userMessage), | ||
singletonList(toolSpecification), | ||
new StreamingResponseHandler() { | ||
|
||
final StringBuilder answerBuilder = new StringBuilder(); | ||
|
||
@Override | ||
public void onNext(String partialResult) { | ||
answerBuilder.append(partialResult); | ||
System.out.println("onPartialResult: '" + partialResult + "'"); | ||
} | ||
|
||
@Override | ||
public void onToolName(String name) { | ||
answerBuilder.append(name); | ||
System.out.println("onToolName: '" + name + "'"); | ||
} | ||
|
||
@Override | ||
public void onToolArguments(String arguments) { | ||
answerBuilder.append(arguments); | ||
System.out.println("onToolArguments: '" + arguments + "'"); | ||
} | ||
|
||
@Override | ||
public void onComplete() { | ||
future.complete(answerBuilder.toString()); | ||
System.out.println("onComplete"); | ||
} | ||
|
||
@Override | ||
public void onError(Throwable error) { | ||
future.completeExceptionally(error); | ||
} | ||
}); | ||
|
||
String answer = future.get(30, SECONDS); | ||
|
||
assertThat(answer).isEqualToIgnoringWhitespace("calculator {\"first\": 2, \"second\": 2}"); | ||
} | ||
} |