forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support for streaming to AI Services (langchain4j#45)
Now one can use TokenStream as a return type in his AI Service in order to stream response from LLM
- Loading branch information
1 parent
529ef6b
commit 80f71fe
Showing
14 changed files
with
557 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
30 changes: 30 additions & 0 deletions
30
langchain4j/src/main/java/dev/langchain4j/service/AiServiceContext.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package dev.langchain4j.service; | ||
|
||
import dev.langchain4j.agent.tool.ToolExecutor; | ||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.memory.ChatMemory; | ||
import dev.langchain4j.model.chat.ChatLanguageModel; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.moderation.ModerationModel; | ||
import dev.langchain4j.retriever.Retriever; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
class AiServiceContext { | ||
|
||
Class<?> aiServiceClass; | ||
|
||
ChatLanguageModel chatLanguageModel; | ||
StreamingChatLanguageModel streamingChatLanguageModel; | ||
|
||
ChatMemory chatMemory; | ||
|
||
ModerationModel moderationModel; | ||
|
||
List<ToolSpecification> toolSpecifications; | ||
Map<String, ToolExecutor> toolExecutors; | ||
|
||
Retriever<TextSegment> retriever; | ||
} |
119 changes: 119 additions & 0 deletions
119
langchain4j/src/main/java/dev/langchain4j/service/AiServiceStreamingResponseHandler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
package dev.langchain4j.service; | ||
|
||
import dev.langchain4j.agent.tool.ToolExecutionRequest; | ||
import dev.langchain4j.agent.tool.ToolExecutor; | ||
import dev.langchain4j.data.message.ToolExecutionResultMessage; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.util.function.Consumer; | ||
|
||
import static dev.langchain4j.data.message.AiMessage.aiMessage; | ||
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; | ||
|
||
/** | ||
* Handles response from LLM for AI Service that is streamed token-by-token. | ||
* Handles both regular (text) responses and responses with the request to execute a tool. | ||
*/ | ||
class AiServiceStreamingResponseHandler implements StreamingResponseHandler { | ||
|
||
private final Logger log = LoggerFactory.getLogger(AiServiceStreamingResponseHandler.class); | ||
|
||
private final AiServiceContext context; | ||
|
||
private final Consumer<String> tokenHandler; | ||
private final Runnable completionHandler; | ||
private final Consumer<Throwable> errorHandler; | ||
|
||
private final StringBuilder answerBuilder; | ||
private final StringBuilder toolNameBuilder; | ||
private final StringBuilder toolArgumentsBuilder; | ||
|
||
AiServiceStreamingResponseHandler(AiServiceContext context, | ||
Consumer<String> tokenHandler, | ||
Runnable completionHandler, | ||
Consumer<Throwable> errorHandler) { | ||
this.context = ensureNotNull(context, "context"); | ||
|
||
this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler"); | ||
this.completionHandler = completionHandler; | ||
this.errorHandler = errorHandler; | ||
|
||
this.answerBuilder = new StringBuilder(); | ||
this.toolNameBuilder = new StringBuilder(); | ||
this.toolArgumentsBuilder = new StringBuilder(); | ||
} | ||
|
||
@Override | ||
public void onNext(String partialResult) { | ||
answerBuilder.append(partialResult); | ||
tokenHandler.accept(partialResult); | ||
} | ||
|
||
@Override | ||
public void onToolName(String name) { | ||
toolNameBuilder.append(name); | ||
} | ||
|
||
@Override | ||
public void onToolArguments(String arguments) { | ||
toolArgumentsBuilder.append(arguments); | ||
} | ||
|
||
@Override | ||
public void onComplete() { | ||
|
||
String toolName = toolNameBuilder.toString(); | ||
String toolArguments = toolArgumentsBuilder.toString(); | ||
|
||
if (toolName.isEmpty()) { | ||
if (context.chatMemory != null) { | ||
context.chatMemory.add(aiMessage(answerBuilder.toString())); | ||
} | ||
if (completionHandler != null) { | ||
completionHandler.run(); | ||
} | ||
} else { | ||
|
||
ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder() | ||
.name(toolName) | ||
.arguments(toolArguments) | ||
.build(); | ||
|
||
if (context.chatMemory != null) { | ||
context.chatMemory.add(aiMessage(toolExecutionRequest)); | ||
} | ||
|
||
ToolExecutor toolExecutor = context.toolExecutors.get(toolName); // TODO what if no such tool? | ||
String toolExecutionResult = toolExecutor.execute(toolExecutionRequest); | ||
ToolExecutionResultMessage toolExecutionResultMessage | ||
= toolExecutionResultMessage(toolExecutionRequest.name(), toolExecutionResult); | ||
|
||
context.chatMemory.add(toolExecutionResultMessage); | ||
|
||
// TODO what if there are multiple tool executions in a row? (for the future) | ||
context.streamingChatLanguageModel.sendMessages( | ||
context.chatMemory.messages(), | ||
// TODO does it make sense to send tools if LLM will not call them in the response anyway? (current openai behavior) | ||
context.toolSpecifications, | ||
new AiServiceStreamingResponseHandler(context, tokenHandler, completionHandler, errorHandler) | ||
); | ||
} | ||
} | ||
|
||
@Override | ||
public void onError(Throwable error) { | ||
if (errorHandler != null) { | ||
try { | ||
errorHandler.accept(error); | ||
} catch (Exception e) { | ||
log.error("While handling the following error...", error); | ||
log.error("...the following error happened", e); | ||
} | ||
} else { | ||
log.warn("Ignored error", error); | ||
} | ||
} | ||
} |
85 changes: 85 additions & 0 deletions
85
langchain4j/src/main/java/dev/langchain4j/service/AiServiceTokenStream.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package dev.langchain4j.service; | ||
|
||
import dev.langchain4j.data.message.ChatMessage; | ||
|
||
import java.util.List; | ||
import java.util.function.Consumer; | ||
|
||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; | ||
|
||
class AiServiceTokenStream implements TokenStream { | ||
|
||
private final List<ChatMessage> messagesToSend; | ||
private final AiServiceContext context; | ||
|
||
AiServiceTokenStream(List<ChatMessage> messagesToSend, AiServiceContext context) { | ||
this.messagesToSend = ensureNotEmpty(messagesToSend, "messagesToSend"); | ||
this.context = ensureNotNull(context, "context"); | ||
ensureNotNull(context.streamingChatLanguageModel, "streamingChatLanguageModel"); | ||
} | ||
|
||
@Override | ||
public OnCompleteOrOnError onNext(Consumer<String> tokenHandler) { | ||
|
||
return new OnCompleteOrOnError() { | ||
|
||
@Override | ||
public OnError onComplete(Runnable completionHandler) { | ||
|
||
return new OnError() { | ||
|
||
@Override | ||
public OnStart onError(Consumer<Throwable> errorHandler) { | ||
return new AiServiceOnStart(tokenHandler, completionHandler, errorHandler); | ||
} | ||
|
||
@Override | ||
public OnStart ignoreErrors() { | ||
return new AiServiceOnStart(tokenHandler, completionHandler, null); | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public OnStart onError(Consumer<Throwable> errorHandler) { | ||
return new AiServiceOnStart(tokenHandler, null, errorHandler); | ||
} | ||
|
||
@Override | ||
public OnStart ignoreErrors() { | ||
return new AiServiceOnStart(tokenHandler, null, null); | ||
} | ||
}; | ||
} | ||
|
||
private class AiServiceOnStart implements OnStart { | ||
|
||
private final Consumer<String> tokenHandler; | ||
private final Runnable completionHandler; | ||
private final Consumer<Throwable> errorHandler; | ||
|
||
private AiServiceOnStart(Consumer<String> tokenHandler, | ||
Runnable completionHandler, | ||
Consumer<Throwable> errorHandler) { | ||
this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler"); | ||
this.completionHandler = completionHandler; | ||
this.errorHandler = errorHandler; | ||
} | ||
|
||
@Override | ||
public void start() { | ||
|
||
context.streamingChatLanguageModel.sendMessages( | ||
messagesToSend, | ||
context.toolSpecifications, | ||
new AiServiceStreamingResponseHandler( | ||
context, | ||
tokenHandler, | ||
completionHandler, | ||
errorHandler | ||
) | ||
); | ||
} | ||
} | ||
} |
Oops, something went wrong.