Skip to content

Commit cc15b1b

Browse files
[SignalR] Implement IConnectionLifetimeFeature (dotnet#20604)
1 parent 0d8d4e7 commit cc15b1b

File tree

7 files changed

+239
-10
lines changed

7 files changed

+239
-10
lines changed

src/Servers/Connections.Abstractions/src/DefaultConnectionContext.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ public class DefaultConnectionContext : ConnectionContext,
2626
public DefaultConnectionContext() :
2727
this(Guid.NewGuid().ToString())
2828
{
29-
ConnectionClosed = _connectionClosedTokenSource.Token;
3029
}
3130

3231
/// <summary>
@@ -45,6 +44,8 @@ public DefaultConnectionContext(string id)
4544
Features.Set<IConnectionTransportFeature>(this);
4645
Features.Set<IConnectionLifetimeFeature>(this);
4746
Features.Set<IConnectionEndPointFeature>(this);
47+
48+
ConnectionClosed = _connectionClosedTokenSource.Token;
4849
}
4950

5051
public DefaultConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application)

src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ internal class HttpConnectionContext : ConnectionContext,
2929
ITransferFormatFeature,
3030
IHttpContextFeature,
3131
IHttpTransportFeature,
32-
IConnectionInherentKeepAliveFeature
32+
IConnectionInherentKeepAliveFeature,
33+
IConnectionLifetimeFeature
3334
{
3435
private static long _tenSeconds = TimeSpan.FromSeconds(10).Ticks;
3536

@@ -41,6 +42,7 @@ internal class HttpConnectionContext : ConnectionContext,
4142
private PipeWriterStream _applicationStream;
4243
private IDuplexPipe _application;
4344
private IDictionary<object, object> _items;
45+
private CancellationTokenSource _connectionClosedTokenSource;
4446

4547
private CancellationTokenSource _sendCts;
4648
private bool _activeSend;
@@ -82,6 +84,10 @@ public HttpConnectionContext(string connectionId, string connectionToken, ILogge
8284
Features.Set<IHttpContextFeature>(this);
8385
Features.Set<IHttpTransportFeature>(this);
8486
Features.Set<IConnectionInherentKeepAliveFeature>(this);
87+
Features.Set<IConnectionLifetimeFeature>(this);
88+
89+
_connectionClosedTokenSource = new CancellationTokenSource();
90+
ConnectionClosed = _connectionClosedTokenSource.Token;
8591
}
8692

8793
public CancellationTokenSource Cancellation { get; set; }
@@ -170,6 +176,15 @@ public IDuplexPipe Application
170176

171177
public HttpContext HttpContext { get; set; }
172178

179+
public override CancellationToken ConnectionClosed { get; set; }
180+
181+
public override void Abort()
182+
{
183+
ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource);
184+
185+
HttpContext?.Abort();
186+
}
187+
173188
public void OnHeartbeat(Action<object> action, object state)
174189
{
175190
lock (_heartbeatLock)
@@ -305,6 +320,9 @@ private async Task WaitOnTasks(Task applicationTask, Task transportTask, bool cl
305320
// Now complete the application
306321
Application?.Output.Complete();
307322
Application?.Input.Complete();
323+
324+
// Trigger ConnectionClosed
325+
ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource);
308326
}
309327
}
310328
else
@@ -313,6 +331,9 @@ private async Task WaitOnTasks(Task applicationTask, Task transportTask, bool cl
313331
Application?.Output.Complete(transportTask.Exception?.InnerException);
314332
Application?.Input.Complete();
315333

334+
// Trigger ConnectionClosed
335+
ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource);
336+
316337
try
317338
{
318339
// A poorly written application *could* in theory get stuck forever and it'll show up as a memory leak

src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ public async Task CompletedEndPointEndsConnection()
961961
}
962962

