Skip to content

Commit

Permalink
DashScope: Encapsulate the calls to listeners into methods for better…
Browse files Browse the repository at this point in the history
… reuse. (langchain4j#1967)

## Change
Encapsulate the calls to listeners into methods for better reuse.

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [ ] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
  • Loading branch information
jiangsier-xyz authored Oct 22, 2024
1 parent c858b0f commit 323ce96
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 489 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,7 @@ private Response<AiMessage> generateByNonMultimodalModel(List<ChatMessage> messa

ChatModelRequest modelListenerRequest = createModelListenerRequest(param, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
onListenRequest(listeners, modelListenerRequest, attributes);

try {
GenerationResult result = generation.call(param);
Expand All @@ -165,41 +158,10 @@ private Response<AiMessage> generateByNonMultimodalModel(List<ChatMessage> messa
finishReasonFrom(result)
);

ChatModelResponse modelListenerResponse = createModelListenerResponse(
result.getRequestId(),
param.getModel(),
response
);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse,
modelListenerRequest,
attributes
);
listeners.forEach(listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});

onListenResponse(listeners, result.getRequestId(), response, modelListenerRequest, attributes);
return response;
} catch (NoApiKeyException | InputRequiredException | RuntimeException e) {
ChatModelErrorContext errorContext = new ChatModelErrorContext(
e,
modelListenerRequest,
null,
attributes
);

listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});

onListenError(listeners, null, e, modelListenerRequest, null, attributes);
throw e instanceof RuntimeException ?
(RuntimeException) e : new RuntimeException(e);
}
Expand All @@ -226,14 +188,7 @@ private Response<AiMessage> generateByMultimodalModel(List<ChatMessage> messages

ChatModelRequest modelListenerRequest = createModelListenerRequest(param, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
onListenRequest(listeners, modelListenerRequest, attributes);

try {
MultiModalConversationResult result = conv.call(param);
Expand All @@ -242,41 +197,10 @@ private Response<AiMessage> generateByMultimodalModel(List<ChatMessage> messages
Response<AiMessage> response = Response.from(AiMessage.from(answer),
tokenUsageFrom(result), finishReasonFrom(result));

ChatModelResponse modelListenerResponse = createModelListenerResponse(
result.getRequestId(),
param.getModel(),
response
);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse,
modelListenerRequest,
attributes
);
listeners.forEach(listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});

onListenResponse(listeners, result.getRequestId(), response, modelListenerRequest, attributes);
return response;
} catch (NoApiKeyException | UploadFileException | RuntimeException e) {
ChatModelErrorContext errorContext = new ChatModelErrorContext(
e,
modelListenerRequest,
null,
attributes
);

listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});

onListenError(listeners, null, e, modelListenerRequest, null, attributes);
throw e instanceof RuntimeException ?
(RuntimeException) e : new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
Expand Down Expand Up @@ -574,4 +573,61 @@ static ChatModelResponse createModelListenerResponse(String responseId,
.aiMessage(response.content())
.build();
}

static void onListenRequest(List<ChatModelListener> listeners,
ChatModelRequest modelListenerRequest,
Map<Object, Object> attributes) {
ChatModelRequestContext context = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onRequest(context);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
}

