Skip to content

Commit 5cec925

Browse files
committed
Merged PR 49002: Remove use of ConfigureAwait from Microsoft.Extensions.AI.dll for AIFunction...
Remove use of ConfigureAwait from Microsoft.Extensions.AI.dll for AIFunction invocations (#6250) We try to use ConfigureAwait(false) throughout our libraries. However, we exempt ourselves from that in cases where user code is expected to be called back from within the async code, and there's a reasonable presumption that such code might care about the synchronization context. AIFunction fits that bill. And FunctionInvokingChatClient needs to invoke such functions, which means that we need to be able to successfully flow the context from where user code calls Get{Streaming}ResponseAsync through into wherever a FunctionInvokingChatClient is in the middleware pipeline. We could try to selectively avoid ConfigureAwait(false) on the path through middleware that could result in calls to FICC.Get{Streaming}ResponseAsync, but that's fairly brittle and hard to maintain. Instead, this PR just removes ConfigureAwait use from the M.E.AI library. It also fixes a few places where tasks were explicitly being created and queued to the thread pool.
2 parents d4094cc + f0bda61 commit 5cec925

22 files changed

+157
-86
lines changed

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/AnonymousDelegatingChatClient.cs

+12-11
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
using System;
55
using System.Collections.Generic;
66
using System.Diagnostics;
7+
#if !NET9_0_OR_GREATER
78
using System.Runtime.CompilerServices;
9+
#endif
810
using System.Threading;
911
using System.Threading.Channels;
1012
using System.Threading.Tasks;
@@ -100,8 +102,8 @@ async Task<ChatResponse> GetResponseViaSharedAsync(
100102
ChatResponse? response = null;
101103
await _sharedFunc(messages, options, async (messages, options, cancellationToken) =>
102104
{
103-
response = await InnerClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
104-
}, cancellationToken).ConfigureAwait(false);
105+
response = await InnerClient.GetResponseAsync(messages, options, cancellationToken);
106+
}, cancellationToken);
105107

