Skip to content

Commit e085f69

Browse files
authored
Merge pull request #211 from johnoliver/streaming-2
Add streaming support to chat completions
2 parents 89a5479 + 93889d8 commit e085f69

File tree

19 files changed

+530
-115
lines changed

19 files changed

+530
-115
lines changed

aiservices/google/src/main/java/com/microsoft/semantickernel/aiservices/google/chatcompletion/GeminiChatCompletion.java

+49-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import com.google.protobuf.Value;
1717
import com.microsoft.semantickernel.Kernel;
1818
import com.microsoft.semantickernel.aiservices.google.GeminiService;
19+
import com.microsoft.semantickernel.aiservices.google.GeminiServiceBuilder;
1920
import com.microsoft.semantickernel.aiservices.google.implementation.MonoConverter;
2021
import com.microsoft.semantickernel.contextvariables.ContextVariableTypes;
2122
import com.microsoft.semantickernel.exceptions.AIException;
@@ -36,7 +37,7 @@
3637
import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
3738
import com.microsoft.semantickernel.services.chatcompletion.ChatHistory;
3839
import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent;
39-
import com.microsoft.semantickernel.aiservices.google.GeminiServiceBuilder;
40+
import com.microsoft.semantickernel.services.chatcompletion.StreamingChatContent;
4041
import java.io.IOException;
4142
import java.time.OffsetDateTime;
4243
import java.util.ArrayList;
@@ -75,6 +76,53 @@ public Mono<List<ChatMessageContent<?>>> getChatMessageContentsAsync(String prom
7576
invocationContext);
7677
}
7778

79+
@Override
80+
public Flux<StreamingChatContent<?>> getStreamingChatMessageContentsAsync(
81+
ChatHistory chatHistory,
82+
@Nullable Kernel kernel,
83+
@Nullable InvocationContext invocationContext) {
84+
85+
LOGGER.warn("Streaming has been called on GeminiChatCompletion service. "
86+
+ "This is currently not supported in Gemini. "
87+
+ "The results will be returned in a non streaming fashion.");
88+
89+
return getChatMessageContentsAsync(chatHistory, kernel, invocationContext)
90+
.flatMapIterable(chatMessageContents -> chatMessageContents)
91+
.map(content -> {
92+
return new GeminiStreamingChatMessageContent(
93+
content.getAuthorRole(),
94+
content.getContent(),
95+
getModelId(),
96+
content.getInnerContent(),
97+
content.getEncoding(),
98+
content.getMetadata(),
99+
null,
100+
UUID.randomUUID().toString());
101+
});
102+
}
103+
104+
@Override
105+
public Flux<StreamingChatContent<?>> getStreamingChatMessageContentsAsync(String prompt,
106+
@Nullable Kernel kernel, @Nullable InvocationContext invocationContext) {
107+
LOGGER.warn("Streaming has been called on GeminiChatCompletion service. "
108+
+ "This is currently not supported in Gemini. "
109+
+ "The results will be returned in a non streaming fashion.");
110+
111+
return getChatMessageContentsAsync(prompt, kernel, invocationContext)
112+
.flatMapIterable(chatMessageContents -> chatMessageContents)
113+
.map(content -> {
114+
return new GeminiStreamingChatMessageContent(
115+
content.getAuthorRole(),
116+
content.getContent(),
117+
getModelId(),
118+
content.getInnerContent(),
119+
content.getEncoding(),
120+
content.getMetadata(),
121+
null,
122+
UUID.randomUUID().toString());
123+
});
124+
}
125+
78126
@Override
79127
public Mono<List<ChatMessageContent<?>>> getChatMessageContentsAsync(ChatHistory chatHistory,
80128
@Nullable Kernel kernel, @Nullable InvocationContext invocationContext) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
package com.microsoft.semantickernel.aiservices.google.chatcompletion;
3+
4+
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
5+
import com.microsoft.semantickernel.services.chatcompletion.AuthorRole;
6+
import com.microsoft.semantickernel.services.chatcompletion.StreamingChatContent;
7+
import java.nio.charset.Charset;
8+
import java.util.List;
9+
import javax.annotation.Nullable;
10+
11+
/**
12+
* Represents the content of a chat message.
13+
*
14+
* @param <T> The type of the inner content.
15+
*/
16+
public class GeminiStreamingChatMessageContent<T> extends GeminiChatMessageContent<T> implements
17+
StreamingChatContent<T> {
18+
19+
private final String id;
20+
21+
/**
22+
* Creates a new instance of the {@link GeminiChatMessageContent} class.
23+
*
24+
* @param authorRole The author role that generated the content.
25+
* @param content The content.
26+
* @param modelId The model id.
27+
* @param innerContent The inner content.
28+
* @param encoding The encoding.
29+
* @param metadata The metadata.
30+
* @param geminiFunctionCalls The function calls.
31+
*/
32+
public GeminiStreamingChatMessageContent(AuthorRole authorRole, String content,
33+
@Nullable String modelId, @Nullable T innerContent, @Nullable Charset encoding,
34+
@Nullable FunctionResultMetadata metadata,
35+
@Nullable List<GeminiFunctionCall> geminiFunctionCalls,
36+
String id) {
37+
super(authorRole, content, modelId, innerContent, encoding, metadata, geminiFunctionCalls);
38+
this.id = id;
39+
}
40+
41+
@Override
42+
public String getId() {
43+
return id;
44+
}
45+
}

