Skip to content

Commit 469d46d

Browse files
committed
Address feedback
1 parent 47d3d76 commit 469d46d

File tree

7 files changed

+156
-74
lines changed

7 files changed

+156
-74
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponse.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public string Text
8585
{
8686
0 => string.Empty,
8787
1 => messages[0].Text,
88-
_ => messages.SelectMany(m => m.Contents).ConcatText(),
88+
_ => string.Join(Environment.NewLine, messages.Select(m => m.Text).Where(s => !string.IsNullOrEmpty(s))),
8989
};
9090
}
9191
}

src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ From the command-line:
1010
dotnet add package Microsoft.Extensions.AI.Abstractions
1111
```
1212

13-
Or directly in the C# project file:
13+
or directly in the C# project file:
1414

1515
```xml
1616
<ItemGroup>

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -112,59 +112,73 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
112112
}
113113

114114
// Process each update.
115-
await foreach (var update in updates.ConfigureAwait(false))
115+
List<ChatResponseUpdate> responseUpdates = [];
116+
try
116117
{
117-
switch (update)
118+
string? responseId = null;
119+
await foreach (var update in updates.ConfigureAwait(false))
118120
{
119-
case MessageContentUpdate mcu:
120-
yield return new(mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, mcu.Text)
121-
{
122-
ChatThreadId = threadId,
123-
RawRepresentation = mcu,
124-
};
125-
break;
121+
switch (update)
122+
{
123+
case MessageContentUpdate mcu:
124+
ChatResponseUpdate responseUpdate = new(mcu.Role == MessageRole.User ? ChatRole.User : ChatRole.Assistant, mcu.Text)
125+
{
126+
ChatThreadId = threadId,
127+
RawRepresentation = mcu,
128+
ResponseId = responseId,
129+
};
130+
responseUpdates.Add(responseUpdate);
131+
yield return responseUpdate;
132+
break;
126133

127-
case ThreadUpdate tu when options is not null:
128-
threadId ??= tu.Value.Id;
129-
break;
134+
case ThreadUpdate tu when options is not null:
135+
threadId ??= tu.Value.Id;
136+
break;
130137

131-
case RunUpdate ru:
132-
threadId ??= ru.Value.ThreadId;
138+
case RunUpdate ru:
139+
threadId ??= ru.Value.ThreadId;
140+
responseId ??= ru.Value.Id;
133141

134-
ChatResponseUpdate ruUpdate = new()
135-
{
136-
AuthorName = ru.Value.AssistantId,
137-
ChatThreadId = threadId,
138-
CreatedAt = ru.Value.CreatedAt,
139-
ModelId = ru.Value.Model,
140-
RawRepresentation = ru,
141-
ResponseId = ru.Value.Id,
142-
Role = ChatRole.Assistant,
143-
};
144-
145-
if (ru.Value.Usage is { } usage)
146-
{
147-
ruUpdate.Contents.Add(new UsageContent(new()
142+
ChatResponseUpdate ruUpdate = new()
148143
{
149-
InputTokenCount = usage.InputTokenCount,
150-
OutputTokenCount = usage.OutputTokenCount,
151-
TotalTokenCount = usage.TotalTokenCount,
152-
}));
153-
}
154-
155-
if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName)
156-
{
157-
ruUpdate.Contents.Add(
158-
new FunctionCallContent(
159-
JsonSerializer.Serialize(new[] { ru.Value.Id, toolCallId }, OpenAIJsonContext.Default.StringArray!),
160-
functionName,
161-
JsonSerializer.Deserialize(rau.FunctionArguments, OpenAIJsonContext.Default.IDictionaryStringObject)!));
162-
}
163-
164-
yield return ruUpdate;
165-
break;
144+
AuthorName = ru.Value.AssistantId,
145+
ChatThreadId = threadId,
146+
CreatedAt = ru.Value.CreatedAt,
147+
ModelId = ru.Value.Model,
148+
RawRepresentation = ru,
149+
ResponseId = responseId,
150+
Role = ChatRole.Assistant,
151+
};
152+
153+
if (ru.Value.Usage is { } usage)
154+
{
155+
ruUpdate.Contents.Add(new UsageContent(new()
156+
{
157+
InputTokenCount = usage.InputTokenCount,
158+
OutputTokenCount = usage.OutputTokenCount,
159+
TotalTokenCount = usage.TotalTokenCount,
160+
}));
161+
}
162+
163+
if (ru is RequiredActionUpdate rau && rau.ToolCallId is string toolCallId && rau.FunctionName is string functionName)
164+
{
165+
ruUpdate.Contents.Add(
166+
new FunctionCallContent(
167+
JsonSerializer.Serialize(new[] { ru.Value.Id, toolCallId }, OpenAIJsonContext.Default.StringArray!),
168+
functionName,
169+
JsonSerializer.Deserialize(rau.FunctionArguments, OpenAIJsonContext.Default.IDictionaryStringObject)!));
170+
}
171+
172+
responseUpdates.Add(ruUpdate);
173+
yield return ruUpdate;
174+
break;
175+
}
166176
}
167177
}
178+
finally
179+
{
180+
chatMessages.AddRangeFromUpdates(responseUpdates);
181+
}
168182
}
169183

