From 1767e2bf096242a2624029a4d0592652d0b44c62 Mon Sep 17 00:00:00 2001 From: Ben Olden-Cooligan Date: Fri, 29 Dec 2023 09:24:44 -0800 Subject: [PATCH] Terminate invalid connections --- .../Internal/ServerConnectionContext.cs | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/GrpcDotNetNamedPipes/Internal/ServerConnectionContext.cs b/GrpcDotNetNamedPipes/Internal/ServerConnectionContext.cs index 812536b..07c73c9 100644 --- a/GrpcDotNetNamedPipes/Internal/ServerConnectionContext.cs +++ b/GrpcDotNetNamedPipes/Internal/ServerConnectionContext.cs @@ -21,6 +21,7 @@ internal class ServerConnectionContext : TransportMessageHandler, IDisposable private readonly ConnectionLogger _logger; private readonly Dictionary> _methodHandlers; private readonly PayloadQueue _payloadQueue; + private readonly CancellationTokenSource _requestInitTimeoutCts = new(); public ServerConnectionContext(NamedPipeServerStream pipeStream, ConnectionLogger logger, Dictionary> methodHandlers) @@ -32,6 +33,12 @@ public ServerConnectionContext(NamedPipeServerStream pipeStream, ConnectionLogge _methodHandlers = methodHandlers; _payloadQueue = new PayloadQueue(); CancellationTokenSource = new CancellationTokenSource(); + + // We're supposed to receive a RequestInit message immediately after the pipe connects. 10s is chosen as a very + // conservative timeout. If this expires without receiving RequestInit, we can assume the client is not using + // the right protocol and we should terminate the connection rather than potentially leave it open forever. + Task.Delay(10_000, _requestInitTimeoutCts.Token) + .ContinueWith(_ => RequestInitTimeout(), TaskContinuationOptions.OnlyOnRanToCompletion); } public NamedPipeServerStream PipeStream { get; } @@ -61,10 +68,38 @@ public IServerStreamWriter CreateResponseStream(Marshaller public override void HandleRequestInit(string methodFullName, DateTime? deadline) { + _requestInitTimeoutCts.Cancel(); + if (!_methodHandlers.ContainsKey(methodFullName)) + { + _logger.Log("Unsupported method"); + try + { + WriteTrailers(StatusCode.Unimplemented, ""); + PipeStream.Disconnect(); + } + catch (Exception) + { + // Ignore + } + return; + } Deadline = new Deadline(deadline); Task.Run(async () => await _methodHandlers[methodFullName](this).ConfigureAwait(false)); } + private void RequestInitTimeout() + { + _logger.Log("Timed out waiting for RequestInit"); + try + { + PipeStream.Disconnect(); + } + catch (Exception) + { + // Ignore + } + } + public override void HandleHeaders(Metadata headers) => RequestHeaders = headers; public override void HandleCancel() => CancellationTokenSource.Cancel();