Skip to content

Commit b83aece

Browse files
authored
Make CancellationToken available in call credentials interceptor (#2107)
1 parent 8af4723 commit b83aece

File tree

6 files changed

+138
-47
lines changed

6 files changed

+138
-47
lines changed

src/Grpc.AspNetCore.Server/Internal/HttpContextServerCallContext.cs

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,7 @@ protected override string PeerCore
7777
get
7878
{
7979
// Follows the standard at https://github.com/grpc/grpc/blob/master/doc/naming.md
80-
if (_peer == null)
81-
{
82-
_peer = BuildPeer();
83-
}
84-
85-
return _peer;
80+
return _peer ??= BuildPeer();
8681
}
8782
}
8883

@@ -291,10 +286,7 @@ private void EndCallCore()
291286

292287
private void LogCallEnd()
293288
{
294-
if (_activity != null)
295-
{
296-
_activity.AddTag(GrpcServerConstants.ActivityStatusCodeTag, _status.StatusCode.ToTrailerString());
297-
}
289+
_activity?.AddTag(GrpcServerConstants.ActivityStatusCodeTag, _status.StatusCode.ToTrailerString());
298290
if (_status.StatusCode != StatusCode.OK)
299291
{
300292
if (GrpcEventSource.Log.IsEnabled())
@@ -387,10 +379,7 @@ protected override Task WriteResponseHeadersAsyncCore(Metadata responseHeaders)
387379
public void Initialize(ISystemClock? clock = null)
388380
{
389381
_activity = GetHostActivity();
390-
if (_activity != null)
391-
{
392-
_activity.AddTag(GrpcServerConstants.ActivityMethodTag, MethodCore);
393-
}
382+
_activity?.AddTag(GrpcServerConstants.ActivityMethodTag, MethodCore);
394383

395384
if (GrpcEventSource.Log.IsEnabled())
396385
{

src/Grpc.Core.Api/AsyncAuthInterceptor.cs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#endregion
1818

19+
using System.Threading;
1920
using System.Threading.Tasks;
2021
using Grpc.Core.Utils;
2122

@@ -34,16 +35,25 @@ namespace Grpc.Core;
3435
/// </summary>
3536
public class AuthInterceptorContext
3637
{
37-
readonly string serviceUrl;
38-
readonly string methodName;
38+
private readonly string serviceUrl;
39+
private readonly string methodName;
40+
private readonly CancellationToken cancellationToken;
3941

4042
/// <summary>
4143
/// Initializes a new instance of <c>AuthInterceptorContext</c>.
4244
/// </summary>
43-
public AuthInterceptorContext(string serviceUrl, string methodName)
45+
public AuthInterceptorContext(string serviceUrl, string methodName) : this(serviceUrl, methodName, CancellationToken.None)
46+
{
47+
}
48+
49+
/// <summary>
50+
/// Initializes a new instance of <c>AuthInterceptorContext</c>.
51+
/// </summary>
52+
public AuthInterceptorContext(string serviceUrl, string methodName, CancellationToken cancellationToken)
4453
{
4554
this.serviceUrl = GrpcPreconditions.CheckNotNull(serviceUrl, nameof(serviceUrl));
4655
this.methodName = GrpcPreconditions.CheckNotNull(methodName, nameof(methodName));
56+
this.cancellationToken = cancellationToken;
4757
}
4858

4959
/// <summary>
@@ -61,4 +71,12 @@ public string MethodName
6171
{
6272
get { return methodName; }
6373
}
74+
75+
/// <summary>
76+
/// The cancellation token of the RPC being called.
77+
/// </summary>
78+
public CancellationToken CancellationToken
79+
{
80+
get { return cancellationToken; }
81+
}
6482
}

src/Grpc.Net.Client/Internal/DefaultCallCredentialsConfigurator.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -20,15 +20,19 @@
2020

2121
namespace Grpc.Net.Client.Internal;
2222

23-
internal class DefaultCallCredentialsConfigurator : CallCredentialsConfiguratorBase
23+
internal sealed class DefaultCallCredentialsConfigurator : CallCredentialsConfiguratorBase
2424
{
2525
public AsyncAuthInterceptor? Interceptor { get; private set; }
26-
public IReadOnlyList<CallCredentials>? Credentials { get; private set; }
26+
public IReadOnlyList<CallCredentials>? CompositeCredentials { get; private set; }
2727

28-
public void Reset()
28+
// A place to cache the context to avoid creating a new instance for each auth interceptor call.
29+
// It's ok not to reset this state because the context is only used for the lifetime of the call.
30+
public AuthInterceptorContext? CachedContext { get; set; }
31+
32+
public void ResetPerCallCredentialState()
2933
{
3034
Interceptor = null;
31-
Credentials = null;
35+
CompositeCredentials = null;
3236
}
3337

3438
public override void SetAsyncAuthInterceptorCredentials(object? state, AsyncAuthInterceptor interceptor)
@@ -38,6 +42,6 @@ public override void SetAsyncAuthInterceptorCredentials(object? state, AsyncAuth
3842

3943
public override void SetCompositeCredentials(object? state, IReadOnlyList<CallCredentials> credentials)
4044
{
41-
Credentials = credentials;
45+
CompositeCredentials = credentials;
4246
}
4347
}

src/Grpc.Net.Client/Internal/GrpcCall.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -955,13 +955,13 @@ private async Task ReadCredentials(HttpRequestMessage request)
955955

956956
if (Options.Credentials != null)
957957
{
958-
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, Options.Credentials).ConfigureAwait(false);
958+
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, Options.Credentials, _callCts.Token).ConfigureAwait(false);
959959
}
960960
if (Channel.CallCredentials?.Count > 0)
961961
{
962962
foreach (var credentials in Channel.CallCredentials)
963963
{
964-
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, credentials).ConfigureAwait(false);
964+
await GrpcProtocolHelpers.ReadCredentialMetadata(configurator, Channel, request, Method, credentials, _callCts.Token).ConfigureAwait(false);
965965
}
966966
}
967967
}

src/Grpc.Net.Client/Internal/GrpcProtocolHelpers.cs

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -121,13 +121,34 @@ internal static bool ShouldSkipHeader(string name)
121121
/* round an integer up to the next value with three significant figures */
122122
private static long TimeoutRoundUpToThreeSignificantFigures(long x)
123123
{
124-
if (x < 1000) return x;
125-
if (x < 10000) return RoundUp(x, 10);
126-
if (x < 100000) return RoundUp(x, 100);
127-
if (x < 1000000) return RoundUp(x, 1000);
128-
if (x < 10000000) return RoundUp(x, 10000);
129-
if (x < 100000000) return RoundUp(x, 100000);
130-
if (x < 1000000000) return RoundUp(x, 1000000);
124+
if (x < 1000)
125+
{
126+
return x;
127+
}
128+
if (x < 10000)
129+
{
130+
return RoundUp(x, 10);
131+
}
132+
if (x < 100000)
133+
{
134+
return RoundUp(x, 100);
135+
}
136+
if (x < 1000000)
137+
{
138+
return RoundUp(x, 1000);
139+
}
140+
if (x < 10000000)
141+
{
142+
return RoundUp(x, 10000);
143+
}
144+
if (x < 100000000)
145+
{
146+
return RoundUp(x, 100000);
147+
}
148+
if (x < 1000000000)
149+
{
150+
return RoundUp(x, 1000000);
151+
}
131152
return RoundUp(x, 10000000);
132153

133154
static long RoundUp(long x, long divisor)
@@ -235,7 +256,7 @@ internal static bool CanWriteCompressed(WriteOptions? writeOptions)
235256
return canCompress;
236257
}
237258

238-
internal static AuthInterceptorContext CreateAuthInterceptorContext(Uri baseAddress, IMethod method)
259+
internal static AuthInterceptorContext CreateAuthInterceptorContext(Uri baseAddress, IMethod method, CancellationToken cancellationToken)
239260
{
240261
var authority = baseAddress.Authority;
241262
if (baseAddress.Scheme == Uri.UriSchemeHttps && authority.EndsWith(":443", StringComparison.Ordinal))
@@ -252,38 +273,44 @@ internal static AuthInterceptorContext CreateAuthInterceptorContext(Uri baseAddr
252273
serviceUrl += "/";
253274
}
254275
serviceUrl += method.ServiceName;
255-
return new AuthInterceptorContext(serviceUrl, method.Name);
276+
return new AuthInterceptorContext(serviceUrl, method.Name, cancellationToken);
256277
}
257278

258279
internal static async Task ReadCredentialMetadata(
259280
DefaultCallCredentialsConfigurator configurator,
260281
GrpcChannel channel,
261282
HttpRequestMessage message,
262283
IMethod method,
263-
CallCredentials credentials)
284+
CallCredentials credentials,
285+
CancellationToken cancellationToken)
264286
{
265287
credentials.InternalPopulateConfiguration(configurator, null);
266288

267289
if (configurator.Interceptor != null)
268290
{
269-
var authInterceptorContext = GrpcProtocolHelpers.CreateAuthInterceptorContext(channel.Address, method);
291+
// Multiple auth interceptors can be called for a gRPC call.
292+
// These all have the same data: address, method and cancellation token.
293+
// Lazily allocate the context if it is needed.
294+
// Stored on the configurator instead of a ref parameter because ref parameters are not supported in async methods.
295+
configurator.CachedContext ??= CreateAuthInterceptorContext(channel.Address, method, cancellationToken);
296+
270297
var metadata = new Metadata();
271-
await configurator.Interceptor(authInterceptorContext, metadata).ConfigureAwait(false);
298+
await configurator.Interceptor(configurator.CachedContext, metadata).ConfigureAwait(false);
272299

273300
foreach (var entry in metadata)
274301
{
275302
AddHeader(message.Headers, entry);
276303
}
277304
}
278305

279-
if (configurator.Credentials != null)
306+
if (configurator.CompositeCredentials != null)
280307
{
281308
// Copy credentials locally. ReadCredentialMetadata will update it.
282-
var callCredentials = configurator.Credentials;
283-
foreach (var c in callCredentials)
309+
var compositeCredentials = configurator.CompositeCredentials;
310+
foreach (var callCredentials in compositeCredentials)
284311
{
285-
configurator.Reset();
286-
await ReadCredentialMetadata(configurator, channel, message, method, c).ConfigureAwait(false);
312+
configurator.ResetPerCallCredentialState();
313+
await ReadCredentialMetadata(configurator, channel, message, method, callCredentials, cancellationToken).ConfigureAwait(false);
287314
}
288315
}
289316
}