170184
/// <inheritdoc />

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,9 @@ public override async Task<ChatResponse> GetResponseAsync(IList<ChatMessage> cha
5656

5757
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } result)
5858
{
59-
if (options?.ChatThreadId is null)
59+
foreach (ChatMessage message in result.Messages)
6060
{
61-
foreach (ChatMessage message in result.Messages)
62-
{
63-
chatMessages.Add(message);
64-
}
61+
chatMessages.Add(message);
6562
}
6663
}
6764
else
@@ -94,12 +91,9 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
9491
yield return chunk;
9592
}
9693

97-
if (chatResponse.ChatThreadId is null)
94+
foreach (ChatMessage message in chatResponse.Messages)
9895
{
99-
foreach (ChatMessage message in chatResponse.Messages)
100-
{
101-
chatMessages.Add(message);
102-
}
96+
chatMessages.Add(message);
10397
}
10498
}
10599
else

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ public IList<ChatMessage> ChatMessages
4545
set => _chatMessages = Throw.IfNull(value);
4646
}
4747

48+
/// <summary>Gets or sets the chat options associated with the operation that initiated this function call request.</summary>
49+
public ChatOptions? Options { get; set; }
50+
4851
/// <summary>Gets or sets the AI function to be invoked.</summary>
4952
public AIFunction Function
5053
{

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,19 @@ public override async Task<ChatResponse> GetResponseAsync(IList<ChatMessage> cha
216216
// fast path out by just returning the original response.
217217
if (iteration == 0 && !requiresFunctionInvocation)
218218
{
219+
Debug.Assert(originalChatMessages == chatMessages,
220+
"Expected the history to be the original, such that there's no additional work to do to keep it up to date.");
219221
return response;
220222
}
221223

224+
// If chatMessages is different from originalChatMessages, we previously created a different history
225+
// in order to avoid sending state back to an inner client that was already tracking it. But we still
226+
// need that original history to contain all the state. So copy it over if necessary.
227+
if (chatMessages != originalChatMessages)
228+
{
229+
AddRange(originalChatMessages, response.Messages);
230+
}
231+
222232
// Track aggregatable details from the response.
223233
(responseMessages ??= []).AddRange(response.Messages);
224234
if (response.Usage is not null)
@@ -249,7 +259,6 @@ public override async Task<ChatResponse> GetResponseAsync(IList<ChatMessage> cha
249259
}
250260

251261
// If the response indicates the inner client is tracking the history, clear it to avoid re-sending the state.
252-
// In that case, we also avoid touching the user's history, so that we don't need to clear it.
253262
if (response.ChatThreadId is not null)
254263
{
255264
if (chatMessages == originalChatMessages)
@@ -261,10 +270,24 @@ public override async Task<ChatResponse> GetResponseAsync(IList<ChatMessage> cha
261270
chatMessages.Clear();
262271
}
263272
}
273+
else if (chatMessages != originalChatMessages)
274+
{
275+
// This should be a very rare case. In a previous iteration, we got back a non-null
276+
// chatThreadId, so we forked chatMessages. But now, we got back a null chatThreadId,
277+
// and chatMessages is no longer the full history. Thankfully, we've been keeping
278+
// originalChatMessages up to date; we can just switch back to use it.
279+
chatMessages = originalChatMessages;
280+
}
264281

265282
// Add the responses from the function calls into the history.
266283
var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false);
267284
responseMessages.AddRange(modeAndMessages.MessagesAdded);
285+
286+
if (chatMessages != originalChatMessages)
287+
{
288+
AddRange(originalChatMessages, modeAndMessages.MessagesAdded);
289+
}
290+
268291
if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId))
269292
{
270293
// Terminate
@@ -311,6 +334,19 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
311334
Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
312335
}
313336

337+
// Make sure that any of the response messages that were added to the chat history also get
338+
// added to the original history if it's different.
339+
if (chatMessages != originalChatMessages)
340+
{
341+
// If chatThreadId was null previously, then we would have added any function result content into
342+
// the original chat messages, passed those chat messages to GetStreamingResponseAsync, and it would
343+
// have added all the new response messages into the original chat messages. But chatThreadId was
344+
// non-null, hence we forked chatMessages. chatMessages then included only the function result content
345+
// and should now include that function result content plus the response messages. None of that is
346+
// in the original, so we can just add everything from chatMessages into the original.
347+
AddRange(originalChatMessages, chatMessages);
348+
}
349+
314350
// If there are no tools to call, or for any other reason we should stop, return the response.
315351
if (functionCallContents is not { Count: > 0 } ||
316352
options?.Tools is not { Count: > 0 } ||
@@ -332,14 +368,17 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
332368
chatMessages.Clear();
333369
}
334370
}
371+
else if (chatMessages != originalChatMessages)
372+
{
373+
// This should be a very rare case. In a previous iteration, we got back a non-null
374+
// chatThreadId, so we forked chatMessages. But now, we got back a null chatThreadId,
375+
// and chatMessages is no longer the full history. Thankfully, we've been keeping
376+
// originalChatMessages up to date; we can just switch back to use it.
377+
chatMessages = originalChatMessages;
378+
}
335379