106108
if (response is null)
107109
{
@@ -133,20 +135,19 @@ public override IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
133135
{
134136
var updates = Channel.CreateBounded<ChatResponseUpdate>(1);
135137

136-
#pragma warning disable CA2016 // explicitly not forwarding the cancellation token, as we need to ensure the channel is always completed
137-
_ = Task.Run(async () =>
138-
#pragma warning restore CA2016
138+
_ = ProcessAsync();
139+
async Task ProcessAsync()
139140
{
140141
Exception? error = null;
141142
try
142143
{
143144
await _sharedFunc(messages, options, async (messages, options, cancellationToken) =>
144145
{
145-
await foreach (var update in InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
146+
await foreach (var update in InnerClient.GetStreamingResponseAsync(messages, options, cancellationToken))
146147
{
147-
await updates.Writer.WriteAsync(update, cancellationToken).ConfigureAwait(false);
148+
await updates.Writer.WriteAsync(update, cancellationToken);
148149
}
149-
}, cancellationToken).ConfigureAwait(false);
150+
}, cancellationToken);
150151
}
151152
catch (Exception ex)
152153
{
@@ -157,7 +158,7 @@ await _sharedFunc(messages, options, async (messages, options, cancellationToken
157158
{
158159
_ = updates.Writer.TryComplete(error);
159160
}
160-
});
161+
}
161162

162163
#if NET9_0_OR_GREATER
163164
return updates.Reader.ReadAllAsync(cancellationToken);
@@ -166,7 +167,7 @@ await _sharedFunc(messages, options, async (messages, options, cancellationToken
166167
static async IAsyncEnumerable<ChatResponseUpdate> ReadAllAsync(
167168
ChannelReader<ChatResponseUpdate> channel, [EnumeratorCancellation] CancellationToken cancellationToken)
168169
{
169-
while (await channel.WaitToReadAsync(cancellationToken).ConfigureAwait(false))
170+
while (await channel.WaitToReadAsync(cancellationToken))
170171
{
171172
while (channel.TryRead(out var update))
172173
{
@@ -187,7 +188,7 @@ static async IAsyncEnumerable<ChatResponseUpdate> ReadAllAsync(
187188

188189
static async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsyncViaGetResponseAsync(Task<ChatResponse> task)
189190
{
190-
ChatResponse response = await task.ConfigureAwait(false);
191+
ChatResponse response = await task;
191192
foreach (var update in response.ToChatResponseUpdates())
192193
{
193194
yield return update;

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs

+9-9
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ public override async Task<ChatResponse> GetResponseAsync(
5555
// concurrent callers might trigger duplicate requests, but that's acceptable.
5656
var cacheKey = GetCacheKey(messages, options, _boxedFalse);
5757

58-
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result)
58+
if (await ReadCacheAsync(cacheKey, cancellationToken) is not { } result)
5959
{
60-
result = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
61-
await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false);
60+
result = await base.GetResponseAsync(messages, options, cancellationToken);
61+
await WriteCacheAsync(cacheKey, result, cancellationToken);
6262
}
6363

6464
return result;
@@ -77,7 +77,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
7777
// result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one.
7878

7979
var cacheKey = GetCacheKey(messages, options, _boxedTrue);
80-
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatResponse)
80+
if (await ReadCacheAsync(cacheKey, cancellationToken) is { } chatResponse)
8181
{
8282
// Yield all of the cached items.
8383
foreach (var chunk in chatResponse.ToChatResponseUpdates())
@@ -89,20 +89,20 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
8989
{
9090
// Yield and store all of the items.
9191
List<ChatResponseUpdate> capturedItems = [];
92-
await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
92+
await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken))
9393
{
9494
capturedItems.Add(chunk);
9595
yield return chunk;
9696
}
9797

9898
// Write the captured items to the cache as a non-streaming result.
99-
await WriteCacheAsync(cacheKey, capturedItems.ToChatResponse(), cancellationToken).ConfigureAwait(false);
99+
await WriteCacheAsync(cacheKey, capturedItems.ToChatResponse(), cancellationToken);
100100
}
101101
}
102102
else
103103
{
104104
var cacheKey = GetCacheKey(messages, options, _boxedTrue);
105-
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
105+
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken) is { } existingChunks)
106106
{
107107
// Yield all of the cached items.
108108
string? chatThreadId = null;
@@ -116,14 +116,14 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
116116
{
117117
// Yield and store all of the items.
118118
List<ChatResponseUpdate> capturedItems = [];
119-
await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
119+
await foreach (var chunk in base.GetStreamingResponseAsync(messages, options, cancellationToken))
120120
{
121121
capturedItems.Add(chunk);
122122
yield return chunk;
123123
}
124124

125125
// Write the captured items to the cache.
126-
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false);
126+
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken);
127127
}
128128
}
129129
}

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderChatClientExtensions.cs

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5-
using Microsoft.Extensions.AI;
65
using Microsoft.Shared.Diagnostics;
76

87
namespace Microsoft.Extensions.AI;

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ public static async Task<ChatResponse<T>> GetResponseAsync<T>(
221221
messages = [.. messages, promptAugmentation];
222222
}
223223

224-
var result = await chatClient.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
224+
var result = await chatClient.GetResponseAsync(messages, options, cancellationToken);
225225
return new ChatResponse<T>(result, serializerOptions) { IsWrappedInObject = isWrappedInObject };
226226
}
227227

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ public ConfigureOptionsChatClient(IChatClient innerClient, Action<ChatOptions> c
3636
/// <inheritdoc/>
3737
public override async Task<ChatResponse> GetResponseAsync(
3838
IEnumerable<ChatMessage> messages, ChatOptions? options = null, CancellationToken cancellationToken = default) =>
39-
await base.GetResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false);
39+
await base.GetResponseAsync(messages, Configure(options), cancellationToken);
4040

4141
/// <inheritdoc/>
4242
public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
4343
IEnumerable<ChatMessage> messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
4444
{
45-
await foreach (var update in base.GetStreamingResponseAsync(messages, Configure(options), cancellationToken).ConfigureAwait(false))
45+
await foreach (var update in base.GetStreamingResponseAsync(messages, Configure(options), cancellationToken))
4646
{
4747
yield return update;
4848
}

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public JsonSerializerOptions JsonSerializerOptions
5252
_ = Throw.IfNull(key);
5353
_jsonSerializerOptions.MakeReadOnly();
5454

55-
if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson)
55+
if (await _storage.GetAsync(key, cancellationToken) is byte[] existingJson)
5656
{
5757
return (ChatResponse?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(ChatResponse)));
5858
}
@@ -66,7 +66,7 @@ public JsonSerializerOptions JsonSerializerOptions
6666
_ = Throw.IfNull(key);
6767
_jsonSerializerOptions.MakeReadOnly();
6868

69-
if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson)
69+
if (await _storage.GetAsync(key, cancellationToken) is byte[] existingJson)
7070
{
7171
return (IReadOnlyList<ChatResponseUpdate>?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList<ChatResponseUpdate>)));
7272
}
@@ -82,7 +82,7 @@ protected override async Task WriteCacheAsync(string key, ChatResponse value, Ca
8282
_jsonSerializerOptions.MakeReadOnly();
8383

8484
var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(ChatResponse)));
85-
await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false);
85+
await _storage.SetAsync(key, newJson, cancellationToken);
8686
}
8787