test/Grpc.Net.Client.Tests/CallCredentialTests.cs

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -18,6 +18,7 @@
1818

1919
using System.Net;
2020
using System.Net.Http.Headers;
21+
using System.Threading;
2122
using Greet;
2223
using Grpc.Core;
2324
using Grpc.Net.Client.Tests.Infrastructure;
@@ -79,19 +80,71 @@ public async Task CallCredentialsWithHttps_MetadataOnRequest()
7980
var invoker = HttpClientCallInvokerFactory.Create(httpClient);
8081

8182
// Act
83+
var syncPoint = new SyncPoint(runContinuationsAsynchronously: true);
8284
var callCredentials = CallCredentials.FromInterceptor(async (context, metadata) =>
8385
{
84-
// The operation is asynchronous to ensure delegate is awaited
85-
await Task.Delay(50);
86+
// The operation is asynchronous to ensure auth interceptor is awaited.
87+
// Sending the request and returning a response is blocked until the auth interceptor completes.
88+
await syncPoint.WaitToContinue();
89+
90+
// Set header.
8691
metadata.Add("authorization", "SECRET_TOKEN");
8792
});
8893
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(credentials: callCredentials), new HelloRequest());
89-
await call.ResponseAsync.DefaultTimeout();
94+
var responseTask = call.ResponseAsync;
95+
96+
await syncPoint.WaitForSyncPoint().DefaultTimeout();
97+
98+
// Response task should be blocked waiting for the auth interceptor to complete.
99+
Assert.False(responseTask.IsCompleted);
100+
// Sending the request should be blocked waiting for the auth interceptor to complete.
101+
Assert.Null(authorizationValue);
102+
103+
syncPoint.Continue();
104+
await responseTask.DefaultTimeout();
90105

