Skip to content

Commit 8595e2d

Browse files
committed
PR feedback
1 parent 81855f7 commit 8595e2d

File tree

4 files changed

+129
-33
lines changed

4 files changed

+129
-33
lines changed

src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ internal class Http3Connection : IHttp3StreamLifetimeHandler, IRequestProcessor
4646
private int _stoppedAcceptingStreams;
4747
private bool _gracefulCloseStarted;
4848
private int _activeRequestCount;
49+
private CancellationTokenSource _acceptStreamsCts = new CancellationTokenSource();
4950
private readonly Http3PeerSettings _serverSettings = new Http3PeerSettings();
5051
private readonly Http3PeerSettings _clientSettings = new Http3PeerSettings();
5152
private readonly StreamCloseAwaitable _streamCompletionAwaitable = new StreamCloseAwaitable();
@@ -107,7 +108,8 @@ public void StopProcessingNextRequest(bool serverInitiated)
107108

108109
if (Interlocked.CompareExchange(ref _gracefulCloseInitiator, initiator, GracefulCloseInitiator.None) == GracefulCloseInitiator.None)
109110
{
110-
UpdateConnectionState();
111+
// Break out of AcceptStreams so connection state can be updated.
112+
_acceptStreamsCts.Cancel();
111113
}
112114
}
113115
}
@@ -240,15 +242,20 @@ public async Task ProcessRequestsAsync<TContext>(IHttpApplication<TContext> appl
240242
// Don't delay on waiting to send outbound control stream settings.
241243
outboundControlStreamTask = ProcessOutboundControlStreamAsync(outboundControlStream);
242244

243-
while (true)
245+
while (_stoppedAcceptingStreams == 0)
244246
{
245-
var streamContext = await _multiplexedContext.AcceptAsync();
247+
var streamContext = await _multiplexedContext.AcceptAsync(_acceptStreamsCts.Token);
246248

247249
try
248250
{
249251
if (streamContext == null)
250252
{
251-
break;
253+
if (_acceptStreamsCts.Token.IsCancellationRequested)
254+
{
255+
_acceptStreamsCts = new CancellationTokenSource();
256+
}
257+
258+
continue;
252259
}
253260

254261
var streamDirectionFeature = streamContext.Features.Get<IStreamDirectionFeature>();
@@ -267,9 +274,10 @@ public async Task ProcessRequestsAsync<TContext>(IHttpApplication<TContext> appl
267274
}
268275
else
269276
{
270-
// TODO race condition between checking this and updating highest stream ID
277+
// Request stream
278+
271279
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-5.2-2
272-
if (_stoppedAcceptingStreams == 1)
280+
if (_gracefulCloseStarted)
273281
{
274282
// https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1.2-3
275283
streamContext.Features.Get<IProtocolErrorCodeFeature>()!.Error = (long)Http3ErrorCode.RequestRejected;
@@ -280,7 +288,6 @@ public async Task ProcessRequestsAsync<TContext>(IHttpApplication<TContext> appl
280288
// Request stream IDs are tracked.
281289
UpdateHighestOpenedRequestStreamId(streamIdFeature.StreamId);
282290

283-
// Request stream
284291
var persistentStateFeature = streamContext.Features.Get<IPersistentStateFeature>();
285292
Debug.Assert(persistentStateFeature != null, $"Required {nameof(IPersistentStateFeature)} not on stream context.");
286293

@@ -353,7 +360,10 @@ public async Task ProcessRequestsAsync<TContext>(IHttpApplication<TContext> appl
353360
// Only send goaway if the connection close was initiated on the server.
354361
if (!clientAbort)
355362
{
356-
await SendGoAwayAsync(GetCurrentGoAwayStreamId());
363+
if (TryStopAcceptingStreams() || _gracefulCloseStarted)
364+
{
365+
await SendGoAwayAsync(GetCurrentGoAwayStreamId());
366+
}
357367
}
358368

359369
// Abort active request streams.
@@ -430,37 +440,31 @@ private void UpdateConnectionState()
430440
return;
431441
}
432442

433-
int activeRequestCount;
434-
lock (_streams)
443+
if (_gracefulCloseInitiator != GracefulCloseInitiator.None)
435444
{
436-
activeRequestCount = _activeRequestCount;
437-
}
445+
int activeRequestCount;
446+
lock (_streams)
447+
{
448+
activeRequestCount = _activeRequestCount;
449+
}
438450

439-
if (_gracefulCloseInitiator != GracefulCloseInitiator.None && !_gracefulCloseStarted)
440-
{
441-
_gracefulCloseStarted = true;
451+
if (!_gracefulCloseStarted)
452+
{
453+
_gracefulCloseStarted = true;
442454

443-
_errorCodeFeature.Error = (long)Http3ErrorCode.NoError;
444-
Log.Http3ConnectionClosing(_context.ConnectionId);
455+
_errorCodeFeature.Error = (long)Http3ErrorCode.NoError;
456+
Log.Http3ConnectionClosing(_context.ConnectionId);
445457

446-
if (_gracefulCloseInitiator == GracefulCloseInitiator.Server && activeRequestCount > 0)
447-
{
448-
if (TryStopAcceptingStreams())
458+
if (_gracefulCloseInitiator == GracefulCloseInitiator.Server && activeRequestCount > 0)
449459
{
450460
// Go away with largest streamid to initiate graceful shutdown.
451461
SendGoAwayAsync(VariableLengthIntegerHelper.EightByteLimit).Preserve();
452462
}
453463
}
454-
}
455464

456-
if (_activeRequestCount == 0)
457-
{
458-
if (_gracefulCloseStarted)
465+
if (activeRequestCount == 0)
459466
{
460-
if (TryStopAcceptingStreams())
461-
{
462-
SendGoAwayAsync(GetCurrentGoAwayStreamId()).Preserve();
463-
}
467+
TryStopAcceptingStreams();
464468
}
465469
}
466470
}

src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,17 @@ private async Task DoSend()
327327

328328
_clientAbort = true;
329329
}
330+
catch (QuicConnectionAbortedException ex)
331+
{
332+
// Abort from peer.
333+
Error = ex.ErrorCode;
334+
_log.StreamAborted(this, ex.ErrorCode, ex);
335+
336+
// This could be ignored if _shutdownReason is already set.
337+
shutdownReason = new ConnectionResetException(ex.Message, ex);
338+
339+
_clientAbort = true;
340+
}
330341
catch (QuicOperationAbortedException ex)
331342
{
332343
// AbortWrite has been called for the stream.

src/Servers/Kestrel/Transport.Quic/test/QuicConnectionContextTests.cs

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,48 @@ public class QuicConnectionContextTests : TestApplicationErrorLoggerLoggedTest
2424
{
2525
private static readonly byte[] TestData = Encoding.UTF8.GetBytes("Hello world");
2626

27+
[ConditionalFact]
28+
[MsQuicSupported]
29+
public async Task AcceptAsync_CancellationThenAccept_AcceptStreamAfterCancellation()
30+
{
31+
// Arrange
32+
var connectionClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
33+
34+
await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);
35+
36+
// Act
37+
var acceptTask = connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout();
38+
39+
var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint);
40+
41+
using var clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options);
42+
await clientConnection.ConnectAsync().DefaultTimeout();
43+
44+
await using var serverConnection = await acceptTask.DefaultTimeout();
45+
46+
// Wait for stream and then cancel
47+
var cts = new CancellationTokenSource();
48+
var acceptStreamTask = serverConnection.AcceptAsync(cts.Token);
49+
cts.Cancel();
50+
51+
var serverStream = await acceptStreamTask.DefaultTimeout();
52+
Assert.Null(serverStream);
53+
54+
// Wait for stream after cancellation
55+
acceptStreamTask = serverConnection.AcceptAsync();
56+
57+
await using var clientStream = clientConnection.OpenBidirectionalStream();
58+
await clientStream.WriteAsync(TestData);
59+
60+
// Assert
61+
serverStream = await acceptStreamTask.DefaultTimeout();
62+
Assert.NotNull(serverStream);
63+
64+
var read = await serverStream.Transport.Input.ReadAtLeastAsync(TestData.Length).DefaultTimeout();
65+
Assert.Equal(TestData, read.Buffer.ToArray());
66+
serverStream.Transport.Input.AdvanceTo(read.Buffer.End);
67+
}
68+
2769
[ConditionalFact]
2870
[MsQuicSupported]
2971
public async Task AcceptAsync_ClientClosesConnection_ServerNotified()
@@ -495,10 +537,6 @@ public async Task StreamPool_ManyConcurrentStreams_StreamPoolFull()
495537
const int StreamsSent = 101;
496538
for (var i = 0; i < StreamsSent; i++)
497539
{
498-
// TODO: Race condition in QUIC library.
499-
// Delay between sending streams to avoid
500-
// https://github.com/dotnet/runtime/issues/55249
501-
await Task.Delay(100);
502540
streamTasks.Add(SendStream(requestState));
503541
}
504542

src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,48 @@ public async Task GET_ServerStreaming_ClientReadsPartialResponse(HttpProtocols p
219219
}
220220
}
221221

222+
[ConditionalTheory]
223+
[MsQuicSupported]
224+
[InlineData(HttpProtocols.Http3, Skip = "https://github.com/dotnet/runtime/issues/56969")]
225+
[InlineData(HttpProtocols.Http2)]
226+
public async Task POST_ClientSendsOnlyHeaders_RequestReceivedOnServer(HttpProtocols protocol)
227+
{
228+
// Arrange
229+
var builder = CreateHostBuilder(context =>
230+
{
231+
return Task.CompletedTask;
232+
}, protocol: protocol);
233+
234+
using (var host = builder.Build())
235+
using (var client = CreateClient())
236+
{
237+
await host.StartAsync();
238+
239+
var requestContent = new StreamingHttpContext();
240+
241+
var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/");
242+
request.Content = requestContent;
243+
request.Version = GetProtocol(protocol);
244+
request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
245+
246+
// Act
247+
var responseTask = client.SendAsync(request, CancellationToken.None).DefaultTimeout();
248+
249+
var requestStream = await requestContent.GetStreamAsync().DefaultTimeout();
250+
251+
// Send headers
252+
await requestStream.FlushAsync().DefaultTimeout();
253+
254+
var response = await responseTask.DefaultTimeout();
255+
256+
// Assert
257+
response.EnsureSuccessStatusCode();
258+
Assert.Equal(GetProtocol(protocol), response.Version);
259+
260+
await host.StopAsync();
261+
}
262+
}
263+
222264
[ConditionalFact]
223265
[MsQuicSupported]
224266
public async Task POST_ServerCompletesWithoutReadingRequestBody_ClientGetsResponse()
@@ -937,7 +979,8 @@ private async Task WaitForLogAsync(Func<IEnumerable<WriteContext>, bool> testLog
937979
{
938980
Logger.LogInformation($"Started waiting for logs: {message}");
939981

940-
for (int i = 0; i < 5; i++)
982+
var retryCount = !Debugger.IsAttached ? 5 : int.MaxValue;
983+
for (var i = 0; i < retryCount; i++)
941984
{
942985
if (testLogs(TestSink.Writes))
943986
{

0 commit comments

Comments
 (0)