8888
/// <inheritdoc />
@@ -93,7 +93,7 @@ protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList
9393
_jsonSerializerOptions.MakeReadOnly();
9494

9595
var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList<ChatResponseUpdate>)));
96-
await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false);
96+
await _storage.SetAsync(key, newJson, cancellationToken);
9797
}
9898

9999
/// <summary>Computes a cache key for the specified values.</summary>

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs

+14-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
using Microsoft.Extensions.Logging;
1414
using Microsoft.Extensions.Logging.Abstractions;
1515
using Microsoft.Shared.Diagnostics;
16-
using static Microsoft.Extensions.AI.OpenTelemetryConsts.GenAI;
1716

1817
#pragma warning disable CA2213 // Disposable fields should be disposed
1918
#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test
@@ -233,7 +232,7 @@ public override async Task<ChatResponse> GetResponseAsync(
233232
functionCallContents?.Clear();
234233

235234
// Make the call to the inner client.
236-
response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
235+
response = await base.GetResponseAsync(messages, options, cancellationToken);
237236
if (response is null)
238237
{
239238
Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}.");
@@ -279,7 +278,7 @@ public override async Task<ChatResponse> GetResponseAsync(
279278

280279
// Add the responses from the function calls into the augmented history and also into the tracked
281280
// list of response messages.
282-
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken).ConfigureAwait(false);
281+
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken);
283282
responseMessages.AddRange(modeAndMessages.MessagesAdded);
284283
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;
285284

@@ -325,7 +324,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
325324
updates.Clear();
326325
functionCallContents?.Clear();
327326

328-
await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
327+
await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken))
329328
{
330329
if (update is null)
331330
{
@@ -356,7 +355,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
356355
FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId);
357356

358357
// Process all of the functions, adding their results into the history.
359-
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken).ConfigureAwait(false);
358+
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken);
360359
responseMessages.AddRange(modeAndMessages.MessagesAdded);
361360
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;
362361

@@ -534,7 +533,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin
534533
if (functionCallContents.Count == 1)
535534
{
536535
FunctionInvocationResult result = await ProcessFunctionCallAsync(
537-
messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken).ConfigureAwait(false);
536+
messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken);
538537

539538
IList<ChatMessage> added = CreateResponseMessages([result]);
540539
ThrowIfNoFunctionResultsAdded(added);
@@ -549,13 +548,15 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin
549548

