From c11308f047d3c46c079e810aaca3934c9a4fd2fa Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Tue, 21 Sep 2021 20:11:40 +0200 Subject: [PATCH] Add WebSockets telemetry --- .../ReverseProxy.Metrics.Sample/Startup.cs | 5 + .../WebSocketsTelemetryConsumer.cs | 21 ++ src/ReverseProxy/Forwarder/HttpForwarder.cs | 2 +- .../Utilities}/DelegatingStream.cs | 15 +- .../WebSocketCloseReason.cs | 14 + .../WebSocketsTelemetry/WebSocketsParser.cs | 132 ++++++++ .../WebSocketsTelemetry.cs | 22 ++ .../WebSocketsTelemetryExtensions.cs | 22 ++ .../WebSocketsTelemetryMiddleware.cs | 81 +++++ .../WebSocketsTelemetryStream.cs | 101 ++++++ .../TelemetryConsumptionExtensions.cs | 41 ++- .../IWebSocketsTelemetryConsumer.cs | 23 ++ .../WebSockets/WebSocketCloseReason.cs | 17 + .../WebSocketsEventListenerService.cs | 61 ++++ .../WebSocketsTelemetryTests.cs | 308 ++++++++++++++++++ .../Forwarder/HttpForwarderTests.cs | 1 + 16 files changed, 844 insertions(+), 22 deletions(-) create mode 100644 samples/ReverseProxy.Metrics.Sample/WebSocketsTelemetryConsumer.cs rename {test/ReverseProxy.Tests/Common => src/ReverseProxy/Utilities}/DelegatingStream.cs (91%) create mode 100644 src/ReverseProxy/WebSocketsTelemetry/WebSocketCloseReason.cs create mode 100644 src/ReverseProxy/WebSocketsTelemetry/WebSocketsParser.cs create mode 100644 src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetry.cs create mode 100644 src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryExtensions.cs create mode 100644 src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryMiddleware.cs create mode 100644 src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryStream.cs create mode 100644 src/TelemetryConsumption/WebSockets/IWebSocketsTelemetryConsumer.cs create mode 100644 src/TelemetryConsumption/WebSockets/WebSocketCloseReason.cs create mode 100644 src/TelemetryConsumption/WebSockets/WebSocketsEventListenerService.cs create mode 100644 test/ReverseProxy.FunctionalTests/WebSocketsTelemetryTests.cs diff --git a/samples/ReverseProxy.Metrics.Sample/Startup.cs b/samples/ReverseProxy.Metrics.Sample/Startup.cs index cdad7c9d4..0cf587190 100644 --- a/samples/ReverseProxy.Metrics.Sample/Startup.cs +++ b/samples/ReverseProxy.Metrics.Sample/Startup.cs @@ -44,6 +44,8 @@ public void ConfigureServices(IServiceCollection services) // Registration of a consumer to events for HttpClient telemetry // Note: this depends on changes implemented in .NET 5 services.AddTelemetryConsumer(); + + services.AddTelemetryConsumer(); } /// @@ -55,6 +57,9 @@ public void Configure(IApplicationBuilder app) // Placed at the beginning so it is the first and last thing run for each request app.UsePerRequestMetricCollection(); + // Middleware used to intercept the WebSocket connection and collect telemetry exposed to WebSocketsTelemetryConsumer + app.UseWebSocketsTelemetry(); + app.UseRouting(); app.UseEndpoints(endpoints => { diff --git a/samples/ReverseProxy.Metrics.Sample/WebSocketsTelemetryConsumer.cs b/samples/ReverseProxy.Metrics.Sample/WebSocketsTelemetryConsumer.cs new file mode 100644 index 000000000..fd9b066d2 --- /dev/null +++ b/samples/ReverseProxy.Metrics.Sample/WebSocketsTelemetryConsumer.cs @@ -0,0 +1,21 @@ +using System; +using Microsoft.Extensions.Logging; +using Yarp.Telemetry.Consumption; + +namespace Yarp.Sample +{ + public sealed class WebSocketsTelemetryConsumer : IWebSocketsTelemetryConsumer + { + private readonly ILogger _logger; + + public WebSocketsTelemetryConsumer(ILogger logger) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + } + + public void OnWebSocketClosed(DateTime timestamp, DateTime establishedTime, WebSocketCloseReason closeReason, long messagesRead, long messagesWritten) + { + _logger.LogInformation($"WebSocket connection closed ({closeReason}) after reading {messagesRead} and writing {messagesWritten} messages over {(timestamp - establishedTime).TotalSeconds:N2} seconds."); + } + } +} diff --git a/src/ReverseProxy/Forwarder/HttpForwarder.cs b/src/ReverseProxy/Forwarder/HttpForwarder.cs index 2d6c14ca0..919f37e47 100644 --- a/src/ReverseProxy/Forwarder/HttpForwarder.cs +++ b/src/ReverseProxy/Forwarder/HttpForwarder.cs @@ -576,7 +576,7 @@ private async ValueTask HandleUpgradedResponse(HttpContext conte var (secondResult, secondException) = await secondTask; if (secondResult != StreamCopyResult.Success) { - error = ReportResult(context, requestFinishedFirst, secondResult, secondException!); + error = ReportResult(context, !requestFinishedFirst, secondResult, secondException!); } else { diff --git a/test/ReverseProxy.Tests/Common/DelegatingStream.cs b/src/ReverseProxy/Utilities/DelegatingStream.cs similarity index 91% rename from test/ReverseProxy.Tests/Common/DelegatingStream.cs rename to src/ReverseProxy/Utilities/DelegatingStream.cs index 1466b620f..597f2a45a 100644 --- a/test/ReverseProxy.Tests/Common/DelegatingStream.cs +++ b/src/ReverseProxy/Utilities/DelegatingStream.cs @@ -1,5 +1,5 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. using System; using System.Diagnostics; @@ -7,8 +7,9 @@ using System.Threading; using System.Threading.Tasks; -namespace Yarp.Tests.Common +namespace Yarp.ReverseProxy.Utilities { + // Taken from https://github.com/dotnet/runtime/blob/00f37bc13b4edbba1afca9e98d74432a94f5192f/src/libraries/Common/src/System/IO/DelegatingStream.cs // Forwards all calls to an inner stream except where overridden in a derived class. internal abstract class DelegatingStream : Stream { @@ -113,9 +114,9 @@ public override ValueTask ReadAsync(Memory buffer, CancellationToken return _innerStream.ReadAsync(buffer, cancellationToken); } - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) { - return _innerStream.BeginRead(buffer, offset, count, callback, state); + return _innerStream.BeginRead(buffer, offset, count, callback!, state); } public override int EndRead(IAsyncResult asyncResult) @@ -167,9 +168,9 @@ public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationTo return _innerStream.WriteAsync(buffer, cancellationToken); } - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) { - return _innerStream.BeginWrite(buffer, offset, count, callback, state); + return _innerStream.BeginWrite(buffer, offset, count, callback!, state); } public override void EndWrite(IAsyncResult asyncResult) diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketCloseReason.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketCloseReason.cs new file mode 100644 index 000000000..949048cdf --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketCloseReason.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Yarp.ReverseProxy.WebSocketsTelemetry +{ + internal enum WebSocketCloseReason : int + { + Unknown, + ClientGracefulClose, + ServerGracefulClose, + ClientDisconnect, + ServerDisconnect, + } +} diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketsParser.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsParser.cs new file mode 100644 index 000000000..5959e0ccb --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsParser.cs @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Diagnostics; + +namespace Yarp.ReverseProxy.WebSocketsTelemetry +{ + internal unsafe struct WebSocketsParser + { + private const int MaskLength = 4; + private const int MinHeaderSize = 2; + private const int MaxHeaderSize = MinHeaderSize + MaskLength + sizeof(ulong); + + private fixed byte _leftoverBuffer[MaxHeaderSize - 1]; + private readonly byte _minHeaderSize; + private byte _leftover; + private ulong _bytesToSkip; + + public long MessageCount { get; private set; } + + public DateTime? CloseTime { get; private set; } + + public WebSocketsParser(bool isServer) + { + _minHeaderSize = (byte)(MinHeaderSize + (isServer ? MaskLength : 0)); + _leftover = 0; + _bytesToSkip = 0; + MessageCount = 0; + CloseTime = null; + } + + // The WebSocket Protocol: https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // +-------------------------------- - - - - - - - - - - - - - - - + + // : Payload Data continued ... : + // +---------------------------------------------------------------+ + // + // The header can be 2-10 bytes long, followed by a 4 byte mask if the message was sent by the client. + // We have to read the first 2 bytes to know how long the frame header will be. + // Since the buffer may not contain the full frame, we make use of a leftoverBuffer + // where we store leftover bytes that don't represent a complete frame header. + // On the next call to Consume, we interpret the leftover bytes as the beginning of the frame. + // As we are not interested in the actual payload data, we skip over (payload length + mask length) bytes after each header. + public void Consume(ReadOnlySpan buffer) + { + int leftover = _leftover; + var bytesToSkip = _bytesToSkip; + + while (true) + { + var toSkip = Math.Min(bytesToSkip, (ulong)buffer.Length); + buffer = buffer.Slice((int)toSkip); + bytesToSkip -= toSkip; + + var available = leftover + buffer.Length; + int headerSize = _minHeaderSize; + + if (available < headerSize) + { + break; + } + + var length = (leftover > 1 ? _leftoverBuffer[1] : buffer[1 - leftover]) & 0x7FUL; + + if (length > 125) + { + // The actual length will be encoded in 2 or 8 bytes, based on whether the length was 126 or 127 + var lengthBytes = 2 << (((int)length & 1) << 1); + headerSize += lengthBytes; + Debug.Assert(leftover < headerSize); + + if (available < headerSize) + { + break; + } + + lengthBytes += MinHeaderSize; + + length = 0; + for (var i = MinHeaderSize; i < lengthBytes; i++) + { + length <<= 8; + length |= i < leftover ? _leftoverBuffer[i] : buffer[i - leftover]; + } + } + + Debug.Assert(leftover < headerSize); + bytesToSkip = length; + + int header = leftover > 0 ? _leftoverBuffer[0] : buffer[0]; + + if ((header & 0xF) == 0x8) // CLOSE + { + CloseTime ??= DateTime.UtcNow; + } + else if ((header & 0x80) != 0) // FIN + { + MessageCount++; + } + + // Advance the buffer by the number of bytes read for the header, + // accounting for any bytes we may have read from the leftoverBuffer + buffer = buffer.Slice(headerSize - leftover); + leftover = 0; + } + + Debug.Assert(bytesToSkip == 0 || buffer.Length == 0); + _bytesToSkip = bytesToSkip; + + Debug.Assert(leftover + buffer.Length < MaxHeaderSize); + for (var i = 0; i < buffer.Length; i++, leftover++) + { + _leftoverBuffer[leftover] = buffer[i]; + } + + _leftover = (byte)leftover; + } + } +} diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetry.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetry.cs new file mode 100644 index 000000000..6fa10793e --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetry.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics.Tracing; + +namespace Yarp.ReverseProxy.WebSocketsTelemetry +{ + [EventSource(Name = "Yarp.ReverseProxy.WebSockets")] + internal sealed class WebSocketsTelemetry : EventSource + { + public static readonly WebSocketsTelemetry Log = new(); + + [Event(1, Level = EventLevel.Informational)] + public void WebSocketClosed(long establishedTime, WebSocketCloseReason closeReason, long messagesRead, long messagesWritten) + { + if (IsEnabled(EventLevel.Informational, EventKeywords.All)) + { + WriteEvent(eventId: 1, establishedTime, closeReason, messagesRead, messagesWritten); + } + } + } +} diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryExtensions.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryExtensions.cs new file mode 100644 index 000000000..21b41c344 --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryExtensions.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Yarp.ReverseProxy.WebSocketsTelemetry; + +namespace Microsoft.AspNetCore.Builder +{ + /// + /// extension methods to add the . + /// + public static class WebSocketsTelemetryExtensions + { + /// + /// Adds a to the request pipeline. + /// Must be added before . + /// + public static IApplicationBuilder UseWebSocketsTelemetry(this IApplicationBuilder app) + { + return app.UseMiddleware(); + } + } +} diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryMiddleware.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryMiddleware.cs new file mode 100644 index 000000000..3478514cc --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryMiddleware.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Diagnostics; +using System.IO; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; + +namespace Yarp.ReverseProxy.WebSocketsTelemetry +{ + internal sealed class WebSocketsTelemetryMiddleware + { + private readonly RequestDelegate _next; + + public WebSocketsTelemetryMiddleware(RequestDelegate next) + { + _next = next ?? throw new ArgumentNullException(nameof(next)); + } + + public Task Invoke(HttpContext context) + { + if (WebSocketsTelemetry.Log.IsEnabled()) + { + if (context.Features.Get() is { IsUpgradableRequest: true } upgradeFeature) + { + return InvokeAsyncCore(context, upgradeFeature, _next); + } + } + + return _next(context); + } + + private static async Task InvokeAsyncCore(HttpContext context, IHttpUpgradeFeature upgradeFeature, RequestDelegate next) + { + var upgradeWrapper = new HttpUpgradeFeatureWrapper(upgradeFeature); + context.Features.Set(upgradeWrapper); + + try + { + await next(context); + } + finally + { + if (upgradeWrapper.TelemetryStream is { } telemetryStream) + { + WebSocketsTelemetry.Log.WebSocketClosed( + telemetryStream.EstablishedTime.Ticks, + telemetryStream.GetCloseReason(context), + telemetryStream.MessagesRead, + telemetryStream.MessagesWritten); + } + + context.Features.Set(upgradeFeature); + } + } + + private sealed class HttpUpgradeFeatureWrapper : IHttpUpgradeFeature + { + private readonly IHttpUpgradeFeature _upgradeFeature; + + public WebSocketsTelemetryStream? TelemetryStream { get; private set; } + + public bool IsUpgradableRequest => _upgradeFeature.IsUpgradableRequest; + + public HttpUpgradeFeatureWrapper(IHttpUpgradeFeature upgradeFeature) + { + _upgradeFeature = upgradeFeature ?? throw new ArgumentNullException(nameof(upgradeFeature)); + } + + public async Task UpgradeAsync() + { + Debug.Assert(TelemetryStream is null); + var opaqueTransport = await _upgradeFeature.UpgradeAsync(); + TelemetryStream = new WebSocketsTelemetryStream(opaqueTransport); + return TelemetryStream; + } + } + } +} diff --git a/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryStream.cs b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryStream.cs new file mode 100644 index 000000000..300bed33c --- /dev/null +++ b/src/ReverseProxy/WebSocketsTelemetry/WebSocketsTelemetryStream.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Yarp.ReverseProxy.Forwarder; +using Yarp.ReverseProxy.Utilities; + +namespace Yarp.ReverseProxy.WebSocketsTelemetry +{ + internal sealed class WebSocketsTelemetryStream : DelegatingStream + { + private WebSocketsParser _readParser, _writeParser; + + public DateTime EstablishedTime { get; } + public long MessagesRead => _readParser.MessageCount; + public long MessagesWritten => _writeParser.MessageCount; + + public WebSocketsTelemetryStream(Stream innerStream) + : base(innerStream) + { + EstablishedTime = DateTime.UtcNow; + _readParser = new WebSocketsParser(isServer: true); + _writeParser = new WebSocketsParser(isServer: false); + } + + public WebSocketCloseReason GetCloseReason(HttpContext context) + { + var clientCloseTime = _readParser.CloseTime; + var serverCloseTime = _writeParser.CloseTime; + + // Mutual, graceful WebSocket close. We report whichever one we saw first. + if (clientCloseTime.HasValue && serverCloseTime.HasValue) + { + return clientCloseTime.Value < serverCloseTime.Value ? WebSocketCloseReason.ClientGracefulClose : WebSocketCloseReason.ServerGracefulClose; + } + + // One side sent a WebSocket close, but we never saw a response from the other side + // It is possible an error occurred, but we saw a graceful close first, so that is the intiator + if (clientCloseTime.HasValue) + { + return WebSocketCloseReason.ClientGracefulClose; + } + if (serverCloseTime.HasValue) + { + return WebSocketCloseReason.ServerGracefulClose; + } + + return context.Features.Get()?.Error switch + { + // Either side disconnected without sending a WebSocket close + ForwarderError.UpgradeRequestClient => WebSocketCloseReason.ClientDisconnect, + ForwarderError.UpgradeRequestCanceled => WebSocketCloseReason.ClientDisconnect, + ForwarderError.UpgradeResponseClient => WebSocketCloseReason.ClientDisconnect, + ForwarderError.UpgradeResponseCanceled => WebSocketCloseReason.ClientDisconnect, + ForwarderError.UpgradeRequestDestination => WebSocketCloseReason.ServerDisconnect, + ForwarderError.UpgradeResponseDestination => WebSocketCloseReason.ServerDisconnect, + + // Both sides gracefully closed the underlying connection without sending a WebSocket close + // Neither side is doing what we recognize as WebSockets ¯\_(ツ)_/¯ + null => WebSocketCloseReason.Unknown, + + // We are not expecting any other error from HttpForwarder after a successful connection upgrade + // Technically, a user could overwrite the IForwarderErrorFeature, in which case we don't know what's going on + _ => WebSocketCloseReason.Unknown + }; + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + var readTask = base.ReadAsync(buffer, cancellationToken); + + if (readTask.IsCompletedSuccessfully) + { + var read = readTask.GetAwaiter().GetResult(); + _readParser.Consume(buffer.Span.Slice(0, read)); + return new ValueTask(read); + } + else + { + return Core(buffer, readTask); + } + + async ValueTask Core(Memory buffer, ValueTask readTask) + { + var read = await readTask; + _readParser.Consume(buffer.Span.Slice(0, read)); + return read; + } + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + _writeParser.Consume(buffer.Span); + return base.WriteAsync(buffer, cancellationToken); + } + } +} diff --git a/src/TelemetryConsumption/TelemetryConsumptionExtensions.cs b/src/TelemetryConsumption/TelemetryConsumptionExtensions.cs index 95154289e..cc58dc529 100644 --- a/src/TelemetryConsumption/TelemetryConsumptionExtensions.cs +++ b/src/TelemetryConsumption/TelemetryConsumptionExtensions.cs @@ -11,15 +11,16 @@ public static class TelemetryConsumptionExtensions { #if NET /// - /// Registers all telemetry listeners (Proxy, Kestrel, Http, NameResolution, NetSecurity and Sockets). + /// Registers all telemetry listeners (Forwarder, Kestrel, Http, NameResolution, NetSecurity, Sockets and WebSockets). /// #else /// - /// Registers all telemetry listeners (Proxy and Kestrel). + /// Registers all telemetry listeners (Forwarder, Kestrel and WebSockets). /// #endif public static IServiceCollection AddTelemetryListeners(this IServiceCollection services) { + services.AddHostedService(); services.AddHostedService(); services.AddHostedService(); #if NET @@ -38,40 +39,46 @@ public static IServiceCollection AddTelemetryConsumer(this IServiceCollection se { var implementsAny = false; - if (consumer is IForwarderTelemetryConsumer) + if (consumer is IWebSocketsTelemetryConsumer webSocketsTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(IForwarderTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(webSocketsTelemetryConsumer)); implementsAny = true; } - if (consumer is IKestrelTelemetryConsumer) + if (consumer is IForwarderTelemetryConsumer forwarderTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(IKestrelTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(forwarderTelemetryConsumer)); + implementsAny = true; + } + + if (consumer is IKestrelTelemetryConsumer kestrelTelemetryConsumer) + { + services.TryAddEnumerable(ServiceDescriptor.Singleton(kestrelTelemetryConsumer)); implementsAny = true; } #if NET - if (consumer is IHttpTelemetryConsumer) + if (consumer is IHttpTelemetryConsumer httpTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(IHttpTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(httpTelemetryConsumer)); implementsAny = true; } - if (consumer is INameResolutionTelemetryConsumer) + if (consumer is INameResolutionTelemetryConsumer nameResolutionTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(INameResolutionTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(nameResolutionTelemetryConsumer)); implementsAny = true; } - if (consumer is INetSecurityTelemetryConsumer) + if (consumer is INetSecurityTelemetryConsumer netSecurityTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(INetSecurityTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(netSecurityTelemetryConsumer)); implementsAny = true; } - if (consumer is ISocketsTelemetryConsumer) + if (consumer is ISocketsTelemetryConsumer socketsTelemetryConsumer) { - services.TryAddEnumerable(new ServiceDescriptor(typeof(ISocketsTelemetryConsumer), consumer)); + services.TryAddEnumerable(ServiceDescriptor.Singleton(socketsTelemetryConsumer)); implementsAny = true; } #endif @@ -94,6 +101,12 @@ public static IServiceCollection AddTelemetryConsumer(this IServiceCo { var implementsAny = false; + if (typeof(IWebSocketsTelemetryConsumer).IsAssignableFrom(typeof(TConsumer))) + { + services.AddSingleton(services => (IWebSocketsTelemetryConsumer)services.GetRequiredService()); + implementsAny = true; + } + if (typeof(IForwarderTelemetryConsumer).IsAssignableFrom(typeof(TConsumer))) { services.AddSingleton(services => (IForwarderTelemetryConsumer)services.GetRequiredService()); diff --git a/src/TelemetryConsumption/WebSockets/IWebSocketsTelemetryConsumer.cs b/src/TelemetryConsumption/WebSockets/IWebSocketsTelemetryConsumer.cs new file mode 100644 index 000000000..798c5a3ea --- /dev/null +++ b/src/TelemetryConsumption/WebSockets/IWebSocketsTelemetryConsumer.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; + +namespace Yarp.Telemetry.Consumption +{ + /// + /// A consumer of Yarp.ReverseProxy.WebSockets EventSource events. + /// + public interface IWebSocketsTelemetryConsumer + { + /// + /// Called when a WebSockets connection is closed. + /// + /// Timestamp when the event was fired. + /// Timestamp when the connection upgrade completed. + /// The reason the WebSocket connection closed. + /// Messages read by the destination server. + /// Messages sent by the destination server. + void OnWebSocketClosed(DateTime timestamp, DateTime establishedTime, WebSocketCloseReason closeReason, long messagesRead, long messagesWritten); + } +} diff --git a/src/TelemetryConsumption/WebSockets/WebSocketCloseReason.cs b/src/TelemetryConsumption/WebSockets/WebSocketCloseReason.cs new file mode 100644 index 000000000..0c3d987f9 --- /dev/null +++ b/src/TelemetryConsumption/WebSockets/WebSocketCloseReason.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Yarp.Telemetry.Consumption +{ + /// + /// The reason the WebSocket connection closed. + /// + public enum WebSocketCloseReason : int + { + Unknown, + ClientGracefulClose, + ServerGracefulClose, + ClientDisconnect, + ServerDisconnect, + } +} diff --git a/src/TelemetryConsumption/WebSockets/WebSocketsEventListenerService.cs b/src/TelemetryConsumption/WebSockets/WebSocketsEventListenerService.cs new file mode 100644 index 000000000..fc54c0ee9 --- /dev/null +++ b/src/TelemetryConsumption/WebSockets/WebSocketsEventListenerService.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.Diagnostics.Tracing; +using Microsoft.Extensions.Logging; + +namespace Yarp.Telemetry.Consumption +{ + internal interface IWebSocketsMetricsConsumer { } + + internal sealed class WebSocketsEventListenerService : EventListenerService + { + protected override string EventSourceName => "Yarp.ReverseProxy.WebSockets"; + + public WebSocketsEventListenerService(ILogger logger, IEnumerable telemetryConsumers, IEnumerable metricsConsumers) + : base(logger, telemetryConsumers, metricsConsumers) + { } + + protected override void OnEventWritten(EventWrittenEventArgs eventData) + { + const int MinEventId = 1; + const int MaxEventId = 1; + + if (eventData.EventId < MinEventId || eventData.EventId > MaxEventId) + { + return; + } + + if (TelemetryConsumers is null) + { + return; + } + +#pragma warning disable IDE0007 // Use implicit type + // Explicit type here to drop the object? signature of payload elements + ReadOnlyCollection payload = eventData.Payload!; +#pragma warning restore IDE0007 // Use implicit type + + switch (eventData.EventId) + { + case 1: + Debug.Assert(eventData.EventName == "WebSocketClosed" && payload.Count == 4); + { + var establishedTime = new DateTime((long)payload[0]); + var closeReason = (WebSocketCloseReason)payload[1]; + var messagesRead = (long)payload[2]; + var messagesWritten = (long)payload[3]; + foreach (var consumer in TelemetryConsumers) + { + consumer.OnWebSocketClosed(eventData.TimeStamp, establishedTime, closeReason, messagesRead, messagesWritten); + } + } + break; + } + } + } +} diff --git a/test/ReverseProxy.FunctionalTests/WebSocketsTelemetryTests.cs b/test/ReverseProxy.FunctionalTests/WebSocketsTelemetryTests.cs new file mode 100644 index 000000000..6abc8f475 --- /dev/null +++ b/test/ReverseProxy.FunctionalTests/WebSocketsTelemetryTests.cs @@ -0,0 +1,308 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Net.Http; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Xunit; +using Yarp.ReverseProxy.Common; +using Yarp.Telemetry.Consumption; + +namespace Yarp.ReverseProxy +{ + public class WebSocketsTelemetryTests + { + [Fact] + public async Task NoWebSocketsUpgrade_NoTelemetryWritten() + { + var telemetry = await TestAsync( + async uri => + { + using var client = new HttpClient(); + await client.GetStringAsync(uri); + }, + (context, webSocket) => throw new InvalidOperationException("Shouldn't be reached")); + + Assert.Null(telemetry); + } + + [Theory] + [InlineData(0, 0, 42)] + [InlineData(0, 1, 42)] + [InlineData(1, 0, 42)] + [InlineData(23, 29, 0)] + [InlineData(17, 19, 1)] + [InlineData(11, 13, 100)] + [InlineData(5, 7, 1_000)] + [InlineData(2, 3, 100_000)] + public async Task MessagesExchanged_CorrectNumberReported(int read, int written, int messageSize) + { + var startTime = DateTime.UtcNow; + + var telemetry = await TestAsync( + async uri => + { + using var client = new ClientWebSocket(); + await client.ConnectAsync(uri, CancellationToken.None); + var webSocket = new WebSocketAdapter(client); + + await Task.WhenAll( + SendMessagesAndCloseAsync(webSocket, read, messageSize), + ReceiveAllMessagesAsync(webSocket)); + }, + async (context, webSocket) => + { + await Task.WhenAll( + SendMessagesAndCloseAsync(webSocket, written, messageSize), + ReceiveAllMessagesAsync(webSocket)); + }); + + Assert.NotNull(telemetry); + Assert.InRange(telemetry!.EstablishedTime, startTime, telemetry.Timestamp); + Assert.Contains(telemetry.CloseReason, new[] { WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerGracefulClose }); + Assert.Equal(read, telemetry!.MessagesRead); + Assert.Equal(written, telemetry.MessagesWritten); + } + + public enum Behavior + { + ClosesConnection = 1, + SendsClose_WaitsForClose = 2, + SendsClose_ClosesConnection = 4 | ClosesConnection, + WaitsForClose_SendsClose = 8, + WaitsForClose_ClosesConnection = 16 | ClosesConnection, + } + + [Theory] + // Both sides close the connection - race between which is noticed first + [InlineData(Behavior.ClosesConnection, Behavior.ClosesConnection, WebSocketCloseReason.Unknown, WebSocketCloseReason.ClientDisconnect, WebSocketCloseReason.ServerDisconnect)] + // One side sends a graceful close + [InlineData(Behavior.SendsClose_ClosesConnection, Behavior.WaitsForClose_ClosesConnection, WebSocketCloseReason.ClientGracefulClose)] + [InlineData(Behavior.SendsClose_WaitsForClose, Behavior.WaitsForClose_ClosesConnection, WebSocketCloseReason.ClientGracefulClose)] + [InlineData(Behavior.WaitsForClose_ClosesConnection, Behavior.SendsClose_ClosesConnection, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(Behavior.WaitsForClose_ClosesConnection, Behavior.SendsClose_WaitsForClose, WebSocketCloseReason.ServerGracefulClose)] + // One side sends a graceful close while the other disconnects - race between which is noticed first + [InlineData(Behavior.SendsClose_WaitsForClose, Behavior.ClosesConnection, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerDisconnect)] + [InlineData(Behavior.SendsClose_ClosesConnection, Behavior.ClosesConnection, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerDisconnect)] + [InlineData(Behavior.ClosesConnection, Behavior.SendsClose_ClosesConnection, WebSocketCloseReason.ServerGracefulClose, WebSocketCloseReason.ClientDisconnect)] + [InlineData(Behavior.ClosesConnection, Behavior.SendsClose_WaitsForClose, WebSocketCloseReason.ServerGracefulClose, WebSocketCloseReason.ClientDisconnect)] + // One side closes the connection while the other is waiting for messages + [InlineData(Behavior.ClosesConnection, Behavior.WaitsForClose_SendsClose, WebSocketCloseReason.ClientDisconnect)] + [InlineData(Behavior.ClosesConnection, Behavior.WaitsForClose_ClosesConnection, WebSocketCloseReason.ClientDisconnect)] + [InlineData(Behavior.WaitsForClose_SendsClose, Behavior.ClosesConnection, WebSocketCloseReason.ServerDisconnect)] + [InlineData(Behavior.WaitsForClose_ClosesConnection, Behavior.ClosesConnection, WebSocketCloseReason.ServerDisconnect)] + // Graceful, mutual close - other side closes as a reaction to receiving close + [InlineData(Behavior.SendsClose_WaitsForClose, Behavior.WaitsForClose_SendsClose, WebSocketCloseReason.ClientGracefulClose)] + [InlineData(Behavior.SendsClose_ClosesConnection, Behavior.WaitsForClose_SendsClose, WebSocketCloseReason.ClientGracefulClose)] + [InlineData(Behavior.WaitsForClose_SendsClose, Behavior.SendsClose_WaitsForClose, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(Behavior.WaitsForClose_SendsClose, Behavior.SendsClose_ClosesConnection, WebSocketCloseReason.ServerGracefulClose)] + // Graceful, mutual close - both sides close at the same time - race between which is noticed first + [InlineData(Behavior.SendsClose_WaitsForClose, Behavior.SendsClose_WaitsForClose, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(Behavior.SendsClose_WaitsForClose, Behavior.SendsClose_ClosesConnection, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(Behavior.SendsClose_ClosesConnection, Behavior.SendsClose_WaitsForClose, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerGracefulClose)] + [InlineData(Behavior.SendsClose_ClosesConnection, Behavior.SendsClose_ClosesConnection, WebSocketCloseReason.ClientGracefulClose, WebSocketCloseReason.ServerGracefulClose)] + public async Task ConnectionClosed_BlameAttributedCorrectly(Behavior clientBehavior, Behavior serverBehavior, params WebSocketCloseReason[] expectedReasons) + { + var telemetry = await TestAsync( + async uri => + { + using var client = new ClientWebSocket(); + + // Keep sending messages from the client in order to observe a server disconnect sooner + client.Options.KeepAliveInterval = TimeSpan.FromMilliseconds(10); + + await client.ConnectAsync(uri, CancellationToken.None); + var webSocket = new WebSocketAdapter(client); + + try + { + await ProcessAsync(webSocket, clientBehavior, client: client); + } + catch + { + Assert.True(serverBehavior.HasFlag(Behavior.ClosesConnection)); + } + }, + async (context, webSocket) => + { + try + { + await ProcessAsync(webSocket, serverBehavior, context: context); + } + catch + { + Assert.True(clientBehavior.HasFlag(Behavior.ClosesConnection)); + } + }); + + Assert.NotNull(telemetry); + Assert.Contains(telemetry!.CloseReason, expectedReasons); + + static async Task ProcessAsync(WebSocketAdapter webSocket, Behavior behavior, ClientWebSocket? client = null, HttpContext? context = null) + { + if (behavior == Behavior.SendsClose_WaitsForClose || + behavior == Behavior.SendsClose_ClosesConnection) + { + await webSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Bye"); + } + + if (behavior == Behavior.SendsClose_WaitsForClose || + behavior == Behavior.WaitsForClose_SendsClose || + behavior == Behavior.WaitsForClose_ClosesConnection) + { + await ReceiveAllMessagesAsync(webSocket); + } + + if (behavior == Behavior.WaitsForClose_SendsClose) + { + await webSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Bye"); + } + + if (behavior == Behavior.SendsClose_ClosesConnection || + behavior == Behavior.WaitsForClose_ClosesConnection || + behavior == Behavior.ClosesConnection) + { + client?.Abort(); + + if (context is not null) + { + await context.Response.Body.FlushAsync(); + context.Abort(); + } + } + } + } + + private static async Task ReceiveAllMessagesAsync(WebSocketAdapter webSocket) + { + Memory buffer = new byte[1024]; + + while (true) + { + var result = await webSocket.ReceiveAsync(buffer); + + if (result.MessageType == WebSocketMessageType.Close) + { + break; + } + } + } + + private static async Task SendMessagesAndCloseAsync(WebSocketAdapter webSocket, int messageCount, int messageSize) + { + var rng = new Random(42); + var buffer = new byte[1024]; + + for (var i = 0; i < messageCount; i++) + { + var remaining = messageSize; + + while (remaining > 1) + { + var chunkSize = Math.Min(buffer.Length, remaining - 1); + remaining -= chunkSize; + var chunk = buffer.AsMemory(0, chunkSize); + rng.NextBytes(chunk.Span); + await webSocket.SendAsync(chunk, WebSocketMessageType.Binary, endOfMessage: false); + } + + await webSocket.SendAsync(buffer.AsMemory(0, remaining), WebSocketMessageType.Binary, endOfMessage: true); + } + + await webSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Bye", CancellationToken.None); + } + + private class WebSocketAdapter + { + private readonly ClientWebSocket? _client; + private readonly WebSocket? _server; + + public WebSocketAdapter(ClientWebSocket? client = null, WebSocket? server = null) + { + Assert.True(client is null ^ server is null); + _client = client; + _server = server; + } + + public ValueTask ReceiveAsync(Memory buffer, CancellationToken cancellationToken = default) + { + return _client is not null + ? _client.ReceiveAsync(buffer, cancellationToken) + : _server!.ReceiveAsync(buffer, cancellationToken); + } + + public ValueTask SendAsync(ReadOnlyMemory buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken = default) + { + return _client is not null + ? _client.SendAsync(buffer, messageType, endOfMessage, cancellationToken) + : _server!.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + } + + public Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken = default) + { + return _client is not null + ? _client.CloseOutputAsync(closeStatus, statusDescription, cancellationToken) + : _server!.CloseOutputAsync(closeStatus, statusDescription, cancellationToken); + } + } + + private static async Task TestAsync(Func requestDelegate, Func destinationDelegate) + { + var telemetryConsumer = new TelemetryConsumer(); + + var test = new TestEnvironment( + destinationServies => { }, + destinationApp => + { + destinationApp.UseWebSockets(); + + destinationApp.Run(async context => + { + if (context.WebSockets.IsWebSocketRequest) + { + var webSocket = await context.WebSockets.AcceptWebSocketAsync(); + + await destinationDelegate(context, new WebSocketAdapter(server: webSocket)); + } + }); + }, + proxyBuilder => + { + proxyBuilder.Services.AddTelemetryConsumer(telemetryConsumer); + }, + proxyApp => + { + proxyApp.UseWebSocketsTelemetry(); + }); + + await test.Invoke(async uri => + { + var webSocketsTarget = uri.Replace("https://", "wss://").Replace("http://", "ws://"); + var webSocketsUri = new Uri(new Uri(webSocketsTarget, UriKind.Absolute), "websockets"); + + await requestDelegate(webSocketsUri); + }); + + return telemetryConsumer.Telemetry; + } + + private record WebSocketsTelemetry(DateTime Timestamp, DateTime EstablishedTime, WebSocketCloseReason CloseReason, long MessagesRead, long MessagesWritten); + + private class TelemetryConsumer : IWebSocketsTelemetryConsumer + { + public WebSocketsTelemetry? Telemetry { get; private set; } + + public void OnWebSocketClosed(DateTime timestamp, DateTime establishedTime, WebSocketCloseReason closeReason, long messagesRead, long messagesWritten) + { + Telemetry = new WebSocketsTelemetry(timestamp, establishedTime, closeReason, messagesRead, messagesWritten); + } + } + } +} diff --git a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs index 8836f95f9..bdd330174 100644 --- a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs @@ -19,6 +19,7 @@ using Moq; using Xunit; using Yarp.Tests.Common; +using Yarp.ReverseProxy.Utilities; namespace Yarp.ReverseProxy.Forwarder.Tests {