336380
// Process all of the functions, adding their results into the history.
337381
var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false);
338-
if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, chatThreadId))
339-
{
340-
// Terminate
341-
yield break;
342-
}
343382

344383
// Stream any generated function results. These are already part of the history,
345384
// but we stream them out for informational purposes.
@@ -361,6 +400,12 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
361400
yield return toolResultUpdate;
362401
Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
363402
}
403+
404+
if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, chatThreadId))
405+
{
406+
// Terminate
407+
yield break;
408+
}
364409
}
365410
}
366411

@@ -407,10 +452,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
407452
// as otherwise we'll be in an infinite loop.
408453
options = options.Clone();
409454
options.ToolMode = null;
410-
if (chatThreadId is not null)
411-
{
412-
options.ChatThreadId = chatThreadId;
413-
}
455+
options.ChatThreadId = chatThreadId;
414456

415457
break;
416458

@@ -419,10 +461,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
419461
options = options.Clone();
420462
options.Tools = null;
421463
options.ToolMode = null;
422-
if (chatThreadId is not null)
423-
{
424-
options.ChatThreadId = chatThreadId;
425-
}
464+
options.ChatThreadId = chatThreadId;
426465

427466
break;
428467

@@ -433,7 +472,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
433472
default:
434473
// As with the other modes, ensure we've propagated the chat thread ID to the options.
435474
// We only need to clone the options if we're actually mutating it.
436-
if (chatThreadId is not null && options.ChatThreadId != chatThreadId)
475+
if (options.ChatThreadId != chatThreadId)
437476
{
438477
options = options.Clone();
439478
options.ChatThreadId = chatThreadId;
@@ -468,6 +507,8 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
468507
FunctionInvocationResult result = await ProcessFunctionCallAsync(
469508
chatMessages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false);
470509
IList<ChatMessage> added = AddResponseMessages(chatMessages, [result]);
510+
511+
ThrowIfNoFunctionResultsAdded(added);
471512
return (result.ContinueMode, added);
472513
}
473514
else
@@ -505,10 +546,23 @@ select Task.Run(() => ProcessFunctionCallAsync(
505546
}
506547
}
507548

549+
ThrowIfNoFunctionResultsAdded(added);
508550
return (continueMode, added);
509551
}
510552
}
511553

554+
/// <summary>
555+
/// Throws an exception if <paramref name="chatMessages"/> is empty due to an override of
556+
/// <see cref="AddResponseMessages"/> not having added any messages.
557+
/// </summary>
558+
private void ThrowIfNoFunctionResultsAdded(IList<ChatMessage> chatMessages)
559+
{
560+
if (chatMessages.Count == 0)
561+
{
562+
Throw.InvalidOperationException($"{GetType().Name}.{nameof(AddResponseMessages)} did not add any function result messages.");
563+
}
564+
}
565+
512566
/// <summary>Processes the function call described in <paramref name="callContents"/>[<paramref name="iteration"/>].</summary>
513567
/// <param name="chatMessages">The current chat contents, inclusive of the function call contents being processed.</param>
514568
/// <param name="options">The options used for the response being processed.</param>
@@ -533,6 +587,7 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
533587
FunctionInvocationContext context = new()
534588
{
535589
ChatMessages = chatMessages,
590+
Options = options,
536591
CallContent = callContent,
537592
Function = function,
538593
Iteration = iteration,
@@ -698,6 +753,22 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
698753
return result;
699754
}
700755

756+
/// <summary>Adds all messages from <paramref name="source"/> into <paramref name="destination"/>.</summary>
757+
private static void AddRange(IList<ChatMessage> destination, IEnumerable<ChatMessage> source)
758+
{
759+
if (destination is List<ChatMessage> list)
760+
{
761+
list.AddRange(source);
762+
}
763+
else
764+
{
765+
foreach (var message in source)
766+
{
767+
destination.Add(message);
768+
}
769+
}
770+
}
771+
701772
private static TimeSpan GetElapsedTime(long startingTimestamp) =>
702773
#if NET
703774
Stopwatch.GetElapsedTime(startingTimestamp);

test/Shared/Throw/ThrowTest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,10 @@ public void Collection_IfReadOnly()
388388
_ = Throw.IfReadOnly(new List<int>());
389389

390390
IList<int> list = new int[4];
391-
Assert.Throws<ArgumentException>("list", () => Throw.IfReadOnly(list);
391+
Assert.Throws<ArgumentException>("list", () => Throw.IfReadOnly(list));
392392

393393
list = new ReadOnlyCollection<int>();
394-
Assert.Throws<ArgumentException>("list", () => Throw.IfReadOnly(list);
394+
Assert.Throws<ArgumentException>("list", () => Throw.IfReadOnly(list));
395395
}
396396

397397
#endregion

0 commit comments

Comments
 (0)