550549
if (AllowConcurrentInvocation)
551550
{
552-
// Schedule the invocation of every function.
553-
// In this case we always capture exceptions because the ordering is nondeterministic
551+
// Rather than await'ing each function before invoking the next, invoke all of them
552+
// and then await all of them. We avoid forcibly introducing parallelism via Task.Run,
553+
// but if a function invocation completes asynchronously, its processing can overlap
554+
// with the processing of other the other invocation invocations.
554555
results = await Task.WhenAll(
555556
from i in Enumerable.Range(0, functionCallContents.Count)
556-
select Task.Run(() => ProcessFunctionCallAsync(
557+
select ProcessFunctionCallAsync(
557558
messages, options, functionCallContents,
558-
iteration, i, captureExceptions: true, cancellationToken))).ConfigureAwait(false);
559+
iteration, i, captureExceptions: true, cancellationToken));
559560
}
560561
else
561562
{
@@ -565,7 +566,7 @@ select Task.Run(() => ProcessFunctionCallAsync(
565566
{
566567
results[i] = await ProcessFunctionCallAsync(
567568
messages, options, functionCallContents,
568-
iteration, i, captureCurrentIterationExceptions, cancellationToken).ConfigureAwait(false);
569+
iteration, i, captureCurrentIterationExceptions, cancellationToken);
569570
}
570571
}
571572

@@ -663,7 +664,7 @@ private async Task<FunctionInvocationResult> ProcessFunctionCallAsync(
663664
object? result;
664665
try
665666
{
666-
result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false);
667+
result = await InvokeFunctionAsync(context, cancellationToken);
667668
}
668669
catch (Exception e) when (!cancellationToken.IsCancellationRequested)
669670
{
@@ -763,7 +764,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
763764
try
764765
{
765766
CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit
766-
result = await context.Function.InvokeAsync(context.Arguments, cancellationToken).ConfigureAwait(false);
767+
result = await context.Function.InvokeAsync(context.Arguments, cancellationToken);
767768
}
768769
catch (Exception e)
769770
{

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public override async Task<ChatResponse> GetResponseAsync(
6060

6161
try
6262
{
63-
var response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
63+
var response = await base.GetResponseAsync(messages, options, cancellationToken);
6464

6565
if (_logger.IsEnabled(LogLevel.Debug))
6666
{
@@ -127,7 +127,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
127127
{
128128
try
129129
{
130-
if (!await e.MoveNextAsync().ConfigureAwait(false))
130+
if (!await e.MoveNextAsync())
131131
{
132132
break;
133133
}
@@ -164,7 +164,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
164164
}
165165
finally
166166
{
167-
await e.DisposeAsync().ConfigureAwait(false);
167+
await e.DisposeAsync();
168168
}
169169
}
170170

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ public override async Task<ChatResponse> GetResponseAsync(
145145
Exception? error = null;
146146
try
147147
{
148-
response = await base.GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false);
148+
response = await base.GetResponseAsync(messages, options, cancellationToken);
149149
return response;
150150
}
151151
catch (Exception ex)
@@ -183,7 +183,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
183183
throw;
184184
}
185185

186-
var responseEnumerator = updates.ConfigureAwait(false).GetAsyncEnumerator();
186+
var responseEnumerator = updates.GetAsyncEnumerator(cancellationToken);
187187
List<ChatResponseUpdate> trackedUpdates = [];
188188
Exception? error = null;
189189
try

src/Libraries/Microsoft.Extensions.AI/Embeddings/AnonymousDelegatingEmbeddingGenerator.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(
3939
{
4040
_ = Throw.IfNull(values);
4141

42-
return await _generateFunc(values, options, InnerGenerator, cancellationToken).ConfigureAwait(false);
42+
return await _generateFunc(values, options, InnerGenerator, cancellationToken);
4343
}
4444
}

0 commit comments

Comments
 (0)