Skip to content

Commit

Permalink
Added support for streaming to AI Services (langchain4j#45)
Browse files Browse the repository at this point in the history
Now one can use TokenStream as a return type in his AI Service in order
to stream response from LLM
  • Loading branch information
langchain4j authored Jul 23, 2023
1 parent 529ef6b commit 80f71fe
Show file tree
Hide file tree
Showing 14 changed files with 557 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,22 @@ public ToolExecutor(Object object, Method method) {
this.method = method;
}

public String execute(Map<String, Object> argumentsMap) {
public String execute(ToolExecutionRequest toolExecutionRequest) {
log.debug("About to execute {}", toolExecutionRequest);

Object[] arguments = prepareArguments(argumentsMap);
// TODO ensure this method never throws exceptions

Object[] arguments = prepareArguments(toolExecutionRequest.argumentsAsMap());
try {
return execute(arguments);
String result = execute(arguments);
log.debug("Tool execution result: {}", result);
return result;
} catch (IllegalAccessException e) {
try {
method.setAccessible(true);
return execute(arguments);
String result = execute(arguments);
log.debug("Tool execution result: {}", result);
return result;
} catch (IllegalAccessException e2) {
throw new RuntimeException(e2);
} catch (InvocationTargetException e2) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package dev.langchain4j.internal;

import java.util.Collection;

import static dev.langchain4j.internal.Exceptions.illegalArgument;

public class ValidationUtils {
Expand All @@ -12,6 +14,14 @@ public static <T> T ensureNotNull(T object, String name) {
return object;
}

public static <T extends Collection<?>> T ensureNotEmpty(T collection, String name) {
if (collection == null || collection.isEmpty()) {
throw illegalArgument("%s cannot be null or empty", name);
}

return collection;
}

public static String ensureNotBlank(String string, String name) {
if (string == null || string.trim().isEmpty()) {
throw illegalArgument("%s cannot be null or blank", name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public int countTokens(String text) {
@Override
public int countTokens(ChatMessage message) {
return extraTokensPerEachMessage()
+ countTokens(message.text())
+ countTokens(message.text()) // TODO count functions
+ countTokens(roleFrom(message).toString());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package dev.langchain4j.service;

import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
import dev.langchain4j.retriever.Retriever;

import java.util.List;
import java.util.Map;

class AiServiceContext {

Class<?> aiServiceClass;

ChatLanguageModel chatLanguageModel;
StreamingChatLanguageModel streamingChatLanguageModel;

ChatMemory chatMemory;

ModerationModel moderationModel;

List<ToolSpecification> toolSpecifications;
Map<String, ToolExecutor> toolExecutors;

Retriever<TextSegment> retriever;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package dev.langchain4j.service;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.function.Consumer;

import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.data.message.ToolExecutionResultMessage.toolExecutionResultMessage;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

/**
* Handles response from LLM for AI Service that is streamed token-by-token.
* Handles both regular (text) responses and responses with the request to execute a tool.
*/
class AiServiceStreamingResponseHandler implements StreamingResponseHandler {

private final Logger log = LoggerFactory.getLogger(AiServiceStreamingResponseHandler.class);

private final AiServiceContext context;

private final Consumer<String> tokenHandler;
private final Runnable completionHandler;
private final Consumer<Throwable> errorHandler;

private final StringBuilder answerBuilder;
private final StringBuilder toolNameBuilder;
private final StringBuilder toolArgumentsBuilder;

AiServiceStreamingResponseHandler(AiServiceContext context,
Consumer<String> tokenHandler,
Runnable completionHandler,
Consumer<Throwable> errorHandler) {
this.context = ensureNotNull(context, "context");

this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler");
this.completionHandler = completionHandler;
this.errorHandler = errorHandler;

this.answerBuilder = new StringBuilder();
this.toolNameBuilder = new StringBuilder();
this.toolArgumentsBuilder = new StringBuilder();
}

@Override
public void onNext(String partialResult) {
answerBuilder.append(partialResult);
tokenHandler.accept(partialResult);
}

@Override
public void onToolName(String name) {
toolNameBuilder.append(name);
}

@Override
public void onToolArguments(String arguments) {
toolArgumentsBuilder.append(arguments);
}

@Override
public void onComplete() {

String toolName = toolNameBuilder.toString();
String toolArguments = toolArgumentsBuilder.toString();

if (toolName.isEmpty()) {
if (context.chatMemory != null) {
context.chatMemory.add(aiMessage(answerBuilder.toString()));
}
if (completionHandler != null) {
completionHandler.run();
}
} else {

ToolExecutionRequest toolExecutionRequest = ToolExecutionRequest.builder()
.name(toolName)
.arguments(toolArguments)
.build();

if (context.chatMemory != null) {
context.chatMemory.add(aiMessage(toolExecutionRequest));
}

ToolExecutor toolExecutor = context.toolExecutors.get(toolName); // TODO what if no such tool?
String toolExecutionResult = toolExecutor.execute(toolExecutionRequest);
ToolExecutionResultMessage toolExecutionResultMessage
= toolExecutionResultMessage(toolExecutionRequest.name(), toolExecutionResult);

context.chatMemory.add(toolExecutionResultMessage);

// TODO what if there are multiple tool executions in a row? (for the future)
context.streamingChatLanguageModel.sendMessages(
context.chatMemory.messages(),
// TODO does it make sense to send tools if LLM will not call them in the response anyway? (current openai behavior)
context.toolSpecifications,
new AiServiceStreamingResponseHandler(context, tokenHandler, completionHandler, errorHandler)
);
}
}

@Override
public void onError(Throwable error) {
if (errorHandler != null) {
try {
errorHandler.accept(error);
} catch (Exception e) {
log.error("While handling the following error...", error);
log.error("...the following error happened", e);
}
} else {
log.warn("Ignored error", error);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package dev.langchain4j.service;

import dev.langchain4j.data.message.ChatMessage;

import java.util.List;
import java.util.function.Consumer;

import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

class AiServiceTokenStream implements TokenStream {

private final List<ChatMessage> messagesToSend;
private final AiServiceContext context;

AiServiceTokenStream(List<ChatMessage> messagesToSend, AiServiceContext context) {
this.messagesToSend = ensureNotEmpty(messagesToSend, "messagesToSend");
this.context = ensureNotNull(context, "context");
ensureNotNull(context.streamingChatLanguageModel, "streamingChatLanguageModel");
}

@Override
public OnCompleteOrOnError onNext(Consumer<String> tokenHandler) {

return new OnCompleteOrOnError() {

@Override
public OnError onComplete(Runnable completionHandler) {

return new OnError() {

@Override
public OnStart onError(Consumer<Throwable> errorHandler) {
return new AiServiceOnStart(tokenHandler, completionHandler, errorHandler);
}

@Override
public OnStart ignoreErrors() {
return new AiServiceOnStart(tokenHandler, completionHandler, null);
}
};
}

@Override
public OnStart onError(Consumer<Throwable> errorHandler) {
return new AiServiceOnStart(tokenHandler, null, errorHandler);
}

@Override
public OnStart ignoreErrors() {
return new AiServiceOnStart(tokenHandler, null, null);
}
};
}

private class AiServiceOnStart implements OnStart {

private final Consumer<String> tokenHandler;
private final Runnable completionHandler;
private final Consumer<Throwable> errorHandler;

private AiServiceOnStart(Consumer<String> tokenHandler,
Runnable completionHandler,
Consumer<Throwable> errorHandler) {
this.tokenHandler = ensureNotNull(tokenHandler, "tokenHandler");
this.completionHandler = completionHandler;
this.errorHandler = errorHandler;
}

@Override
public void start() {

context.streamingChatLanguageModel.sendMessages(
messagesToSend,
context.toolSpecifications,
new AiServiceStreamingResponseHandler(
context,
tokenHandler,
completionHandler,
errorHandler
)
);
}
}
}
Loading

0 comments on commit 80f71fe

Please sign in to comment.