semantickernel-api/src/main/java/com/microsoft/semantickernel/services/textcompletion/StreamingTextContent.java renamed to aiservices/google/src/main/java/com/microsoft/semantickernel/aiservices/google/textcompletion/GeminiStreamingTextContent.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
// Copyright (c) Microsoft. All rights reserved.
2-
package com.microsoft.semantickernel.services.textcompletion;
2+
package com.microsoft.semantickernel.aiservices.google.textcompletion;
33

4-
import com.microsoft.semantickernel.services.StreamingKernelContent;
4+
import com.microsoft.semantickernel.services.StreamingTextContent;
5+
import com.microsoft.semantickernel.services.textcompletion.TextContent;
56
import javax.annotation.Nullable;
67

78
/**
89
* StreamingTextContent is a wrapper for TextContent that allows for streaming.
910
*/
10-
public class StreamingTextContent extends StreamingKernelContent<TextContent> {
11+
public class GeminiStreamingTextContent extends StreamingTextContent<TextContent> {
1112

1213
/**
1314
* Initializes a new instance of the {@code StreamingTextContent} class with a provided text
1415
* content.
1516
*
1617
* @param content The text content.
1718
*/
18-
public StreamingTextContent(TextContent content) {
19+
public GeminiStreamingTextContent(TextContent content) {
1920
super(content, 0, null, null);
2021
}
2122

aiservices/google/src/main/java/com/microsoft/semantickernel/aiservices/google/textcompletion/GeminiTextGenerationService.java

+10-10
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,29 @@
77
import com.google.cloud.vertexai.generativeai.GenerativeModel;
88
import com.microsoft.semantickernel.Kernel;
99
import com.microsoft.semantickernel.aiservices.google.GeminiService;
10+
import com.microsoft.semantickernel.aiservices.google.GeminiServiceBuilder;
1011
import com.microsoft.semantickernel.aiservices.google.implementation.MonoConverter;
1112
import com.microsoft.semantickernel.exceptions.AIException;
1213
import com.microsoft.semantickernel.exceptions.SKCheckedException;
1314
import com.microsoft.semantickernel.exceptions.SKException;
1415
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
1516
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
16-
import com.microsoft.semantickernel.aiservices.google.GeminiServiceBuilder;
17-
import com.microsoft.semantickernel.services.textcompletion.StreamingTextContent;
17+
import com.microsoft.semantickernel.services.StreamingTextContent;
1818
import com.microsoft.semantickernel.services.textcompletion.TextContent;
1919
import com.microsoft.semantickernel.services.textcompletion.TextGenerationService;
20-
import reactor.core.publisher.Flux;
21-
import reactor.core.publisher.Mono;
22-
23-
import org.slf4j.Logger;
24-
import org.slf4j.LoggerFactory;
25-
26-
import javax.annotation.Nullable;
2720
import java.io.IOException;
2821
import java.time.OffsetDateTime;
2922
import java.util.ArrayList;
3023
import java.util.List;
3124
import java.util.UUID;
25+
import javax.annotation.Nullable;
26+
import org.slf4j.Logger;
27+
import org.slf4j.LoggerFactory;
28+
import reactor.core.publisher.Flux;
29+
import reactor.core.publisher.Mono;
3230

3331
public class GeminiTextGenerationService extends GeminiService implements TextGenerationService {
32+
3433
private static final Logger LOGGER = LoggerFactory.getLogger(GeminiTextGenerationService.class);
3534

3635
public GeminiTextGenerationService(VertexAI client, String modelId) {
@@ -57,7 +56,7 @@ public Flux<StreamingTextContent> getStreamingTextContentsAsync(
5756
return this
5857
.internalGetTextAsync(prompt, executionSettings)
5958
.flatMapMany(it -> Flux.fromStream(it.stream())
60-
.map(StreamingTextContent::new));
59+
.map(GeminiStreamingTextContent::new));
6160
}
6261

6362
private Mono<List<TextContent>> internalGetTextAsync(String prompt,
@@ -124,6 +123,7 @@ private GenerativeModel getGenerativeModel(
124123

125124
public static class Builder extends
126125
GeminiServiceBuilder<GeminiTextGenerationService, GeminiTextGenerationService.Builder> {
126+
127127
@Override
128128
public GeminiTextGenerationService build() {
129129
if (this.client == null) {

aiservices/huggingface/src/main/java/com/microsoft/semantickernel/aiservices/huggingface/services/HuggingFaceTextGenerationService.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import com.microsoft.semantickernel.exceptions.SKException;
1010
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
1111
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
12-
import com.microsoft.semantickernel.services.textcompletion.StreamingTextContent;
12+
import com.microsoft.semantickernel.services.StreamingTextContent;
1313
import com.microsoft.semantickernel.services.textcompletion.TextContent;
1414
import com.microsoft.semantickernel.services.textcompletion.TextGenerationService;
1515
import java.util.List;

aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java

+88
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
5757
import com.microsoft.semantickernel.services.chatcompletion.ChatHistory;
5858
import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent;
59+
import com.microsoft.semantickernel.services.chatcompletion.StreamingChatContent;
5960
import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageContentType;
6061
import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageImageContent;
6162
import com.microsoft.semantickernel.services.openai.OpenAiServiceBuilder;
@@ -64,6 +65,7 @@
6465
import java.util.Arrays;
6566
import java.util.Collections;
6667
import java.util.List;
68+
import java.util.Locale;
6769
import java.util.Map;
6870
import java.util.Objects;
6971
import java.util.stream.Collectors;
@@ -179,6 +181,92 @@ public Mono<List<ChatMessageContent<?>>> getChatMessageContentsAsync(
179181
});
180182
}
181183

184+
@Override
185+
public Flux<StreamingChatContent<?>> getStreamingChatMessageContentsAsync(
186+
ChatHistory chatHistory,
187+
@Nullable Kernel kernel,
188+
@Nullable InvocationContext invocationContext) {
189+
if (invocationContext != null && invocationContext.getToolCallBehavior()
190+
.isAutoInvokeAllowed()) {
191+
throw new SKException(
192+
"Auto invoke is not supported for streaming chat message contents");
193+
}
194+
195+
if (invocationContext != null
196+
&& invocationContext.returnMode() != InvocationReturnMode.NEW_MESSAGES_ONLY) {
197+
throw new SKException(
198+
"Streaming chat message contents only supports NEW_MESSAGES_ONLY return mode");
199+
}
200+
201+
List<ChatRequestMessage> chatRequestMessages = getChatRequestMessages(chatHistory);
202+
203+
ChatMessages messages = new ChatMessages(chatRequestMessages);
204+
205+
List<OpenAIFunction> functions = new ArrayList<>();
206+
if (kernel != null) {
207+
kernel.getPlugins()
208+
.forEach(plugin -> plugin.getFunctions().forEach((name, function) -> functions
209+
.add(OpenAIFunction.build(function.getMetadata(), plugin.getName()))));
210+
}
211+
212+
ChatCompletionsOptions options = executeHook(
213+
invocationContext,
214+
kernel,
215+
new PreChatCompletionEvent(
216+
getCompletionsOptions(
217+
this,
218+
messages.allMessages,
219+
functions,
220+
invocationContext)))
221+
.getOptions();
222+
223+
return getClient()
224+
.getChatCompletionsStreamWithResponse(
225+
getDeploymentName(),
226+
options,
227+
OpenAIRequestSettings.getRequestOptions())
228+
.flatMap(completionsResult -> {
229+
if (completionsResult.getStatusCode() >= 400) {
230+
//SemanticKernelTelemetry.endSpanWithError(span);
231+
return Mono.error(new AIException(ErrorCodes.SERVICE_ERROR,
232+
"Request failed: " + completionsResult.getStatusCode()));
233+
}
234+
//SemanticKernelTelemetry.endSpanWithUsage(span, completionsResult.getValue().getUsage());
235+
236+
return Mono.just(completionsResult.getValue());
237+
})
238+
.flatMap(completions -> {
239+
return Flux.fromIterable(completions.getChoices())
240+
.map(message -> {
241+
AuthorRole role = message.getDelta().getRole() == null
242+
? AuthorRole.ASSISTANT
243+
: AuthorRole.valueOf(message.getDelta().getRole().toString()
244+
.toUpperCase(Locale.ROOT));
245+
246+
return new OpenAIStreamingChatMessageContent<>(
247+
completions.getId(),
248+
role,
249+
message.getDelta().getContent(),
250+
getModelId(),
251+
null,
252+
null,
253+
null,
254+
Arrays.asList());
255+
});
256+
});
257+
}
258+
259+
@Override
260+
public Flux<StreamingChatContent<?>> getStreamingChatMessageContentsAsync(
261+
String prompt,
262+
@Nullable Kernel kernel,
263+
@Nullable InvocationContext invocationContext) {
264+
return getStreamingChatMessageContentsAsync(
265+
new ChatHistory().addUserMessage(prompt),
266+
kernel,
267+
invocationContext);
268+
}
269+
182270
// Holds messages temporarily as we build up our result
183271
private static class ChatMessages {
184272

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
package com.microsoft.semantickernel.aiservices.openai.chatcompletion;
3+
4+
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
5+
import com.microsoft.semantickernel.services.chatcompletion.AuthorRole;
6+
import com.microsoft.semantickernel.services.chatcompletion.StreamingChatContent;
7+
import java.nio.charset.Charset;
8+
import java.util.List;
9+
import javax.annotation.Nullable;
10+
11+
public class OpenAIStreamingChatMessageContent<T> extends OpenAIChatMessageContent<T> implements
12+
StreamingChatContent<T> {
13+
14+
private final String id;
15+
16+
public OpenAIStreamingChatMessageContent(
17+
String id,
18+
AuthorRole authorRole,
19+
String content,
20+
@Nullable String modelId,
21+
@Nullable T innerContent,
22+
@Nullable Charset encoding,
23+
@Nullable FunctionResultMetadata metadata,
24+
@Nullable List<OpenAIFunctionToolCall> toolCall) {
25+
super(
26+
authorRole,
27+
content,
28+
modelId,
29+
innerContent,
30+
encoding,
31+
metadata,
32+
toolCall);
33+
34+
this.id = id;
35+
}
36+
37+
@Override
38+
public String getId() {
39+
return id;
40+
}
41+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
package com.microsoft.semantickernel.aiservices.openai.textcompletion;
3+
4+
import com.microsoft.semantickernel.services.StreamingTextContent;
5+
import com.microsoft.semantickernel.services.textcompletion.TextContent;
6+
7+
import javax.annotation.Nullable;
8+
9+
/**
10+
* StreamingTextContent is a wrapper for TextContent that allows for streaming.
11+
*/
12+
public class OpenAIStreamingTextContent extends StreamingTextContent<TextContent> {
13+
14+
/**
15+
* Initializes a new instance of the {@code StreamingTextContent} class with a provided text
16+
* content.
17+
*
18+
* @param content The text content.
19+
*/
20+
public OpenAIStreamingTextContent(TextContent content) {
21+
super(content, 0, null, null);
22+
}
23+
24+
@Override
25+
@Nullable
26+
public String getContent() {
27+
TextContent content = getInnerContent();
28+
if (content == null) {
29+
return null;
30+
}
31+
return content.getContent();
32+
}
33+
34+
}

0 commit comments

Comments
 (0)