13
13
using Microsoft . Extensions . Logging ;
14
14
using Microsoft . Extensions . Logging . Abstractions ;
15
15
using Microsoft . Shared . Diagnostics ;
16
- using static Microsoft . Extensions . AI . OpenTelemetryConsts . GenAI ;
17
16
18
17
#pragma warning disable CA2213 // Disposable fields should be disposed
19
18
#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test
@@ -233,7 +232,7 @@ public override async Task<ChatResponse> GetResponseAsync(
233
232
functionCallContents ? . Clear ( ) ;
234
233
235
234
// Make the call to the inner client.
236
- response = await base . GetResponseAsync ( messages , options , cancellationToken ) . ConfigureAwait ( false ) ;
235
+ response = await base . GetResponseAsync ( messages , options , cancellationToken ) ;
237
236
if ( response is null )
238
237
{
239
238
Throw . InvalidOperationException ( $ "The inner { nameof ( IChatClient ) } returned a null { nameof ( ChatResponse ) } .") ;
@@ -279,7 +278,7 @@ public override async Task<ChatResponse> GetResponseAsync(
279
278
280
279
// Add the responses from the function calls into the augmented history and also into the tracked
281
280
// list of response messages.
282
- var modeAndMessages = await ProcessFunctionCallsAsync ( augmentedHistory , options ! , functionCallContents ! , iteration , consecutiveErrorCount , cancellationToken ) . ConfigureAwait ( false ) ;
281
+ var modeAndMessages = await ProcessFunctionCallsAsync ( augmentedHistory , options ! , functionCallContents ! , iteration , consecutiveErrorCount , cancellationToken ) ;
283
282
responseMessages . AddRange ( modeAndMessages . MessagesAdded ) ;
284
283
consecutiveErrorCount = modeAndMessages . NewConsecutiveErrorCount ;
285
284
@@ -325,7 +324,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
325
324
updates . Clear ( ) ;
326
325
functionCallContents ? . Clear ( ) ;
327
326
328
- await foreach ( var update in base . GetStreamingResponseAsync ( messages , options , cancellationToken ) . ConfigureAwait ( false ) )
327
+ await foreach ( var update in base . GetStreamingResponseAsync ( messages , options , cancellationToken ) )
329
328
{
330
329
if ( update is null )
331
330
{
@@ -356,7 +355,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
356
355
FixupHistories ( originalMessages , ref messages , ref augmentedHistory , response , responseMessages , ref lastIterationHadThreadId ) ;
357
356
358
357
// Process all of the functions, adding their results into the history.
359
- var modeAndMessages = await ProcessFunctionCallsAsync ( augmentedHistory , options , functionCallContents , iteration , consecutiveErrorCount , cancellationToken ) . ConfigureAwait ( false ) ;
358
+ var modeAndMessages = await ProcessFunctionCallsAsync ( augmentedHistory , options , functionCallContents , iteration , consecutiveErrorCount , cancellationToken ) ;
360
359
responseMessages . AddRange ( modeAndMessages . MessagesAdded ) ;
361
360
consecutiveErrorCount = modeAndMessages . NewConsecutiveErrorCount ;
362
361
@@ -534,7 +533,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin
534
533
if ( functionCallContents . Count == 1 )
535
534
{
536
535
FunctionInvocationResult result = await ProcessFunctionCallAsync (
537
- messages , options , functionCallContents , iteration , 0 , captureCurrentIterationExceptions , cancellationToken ) . ConfigureAwait ( false ) ;
536
+ messages , options , functionCallContents , iteration , 0 , captureCurrentIterationExceptions , cancellationToken ) ;
538
537
539
538
IList < ChatMessage > added = CreateResponseMessages ( [ result ] ) ;
540
539
ThrowIfNoFunctionResultsAdded ( added ) ;
@@ -549,13 +548,15 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin
549
548
550
549
if ( AllowConcurrentInvocation )
551
550
{
552
- // Schedule the invocation of every function.
553
- // In this case we always capture exceptions because the ordering is nondeterministic
551
+ // Rather than await'ing each function before invoking the next, invoke all of them
552
+ // and then await all of them. We avoid forcibly introducing parallelism via Task.Run,
553
+ // but if a function invocation completes asynchronously, its processing can overlap
554
+ // with the processing of other the other invocation invocations.
554
555
results = await Task . WhenAll (
555
556
from i in Enumerable . Range ( 0 , functionCallContents . Count )
556
- select Task . Run ( ( ) => ProcessFunctionCallAsync (
557
+ select ProcessFunctionCallAsync (
557
558
messages , options , functionCallContents ,
558
- iteration , i , captureExceptions : true , cancellationToken ) ) ) . ConfigureAwait ( false ) ;
559
+ iteration , i , captureExceptions : true , cancellationToken ) ) ;
559
560
}
560
561
else
561
562
{
@@ -565,7 +566,7 @@ select Task.Run(() => ProcessFunctionCallAsync(
565
566
{
566
567
results [ i ] = await ProcessFunctionCallAsync (
567
568
messages , options , functionCallContents ,
568
- iteration , i , captureCurrentIterationExceptions , cancellationToken ) . ConfigureAwait ( false ) ;
569
+ iteration , i , captureCurrentIterationExceptions , cancellationToken ) ;
569
570
}
570
571
}
571
572
@@ -663,7 +664,7 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
663
664
object ? result ;
664
665
try
665
666
{
666
- result = await InvokeFunctionAsync ( context , cancellationToken ) . ConfigureAwait ( false ) ;
667
+ result = await InvokeFunctionAsync ( context , cancellationToken ) ;
667
668
}
668
669
catch ( Exception e ) when ( ! cancellationToken . IsCancellationRequested )
669
670
{
@@ -763,7 +764,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
763
764
try
764
765
{
765
766
CurrentContext = context ; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit
766
- result = await context . Function . InvokeAsync ( context . Arguments , cancellationToken ) . ConfigureAwait ( false ) ;
767
+ result = await context . Function . InvokeAsync ( context . Arguments , cancellationToken ) ;
767
768
}
768
769
catch ( Exception e )
769
770
{
0 commit comments