963963
[Fact]
964-
public async Task SynchronusExceptionEndsConnection()
964+
public async Task SynchronousExceptionEndsConnection()
965965
{
966966
bool ExpectedErrors(WriteContext writeContext)
967967
{
@@ -2269,6 +2269,173 @@ bool ExpectedErrors(WriteContext writeContext)
22692269
}
22702270
}
22712271

2272+
[Fact]
2273+
public async Task LongPollingConnectionClosingTriggersConnectionClosedToken()
2274+
{
2275+
using (StartVerifiableLog())
2276+
{
2277+
var manager = CreateConnectionManager(LoggerFactory);
2278+
var pipeOptions = new PipeOptions(pauseWriterThreshold: 2, resumeWriterThreshold: 1);
2279+
var connection = manager.CreateConnection(pipeOptions, pipeOptions);
2280+
connection.TransportType = HttpTransportType.LongPolling;
2281+
2282+
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
2283+
2284+
var context = MakeRequest("/foo", connection);
2285+
2286+
var services = new ServiceCollection();
2287+
services.AddSingleton<NeverEndingConnectionHandler>();
2288+
var builder = new ConnectionBuilder(services.BuildServiceProvider());
2289+
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
2290+
var app = builder.Build();
2291+
var options = new HttpConnectionDispatcherOptions();
2292+
2293+
var pollTask = dispatcher.ExecuteAsync(context, options, app);
2294+
Assert.True(pollTask.IsCompleted);
2295+
2296+
// Now send the second poll
2297+
pollTask = dispatcher.ExecuteAsync(context, options, app);
2298+
2299+
// Issue the delete request and make sure the poll completes
2300+
var deleteContext = new DefaultHttpContext();
2301+
deleteContext.Request.Path = "/foo";
2302+
deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}");
2303+
deleteContext.Request.Method = "DELETE";
2304+
2305+
Assert.False(pollTask.IsCompleted);
2306+
2307+
await dispatcher.ExecuteAsync(deleteContext, options, app).OrTimeout();
2308+
2309+
await pollTask.OrTimeout();
2310+
2311+
// Verify that transport shuts down
2312+
await connection.TransportTask.OrTimeout();
2313+
2314+
// Verify the response from the DELETE request
2315+
Assert.Equal(StatusCodes.Status202Accepted, deleteContext.Response.StatusCode);
2316+
Assert.Equal("text/plain", deleteContext.Response.ContentType);
2317+
Assert.Equal(HttpConnectionStatus.Disposed, connection.Status);
2318+
2319+
// Verify the connection not removed because application is hanging
2320+
Assert.True(manager.TryGetConnection(connection.ConnectionId, out _));
2321+
2322+
Assert.True(connection.ConnectionClosed.IsCancellationRequested);
2323+
}
2324+
}
2325+
2326+
[Fact]
2327+
public async Task SSEConnectionClosingTriggersConnectionClosedToken()
2328+
{
2329+
using (StartVerifiableLog())
2330+
{
2331+
var manager = CreateConnectionManager(LoggerFactory);
2332+
var connection = manager.CreateConnection();
2333+
connection.TransportType = HttpTransportType.ServerSentEvents;
2334+
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
2335+
var context = MakeRequest("/foo", connection);
2336+
SetTransport(context, connection.TransportType);
2337+
var services = new ServiceCollection();
2338+
services.AddSingleton<NeverEndingConnectionHandler>();
2339+
var builder = new ConnectionBuilder(services.BuildServiceProvider());
2340+
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
2341+
var app = builder.Build();
2342+
var options = new HttpConnectionDispatcherOptions();
2343+
_ = dispatcher.ExecuteAsync(context, options, app);
2344+
2345+
// Close the SSE connection
2346+
connection.Transport.Output.Complete();
2347+
2348+
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
2349+
connection.ConnectionClosed.Register(() => tcs.SetResult(null));
2350+
await tcs.Task.OrTimeout();
2351+
}
2352+
}
2353+
2354+
[Fact]
2355+
public async Task WebSocketConnectionClosingTriggersConnectionClosedToken()
2356+
{
2357+
using (StartVerifiableLog())
2358+
{
2359+
var manager = CreateConnectionManager(LoggerFactory);
2360+
var connection = manager.CreateConnection();
2361+
connection.TransportType = HttpTransportType.WebSockets;
2362+
2363+
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
2364+
2365+
var context = MakeRequest("/foo", connection);
2366+
SetTransport(context, HttpTransportType.WebSockets);
2367+
2368+
var services = new ServiceCollection();
2369+
services.AddSingleton<NeverEndingConnectionHandler>();
2370+
var builder = new ConnectionBuilder(services.BuildServiceProvider());
2371+
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
2372+
var app = builder.Build();
2373+
var options = new HttpConnectionDispatcherOptions();
2374+
options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1);
2375+
2376+
_ = dispatcher.ExecuteAsync(context, options, app);
2377+
2378+
var websocket = (TestWebSocketConnectionFeature)context.Features.Get<IHttpWebSocketFeature>();
2379+
await websocket.Accepted.OrTimeout();
2380+
await websocket.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", cancellationToken: default).OrTimeout();
2381+
2382+
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
2383+
connection.ConnectionClosed.Register(() => tcs.SetResult(null));
2384+
await tcs.Task.OrTimeout();
2385+
}
2386+
}
2387+
2388+
public class CustomHttpRequestLifetimeFeature : IHttpRequestLifetimeFeature
2389+
{
2390+
public CancellationToken RequestAborted { get; set; }
2391+
2392+
private CancellationTokenSource _cts;
2393+
public CustomHttpRequestLifetimeFeature()
2394+
{
2395+
_cts = new CancellationTokenSource();
2396+
RequestAborted = _cts.Token;
2397+
}
2398+
2399+
public void Abort()
2400+
{
2401+
_cts.Cancel();
2402+
}
2403+
}
2404+
2405+
[Fact]
2406+
public async Task AbortingConnectionAbortsHttpContextAndTriggersConnectionClosedToken()
2407+
{
2408+
using (StartVerifiableLog())
2409+
{
2410+
var manager = CreateConnectionManager(LoggerFactory);
2411+
var connection = manager.CreateConnection();
2412+
connection.TransportType = HttpTransportType.ServerSentEvents;
2413+
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
2414+
var context = MakeRequest("/foo", connection);
2415+
var lifetimeFeature = new CustomHttpRequestLifetimeFeature();
2416+
context.Features.Set<IHttpRequestLifetimeFeature>(lifetimeFeature);
2417+
SetTransport(context, connection.TransportType);
2418+
2419+
var services = new ServiceCollection();
2420+
services.AddSingleton<NeverEndingConnectionHandler>();
2421+
var builder = new ConnectionBuilder(services.BuildServiceProvider());
2422+
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
2423+
var app = builder.Build();
2424+
var options = new HttpConnectionDispatcherOptions();
2425+
_ = dispatcher.ExecuteAsync(context, options, app);
2426+
2427+
connection.Abort();
2428+
2429+
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
2430+
connection.ConnectionClosed.Register(() => tcs.SetResult(null));
2431+
await tcs.Task.OrTimeout();
2432+
2433+
tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
2434+
lifetimeFeature.RequestAborted.Register(() => tcs.SetResult(null));
2435+
await tcs.Task.OrTimeout();
2436+
}
2437+
}
2438+
22722439
private static async Task CheckTransportSupported(HttpTransportType supportedTransports, HttpTransportType transportType, int status, ILoggerFactory loggerFactory)
22732440
{
22742441
var manager = CreateConnectionManager(loggerFactory);

src/SignalR/server/Core/src/HubConnectionContext.cs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public class HubConnectionContext
3737
private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1);
3838
private readonly object _receiveMessageTimeoutLock = new object();
3939
private readonly ISystemClock _systemClock;
40+
private readonly CancellationTokenRegistration _closedRegistration;
4041

