Skip to content

Commit 88e7fe7

Browse files
authored
Merge pull request microsoft#162 from dsgrieve/main
Make ChatHistory thread safe
2 parents b5b873b + 4b729df commit 88e7fe7

File tree

3 files changed

+34
-29
lines changed

3 files changed

+34
-29
lines changed

aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ private static class ChatMessages {
184184

185185
private final List<ChatRequestMessage> newMessages;
186186
private final List<ChatRequestMessage> allMessages;
187-
private final List<OpenAIChatMessageContent> newChatMessageContent;
187+
private final List<OpenAIChatMessageContent<?>> newChatMessageContent;
188188

189189
public ChatMessages(List<ChatRequestMessage> allMessages) {
190190
this.allMessages = Collections.unmodifiableList(allMessages);
@@ -195,7 +195,7 @@ public ChatMessages(List<ChatRequestMessage> allMessages) {
195195
private ChatMessages(
196196
List<ChatRequestMessage> allMessages,
197197
List<ChatRequestMessage> newMessages,
198-
List<OpenAIChatMessageContent> newChatMessageContent) {
198+
List<OpenAIChatMessageContent<?>> newChatMessageContent) {
199199
this.allMessages = Collections.unmodifiableList(allMessages);
200200
this.newMessages = Collections.unmodifiableList(newMessages);
201201
this.newChatMessageContent = Collections.unmodifiableList(newChatMessageContent);
@@ -219,8 +219,8 @@ public ChatMessages add(ChatRequestMessage requestMessage) {
219219
}
220220

221221
@CheckReturnValue
222-
public ChatMessages addChatMessage(List<OpenAIChatMessageContent> chatMessageContent) {
223-
ArrayList<OpenAIChatMessageContent> tmpChatMessageContent = new ArrayList<>(
222+
public ChatMessages addChatMessage(List<OpenAIChatMessageContent<?>> chatMessageContent) {
223+
ArrayList<OpenAIChatMessageContent<?>> tmpChatMessageContent = new ArrayList<>(
224224
newChatMessageContent);
225225
tmpChatMessageContent.addAll(chatMessageContent);
226226

@@ -357,19 +357,16 @@ private Mono<ChatMessages> internalChatMessageContentsAsync(
357357
// If we don't want to attempt to invoke any functions
358358
// Or if we are auto-invoking, but we somehow end up with other than 1 choice even though only 1 was requested
359359
if (autoInvokeAttempts == 0 || responseMessages.size() != 1) {
360-
return getChatMessageContentsAsync(completions)
361-
.flatMap(m -> {
362-
return Mono.just(messages.addChatMessage(m));
363-
});
360+
List<OpenAIChatMessageContent<?>> chatMessageContents = getChatMessageContentsAsync(completions);
361+
return Mono.just(messages.addChatMessage(chatMessageContents));
364362
}
365363
// Or if there are no tool calls to be done
366364
ChatResponseMessage response = responseMessages.get(0);
367365
List<ChatCompletionsToolCall> toolCalls = response.getToolCalls();
368366
if (toolCalls == null || toolCalls.isEmpty()) {
369-
return getChatMessageContentsAsync(completions)
370-
.flatMap(m -> {
371-
return Mono.just(messages.addChatMessage(m));
372-
});
367+
List<OpenAIChatMessageContent<?>> chatMessageContents = getChatMessageContentsAsync(
368+
completions);
369+
return Mono.just(messages.addChatMessage(chatMessageContents));
373370
}
374371

375372
ChatRequestAssistantMessage requestMessage = new ChatRequestAssistantMessage(
@@ -592,7 +589,7 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall(
592589
arguments);
593590
}
594591

595-
private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
592+
private List<OpenAIChatMessageContent<?>> getChatMessageContentsAsync(
596593
ChatCompletions completions) {
597594
FunctionResultMetadata<CompletionsUsage> completionMetadata = FunctionResultMetadata.build(
598595
completions.getId(),
@@ -606,22 +603,28 @@ private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
606603
.filter(Objects::nonNull)
607604
.collect(Collectors.toList());
608605

609-
return Flux.fromIterable(responseMessages)
610-
.flatMap(response -> {
606+
List<OpenAIChatMessageContent<?>> chatMessageContent =
607+
responseMessages
608+
.stream()
609+
.map(response -> {
611610
try {
612-
return Mono.just(new OpenAIChatMessageContent(
611+
return new OpenAIChatMessageContent<>(
613612
AuthorRole.ASSISTANT,
614613
response.getContent(),
615614
this.getModelId(),
616615
null,
617616
null,
618617
completionMetadata,
619-
formOpenAiToolCalls(response)));
620-
} catch (Exception e) {
621-
return Mono.error(e);
618+
formOpenAiToolCalls(response));
619+
} catch (SKCheckedException e) {
620+
LOGGER.warn("Failed to form chat message content", e);
621+
return null;
622622
}
623623
})
624-
.collectList();
624+
.filter(Objects::nonNull)
625+
.collect(Collectors.toList());
626+
627+
return chatMessageContent;
625628
}
626629

627630
private List<ChatMessageContent<?>> toOpenAIChatMessageContent(
@@ -931,7 +934,7 @@ private static boolean hasToolCallBeenExecuted(List<ChatRequestMessage> chatRequ
931934
}
932935

933936
private static List<ChatRequestMessage> getChatRequestMessages(
934-
List<? extends ChatMessageContent> messages) {
937+
List<? extends ChatMessageContent<?>> messages) {
935938
if (messages == null || messages.isEmpty()) {
936939
return new ArrayList<>();
937940
}

aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public OpenAIChatMessageContent(
3636
@Nullable String modelId,
3737
@Nullable T innerContent,
3838
@Nullable Charset encoding,
39-
@Nullable FunctionResultMetadata metadata,
39+
@Nullable FunctionResultMetadata<?> metadata,
4040
@Nullable List<OpenAIFunctionToolCall> toolCall) {
4141
super(authorRole, content, modelId, innerContent, encoding, metadata);
4242

semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageTextContent;
66
import java.nio.charset.Charset;
77
import java.util.ArrayList;
8+
import java.util.Collection;
89
import java.util.Collections;
910
import java.util.Iterator;
1011
import java.util.List;
1112
import java.util.Optional;
1213
import java.util.Spliterator;
14+
import java.util.concurrent.ConcurrentLinkedQueue;
1315
import java.util.function.Consumer;
1416
import javax.annotation.Nullable;
1517

@@ -18,7 +20,7 @@
1820
*/
1921
public class ChatHistory implements Iterable<ChatMessageContent<?>> {
2022

21-
private final List<ChatMessageContent<?>> chatMessageContents;
23+
private final Collection<ChatMessageContent<?>> chatMessageContents;
2224

2325
/**
2426
* The default constructor
@@ -33,7 +35,7 @@ public ChatHistory() {
3335
* @param instructions The instructions to add to the chat history
3436
*/
3537
public ChatHistory(@Nullable String instructions) {
36-
this.chatMessageContents = new ArrayList<>();
38+
this.chatMessageContents = new ConcurrentLinkedQueue<>();
3739
if (instructions != null) {
3840
this.chatMessageContents.add(
3941
ChatMessageTextContent.systemMessage(instructions));
@@ -45,8 +47,8 @@ public ChatHistory(@Nullable String instructions) {
4547
*
4648
* @param chatMessageContents The chat message contents to add to the chat history
4749
*/
48-
public ChatHistory(List<? extends ChatMessageContent> chatMessageContents) {
49-
this.chatMessageContents = new ArrayList(chatMessageContents);
50+
public ChatHistory(List<? extends ChatMessageContent<?>> chatMessageContents) {
51+
this.chatMessageContents = new ConcurrentLinkedQueue<>(chatMessageContents);
5052
}
5153

5254
/**
@@ -55,7 +57,7 @@ public ChatHistory(List<? extends ChatMessageContent> chatMessageContents) {
5557
* @return List of messages in the chat
5658
*/
5759
public List<ChatMessageContent<?>> getMessages() {
58-
return Collections.unmodifiableList(chatMessageContents);
60+
return Collections.unmodifiableList(new ArrayList<>(chatMessageContents));
5961
}
6062

6163
/**
@@ -67,7 +69,7 @@ public Optional<ChatMessageContent<?>> getLastMessage() {
6769
if (chatMessageContents.isEmpty()) {
6870
return Optional.empty();
6971
}
70-
return Optional.of(chatMessageContents.get(chatMessageContents.size() - 1));
72+
return Optional.of(((ConcurrentLinkedQueue<ChatMessageContent<?>>)chatMessageContents).peek());
7173
}
7274

7375
/**
@@ -114,7 +116,7 @@ public Spliterator<ChatMessageContent<?>> spliterator() {
114116
* @param metadata The metadata of the message
115117
*/
116118
public ChatHistory addMessage(AuthorRole authorRole, String content, Charset encoding,
117-
FunctionResultMetadata metadata) {
119+
FunctionResultMetadata<?> metadata) {
118120
chatMessageContents.add(
119121
ChatMessageTextContent.builder()
120122
.withAuthorRole(authorRole)

0 commit comments

Comments
 (0)