|
56 | 56 | import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
|
57 | 57 | import com.microsoft.semantickernel.services.chatcompletion.ChatHistory;
|
58 | 58 | import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent;
|
| 59 | +import com.microsoft.semantickernel.services.chatcompletion.StreamingChatContent; |
59 | 60 | import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageContentType;
|
60 | 61 | import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageImageContent;
|
61 | 62 | import com.microsoft.semantickernel.services.openai.OpenAiServiceBuilder;
|
|
64 | 65 | import java.util.Arrays;
|
65 | 66 | import java.util.Collections;
|
66 | 67 | import java.util.List;
|
| 68 | +import java.util.Locale; |
67 | 69 | import java.util.Map;
|
68 | 70 | import java.util.Objects;
|
69 | 71 | import java.util.stream.Collectors;
|
@@ -179,6 +181,92 @@ public Mono<List<ChatMessageContent<?>>> getChatMessageContentsAsync(
|
179 | 181 | });
|
180 | 182 | }
|
181 | 183 |
|
| 184 | + @Override |
| 185 | + public Flux<StreamingChatContent<?>> getStreamingChatMessageContentsAsync( |
| 186 | + ChatHistory chatHistory, |
| 187 | + @Nullable Kernel kernel, |
| 188 | + @Nullable InvocationContext invocationContext) { |
| 189 | + if (invocationContext != null && invocationContext.getToolCallBehavior() |
| 190 | + .isAutoInvokeAllowed()) { |
| 191 | + throw new SKException( |
| 192 | + "Auto invoke is not supported for streaming chat message contents"); |
| 193 | + } |
| 194 | + |
| 195 | + if (invocationContext != null |
| 196 | + && invocationContext.returnMode() != InvocationReturnMode.NEW_MESSAGES_ONLY) { |
| 197 | + throw new SKException( |
| 198 | + "Streaming chat message contents only supports NEW_MESSAGES_ONLY return mode"); |
| 199 | + } |
| 200 | + |
| 201 | + List<ChatRequestMessage> chatRequestMessages = getChatRequestMessages(chatHistory); |
| 202 | + |
| 203 | + ChatMessages messages = new ChatMessages(chatRequestMessages); |
| 204 | + |
| 205 | + List<OpenAIFunction> functions = new ArrayList<>(); |
| 206 | + if (kernel != null) { |
| 207 | + kernel.getPlugins() |
| 208 | + .forEach(plugin -> plugin.getFunctions().forEach((name, function) -> functions |
| 209 | + .add(OpenAIFunction.build(function.getMetadata(), plugin.getName())))); |
| 210 | + } |
| 211 | + |
| 212 | + ChatCompletionsOptions options = executeHook( |
| 213 | + invocationContext, |
| 214 | + kernel, |
| 215 | + new PreChatCompletionEvent( |
| 216 | + getCompletionsOptions( |
| 217 | + this, |
| 218 | + messages.allMessages, |
| 219 | + functions, |
| 220 | + invocationContext))) |
| 221 | + .getOptions(); |
| 222 | + |
| 223 | + return getClient() |
| 224 | + .getChatCompletionsStreamWithResponse( |
| 225 | + getDeploymentName(), |
| 226 | + options, |
| 227 | + OpenAIRequestSettings.getRequestOptions()) |
| 228 | + .flatMap(completionsResult -> { |
| 229 | + if (completionsResult.getStatusCode() >= 400) { |
| 230 | + //SemanticKernelTelemetry.endSpanWithError(span); |
| 231 | + return Mono.error(new AIException(ErrorCodes.SERVICE_ERROR, |
| 232 | + "Request failed: " + completionsResult.getStatusCode())); |
| 233 | + } |
| 234 | + //SemanticKernelTelemetry.endSpanWithUsage(span, completionsResult.getValue().getUsage()); |
| 235 | + |
| 236 | + return Mono.just(completionsResult.getValue()); |
| 237 | + }) |
| 238 | + .flatMap(completions -> { |
| 239 | + return Flux.fromIterable(completions.getChoices()) |
| 240 | + .map(message -> { |
| 241 | + AuthorRole role = message.getDelta().getRole() == null |
| 242 | + ? AuthorRole.ASSISTANT |
| 243 | + : AuthorRole.valueOf(message.getDelta().getRole().toString() |
| 244 | + .toUpperCase(Locale.ROOT)); |
| 245 | + |
| 246 | + return new OpenAIStreamingChatMessageContent<>( |
| 247 | + completions.getId(), |
| 248 | + role, |
| 249 | + message.getDelta().getContent(), |
| 250 | + getModelId(), |
| 251 | + null, |
| 252 | + null, |
| 253 | + null, |
| 254 | + Arrays.asList()); |
| 255 | + }); |
| 256 | + }); |
| 257 | + } |
| 258 | + |
| 259 | + @Override |
| 260 | + public Flux<StreamingChatContent<?>> getStreamingChatMessageContentsAsync( |
| 261 | + String prompt, |
| 262 | + @Nullable Kernel kernel, |
| 263 | + @Nullable InvocationContext invocationContext) { |
| 264 | + return getStreamingChatMessageContentsAsync( |
| 265 | + new ChatHistory().addUserMessage(prompt), |
| 266 | + kernel, |
| 267 | + invocationContext); |
| 268 | + } |
| 269 | + |
182 | 270 | // Holds messages temporarily as we build up our result
|
183 | 271 | private static class ChatMessages {
|
184 | 272 |
|
|
0 commit comments