@@ -216,9 +216,19 @@ public override async Task<ChatResponse> GetResponseAsync(IList<ChatMessage> cha
216
216
// fast path out by just returning the original response.
217
217
if ( iteration == 0 && ! requiresFunctionInvocation )
218
218
{
219
+ Debug . Assert ( originalChatMessages == chatMessages ,
220
+ "Expected the history to be the original, such that there's no additional work to do to keep it up to date." ) ;
219
221
return response ;
220
222
}
221
223
224
+ // If chatMessages is different from originalChatMessages, we previously created a different history
225
+ // in order to avoid sending state back to an inner client that was already tracking it. But we still
226
+ // need that original history to contain all the state. So copy it over if necessary.
227
+ if ( chatMessages != originalChatMessages )
228
+ {
229
+ AddRange ( originalChatMessages , response . Messages ) ;
230
+ }
231
+
222
232
// Track aggregatable details from the response.
223
233
( responseMessages ??= [ ] ) . AddRange ( response . Messages ) ;
224
234
if ( response . Usage is not null )
@@ -249,7 +259,6 @@ public override async Task<ChatResponse> GetResponseAsync(IList<ChatMessage> cha
249
259
}
250
260
251
261
// If the response indicates the inner client is tracking the history, clear it to avoid re-sending the state.
252
- // In that case, we also avoid touching the user's history, so that we don't need to clear it.
253
262
if ( response . ChatThreadId is not null )
254
263
{
255
264
if ( chatMessages == originalChatMessages )
@@ -261,10 +270,24 @@ public override async Task<ChatResponse> GetResponseAsync(IList<ChatMessage> cha
261
270
chatMessages . Clear ( ) ;
262
271
}
263
272
}
273
+ else if ( chatMessages != originalChatMessages )
274
+ {
275
+ // This should be a very rare case. In a previous iteration, we got back a non-null
276
+ // chatThreadId, so we forked chatMessages. But now, we got back a null chatThreadId,
277
+ // and chatMessages is no longer the full history. Thankfully, we've been keeping
278
+ // originalChatMessages up to date; we can just switch back to use it.
279
+ chatMessages = originalChatMessages ;
280
+ }
264
281
265
282
// Add the responses from the function calls into the history.
266
283
var modeAndMessages = await ProcessFunctionCallsAsync ( chatMessages , options ! , functionCallContents ! , iteration , cancellationToken ) . ConfigureAwait ( false ) ;
267
284
responseMessages . AddRange ( modeAndMessages . MessagesAdded ) ;
285
+
286
+ if ( chatMessages != originalChatMessages )
287
+ {
288
+ AddRange ( originalChatMessages , modeAndMessages . MessagesAdded ) ;
289
+ }
290
+
268
291
if ( UpdateOptionsForMode ( modeAndMessages . Mode , ref options ! , response . ChatThreadId ) )
269
292
{
270
293
// Terminate
@@ -311,6 +334,19 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
311
334
Activity . Current = activity ; // workaround for https://github.com/dotnet/runtime/issues/47802
312
335
}
313
336
337
+ // Make sure that any of the response messages that were added to the chat history also get
338
+ // added to the original history if it's different.
339
+ if ( chatMessages != originalChatMessages )
340
+ {
341
+ // If chatThreadId was null previously, then we would have added any function result content into
342
+ // the original chat messages, passed those chat messages to GetStreamingResponseAsync, and it would
343
+ // have added all the new response messages into the original chat messages. But chatThreadId was
344
+ // non-null, hence we forked chatMessages. chatMessages then included only the function result content
345
+ // and should now include that function result content plus the response messages. None of that is
346
+ // in the original, so we can just add everything from chatMessages into the original.
347
+ AddRange ( originalChatMessages , chatMessages ) ;
348
+ }
349
+
314
350
// If there are no tools to call, or for any other reason we should stop, return the response.
315
351
if ( functionCallContents is not { Count : > 0 } ||
316
352
options ? . Tools is not { Count : > 0 } ||
@@ -332,14 +368,17 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
332
368
chatMessages . Clear ( ) ;
333
369
}
334
370
}
371
+ else if ( chatMessages != originalChatMessages )
372
+ {
373
+ // This should be a very rare case. In a previous iteration, we got back a non-null
374
+ // chatThreadId, so we forked chatMessages. But now, we got back a null chatThreadId,
375
+ // and chatMessages is no longer the full history. Thankfully, we've been keeping
376
+ // originalChatMessages up to date; we can just switch back to use it.
377
+ chatMessages = originalChatMessages ;
378
+ }
335
379
336
380
// Process all of the functions, adding their results into the history.
337
381
var modeAndMessages = await ProcessFunctionCallsAsync ( chatMessages , options , functionCallContents , iteration , cancellationToken ) . ConfigureAwait ( false ) ;
338
- if ( UpdateOptionsForMode ( modeAndMessages . Mode , ref options , chatThreadId ) )
339
- {
340
- // Terminate
341
- yield break ;
342
- }
343
382
344
383
// Stream any generated function results. These are already part of the history,
345
384
// but we stream them out for informational purposes.
@@ -361,6 +400,12 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
361
400
yield return toolResultUpdate ;
362
401
Activity . Current = activity ; // workaround for https://github.com/dotnet/runtime/issues/47802
363
402
}
403
+
404
+ if ( UpdateOptionsForMode ( modeAndMessages . Mode , ref options , chatThreadId ) )
405
+ {
406
+ // Terminate
407
+ yield break ;
408
+ }
364
409
}
365
410
}
366
411
@@ -407,10 +452,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
407
452
// as otherwise we'll be in an infinite loop.
408
453
options = options . Clone ( ) ;
409
454
options . ToolMode = null ;
410
- if ( chatThreadId is not null )
411
- {
412
- options . ChatThreadId = chatThreadId ;
413
- }
455
+ options . ChatThreadId = chatThreadId ;
414
456
415
457
break ;
416
458
@@ -419,10 +461,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
419
461
options = options . Clone ( ) ;
420
462
options . Tools = null ;
421
463
options . ToolMode = null ;
422
- if ( chatThreadId is not null )
423
- {
424
- options . ChatThreadId = chatThreadId ;
425
- }
464
+ options . ChatThreadId = chatThreadId ;
426
465
427
466
break ;
428
467
@@ -433,7 +472,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
433
472
default :
434
473
// As with the other modes, ensure we've propagated the chat thread ID to the options.
435
474
// We only need to clone the options if we're actually mutating it.
436
- if ( chatThreadId is not null && options . ChatThreadId != chatThreadId )
475
+ if ( options . ChatThreadId != chatThreadId )
437
476
{
438
477
options = options . Clone ( ) ;
439
478
options . ChatThreadId = chatThreadId ;
@@ -468,6 +507,8 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
468
507
FunctionInvocationResult result = await ProcessFunctionCallAsync (
469
508
chatMessages , options , functionCallContents , iteration , 0 , cancellationToken ) . ConfigureAwait ( false ) ;
470
509
IList < ChatMessage > added = AddResponseMessages ( chatMessages , [ result ] ) ;
510
+
511
+ ThrowIfNoFunctionResultsAdded ( added ) ;
471
512
return ( result . ContinueMode , added ) ;
472
513
}
473
514
else
@@ -505,10 +546,23 @@ select Task.Run(() => ProcessFunctionCallAsync(
505
546
}
506
547
}
507
548
549
+ ThrowIfNoFunctionResultsAdded ( added ) ;
508
550
return ( continueMode , added ) ;
509
551
}
510
552
}
511
553
554
+ /// <summary>
555
+ /// Throws an exception if <paramref name="chatMessages"/> is empty due to an override of
556
+ /// <see cref="AddResponseMessages"/> not having added any messages.
557
+ /// </summary>
558
+ private void ThrowIfNoFunctionResultsAdded ( IList < ChatMessage > chatMessages )
559
+ {
560
+ if ( chatMessages . Count == 0 )
561
+ {
562
+ Throw . InvalidOperationException ( $ "{ GetType ( ) . Name } .{ nameof ( AddResponseMessages ) } did not add any function result messages.") ;
563
+ }
564
+ }
565
+
512
566
/// <summary>Processes the function call described in <paramref name="callContents"/>[<paramref name="iteration"/>].</summary>
513
567
/// <param name="chatMessages">The current chat contents, inclusive of the function call contents being processed.</param>
514
568
/// <param name="options">The options used for the response being processed.</param>
@@ -533,6 +587,7 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
533
587
FunctionInvocationContext context = new ( )
534
588
{
535
589
ChatMessages = chatMessages ,
590
+ Options = options ,
536
591
CallContent = callContent ,
537
592
Function = function ,
538
593
Iteration = iteration ,
@@ -698,6 +753,22 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
698
753
return result ;
699
754
}
700
755
756
+ /// <summary>Adds all messages from <paramref name="source"/> into <paramref name="destination"/>.</summary>
757
+ private static void AddRange ( IList < ChatMessage > destination , IEnumerable < ChatMessage > source )
758
+ {
759
+ if ( destination is List < ChatMessage > list )
760
+ {
761
+ list . AddRange ( source ) ;
762
+ }
763
+ else
764
+ {
765
+ foreach ( var message in source )
766
+ {
767
+ destination . Add ( message ) ;
768
+ }
769
+ }
770
+ }
771
+
701
772
private static TimeSpan GetElapsedTime ( long startingTimestamp ) =>
702
773
#if NET
703
774
Stopwatch . GetElapsedTime ( startingTimestamp ) ;
0 commit comments