4142
private StreamTracker _streamTracker;
4243
private long _lastSendTimeStamp;
@@ -66,6 +67,7 @@ public HubConnectionContext(ConnectionContext connectionContext, HubConnectionCo
6667
_connectionContext = connectionContext;
6768
_logger = loggerFactory.CreateLogger<HubConnectionContext>();
6869
ConnectionAborted = _connectionAbortedTokenSource.Token;
70+
_closedRegistration = connectionContext.ConnectionClosed.Register((state) => ((HubConnectionContext)state).Abort(), this);
6971

7072
HubCallerContext = new DefaultHubCallerContext(this);
7173

@@ -624,12 +626,6 @@ private static void AbortConnection(object state)
624626
finally
625627
{
626628
_ = InnerAbortConnection(connection);
627-
628-
// Use _streamTracker to avoid lazy init from StreamTracker getter if it doesn't exist
629-
if (connection._streamTracker != null)
630-
{
631-
connection._streamTracker.CompleteAll(new OperationCanceledException("The underlying connection was closed."));
632-
}
633629
}
634630

635631
static async Task InnerAbortConnection(HubConnectionContext connection)
@@ -670,6 +666,17 @@ internal void StopClientTimeout()
670666
}
671667
}
672668

669+
internal void Cleanup()
670+
{
671+
_closedRegistration.Dispose();
672+
673+
// Use _streamTracker to avoid lazy init from StreamTracker getter if it doesn't exist
674+
if (_streamTracker != null)
675+
{
676+
_streamTracker.CompleteAll(new OperationCanceledException("The underlying connection was closed."));
677+
}
678+
}
679+
673680
private static class Log
674681
{
675682
// Category: HubConnectionContext

src/SignalR/server/Core/src/HubConnectionHandler.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ public override async Task OnConnectedAsync(ConnectionContext connection)
139139
}
140140
finally
141141
{
142+
connectionContext.Cleanup();
143+
142144
Log.ConnectedEnding(_logger);
143145
await _lifetimeManager.OnDisconnectedAsync(connectionContext);
144146
}