static void onListenResponse(List<ChatModelListener> listeners,
String responseId,
Response<AiMessage> response,
ChatModelRequest modelListenerRequest,
Map<Object, Object> attributes) {
ChatModelResponse modelListenerResponse = createModelListenerResponse(
responseId, modelListenerRequest.model(), response);

ChatModelResponseContext context = new ChatModelResponseContext(
modelListenerResponse,
modelListenerRequest,
attributes
);
listeners.forEach(listener -> {
try {
listener.onResponse(context);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
}

static void onListenError(List<ChatModelListener> listeners,
String responseId,
Throwable error,
ChatModelRequest modelListenerRequest,
Response<AiMessage> partialResponse, Map<Object, Object> attributes) {
ChatModelResponse partialModelListenerResponse = createModelListenerResponse(
responseId, modelListenerRequest.model(), partialResponse);
ChatModelErrorContext context = new ChatModelErrorContext(
error,
modelListenerRequest,
partialModelListenerResponse,
attributes
);
listeners.forEach(listener -> {
try {
listener.onError(context);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,7 @@ private void generateByNonMultimodalModel(List<ChatMessage> messages, StreamingR

ChatModelRequest modelListenerRequest = createModelListenerRequest(param, messages, null);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
onListenRequest(listeners, modelListenerRequest, attributes);

QwenStreamingResponseBuilder responseBuilder = new QwenStreamingResponseBuilder();
AtomicReference<String> responseId = new AtomicReference<>();
Expand All @@ -159,72 +152,18 @@ public void onEvent(GenerationResult result) {
@Override
public void onComplete() {
Response<AiMessage> response = responseBuilder.build();

ChatModelResponse modelListenerResponse = createModelListenerResponse(
responseId.get(),
param.getModel(),
response
);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse,
modelListenerRequest,
attributes
);
listeners.forEach(listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});

onListenResponse(listeners, responseId.get(), response, modelListenerRequest, attributes);
handler.onComplete(response);
}

@Override
public void onError(Exception e) {
Response<AiMessage> response = responseBuilder.build();

ChatModelResponse modelListenerPartialResponse = createModelListenerResponse(
responseId.get(),
param.getModel(),
response
);

ChatModelErrorContext errorContext = new ChatModelErrorContext(
e,
modelListenerRequest,
modelListenerPartialResponse,
attributes
);

listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception ex) {
log.warn("Exception while calling model listener", ex);
}
});

onListenError(listeners, responseId.get(), e, modelListenerRequest, responseBuilder.build(), attributes);
handler.onError(e);
}
});
} catch (NoApiKeyException | InputRequiredException | RuntimeException e) {
ChatModelErrorContext errorContext = new ChatModelErrorContext(
e,
modelListenerRequest,
null,
attributes
);

listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});

onListenError(listeners, null, e, modelListenerRequest, null, attributes);
throw e instanceof RuntimeException ?
(RuntimeException) e : new RuntimeException(e);
}
Expand All @@ -246,14 +185,7 @@ private void generateByMultimodalModel(List<ChatMessage> messages, StreamingResp

ChatModelRequest modelListenerRequest = createModelListenerRequest(param, messages, null);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
onListenRequest(listeners, modelListenerRequest, attributes);

QwenStreamingResponseBuilder responseBuilder = new QwenStreamingResponseBuilder();
AtomicReference<String> responseId = new AtomicReference<>();
Expand All @@ -275,72 +207,18 @@ public void onEvent(MultiModalConversationResult result) {
@Override
public void onComplete() {
Response<AiMessage> response = responseBuilder.build();

ChatModelResponse modelListenerResponse = createModelListenerResponse(
responseId.get(),
param.getModel(),
response
);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse,
modelListenerRequest,
attributes
);
listeners.forEach(listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});

onListenResponse(listeners, responseId.get(), response, modelListenerRequest, attributes);
handler.onComplete(response);
}

@Override
public void onError(Exception e) {
Response<AiMessage> response = responseBuilder.build();

ChatModelResponse modelListenerPartialResponse = createModelListenerResponse(
responseId.get(),
param.getModel(),
response
);

ChatModelErrorContext errorContext = new ChatModelErrorContext(
e,
modelListenerRequest,
modelListenerPartialResponse,
attributes
);

listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception ex) {
log.warn("Exception while calling model listener", ex);
}
});

onListenError(listeners, responseId.get(), e, modelListenerRequest, responseBuilder.build(), attributes);
handler.onError(e);
}
});
} catch (NoApiKeyException | UploadFileException | InputRequiredException | RuntimeException e) {
ChatModelErrorContext errorContext = new ChatModelErrorContext(
e,
modelListenerRequest,
null,
attributes
);

listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});

onListenError(listeners, null, e, modelListenerRequest, null, attributes);
throw e instanceof RuntimeException ?
(RuntimeException) e : new RuntimeException(e);
}
Expand Down
Loading

0 comments on commit 323ce96

Please sign in to comment.