91106
// Assert
92107
Assert.AreEqual("SECRET_TOKEN", authorizationValue);
93108
}
94109

110+
[Test]
111+
public async Task CallCredentialsWithHttps_CancellationToken()
112+
{
113+
// Arrange
114+
string? authorizationValue = null;
115+
var httpClient = ClientTestHelpers.CreateTestClient(async request =>
116+
{
117+
authorizationValue = request.Headers.GetValues("authorization").Single();
118+
119+
var reply = new HelloReply { Message = "Hello world" };
120+
var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout();
121+
return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent);
122+
});
123+
var invoker = HttpClientCallInvokerFactory.Create(httpClient);
124+
125+
// Act
126+
var unreachableAuthInterceptorSection = false;
127+
var callCredentials = CallCredentials.FromInterceptor(async (context, metadata) =>
128+
{
129+
var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
130+
context.CancellationToken.Register(s => ((TaskCompletionSource<object?>)s!).SetCanceled(), tcs);
131+
132+
await tcs.Task;
133+
134+
unreachableAuthInterceptorSection = true;
135+
});
136+
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(credentials: callCredentials), new HelloRequest());
137+
var responseTask = call.ResponseAsync;
138+
139+
call.Dispose();
140+
141+
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => responseTask).DefaultTimeout();
142+
Assert.AreEqual(StatusCode.Cancelled, ex.StatusCode);
143+
144+
// Assert
145+
Assert.False(unreachableAuthInterceptorSection);
146+
}
147+
95148
[Test]
96149
public async Task CallCredentialsWithHttp_NoMetadataOnRequest()
97150
{

0 commit comments

Comments
 (0)