src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ public async Task StreamDontRead(ChannelReader<string> source)
221221
}
222222
}
223223

224-
225224
public async Task<int> StreamingSum(ChannelReader<int> source)
226225
{
227226
var total = 0;
@@ -322,6 +321,14 @@ public async Task UploadDoesWorkOnComplete(ChannelReader<string> source)
322321
tcs.TrySetResult(42);
323322
}
324323
}
324+
325+
public async Task BlockingMethod()
326+
{
327+
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
328+
Context.ConnectionAborted.Register(state => ((TaskCompletionSource<object>)state).SetResult(null), tcs);
329+
330+
await tcs.Task;
331+
}
325332
}
326333

327334
public abstract class TestHub : Hub

src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,30 @@ bool ExpectedErrors(WriteContext writeContext)
948948
Assert.True(hasErrorLog);
949949
}
950950

951+
[Fact]
952+
public async Task HubMethodListeningToConnectionAbortedClosesOnConnectionContextAbort()
953+
{
954+
using (StartVerifiableLog())
955+
{
956+
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(MethodHub), loggerFactory: LoggerFactory);
957+
958+
using (var client = new TestClient())
959+
{
960+
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
961+
962+
var invokeTask = client.InvokeAsync(nameof(MethodHub.BlockingMethod));
963+
964+
client.Connection.Abort();
965+
966+
// If this completes then the server has completed the connection
967+
await connectionHandlerTask.OrTimeout();
968+
969+
// Nothing written to connection because it was closed
970+
Assert.False(invokeTask.IsCompleted);
971+
}
972+
}
973+
}
974+
951975
[Fact]
952976
public async Task DetailedExceptionEvenWhenNotExplicitlySet()
953977
{

0 commit comments

Comments
 (0)