Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/Grpc.Net.Client/Balancer/Internal/BalancerHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
using System;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Grpc.Shared;
using Microsoft.Extensions.Logging;

namespace Grpc.Net.Client.Balancer.Internal;

Expand All @@ -37,11 +39,13 @@ internal class BalancerHttpHandler : DelegatingHandler
internal const string IsSocketsHttpHandlerSetupKey = "IsSocketsHttpHandlerSetup";

private readonly ConnectionManager _manager;
private readonly ILogger _logger;

public BalancerHttpHandler(HttpMessageHandler innerHandler, ConnectionManager manager)
: base(innerHandler)
{
_manager = manager;
_logger = manager.LoggerFactory.CreateLogger<BalancerHttpHandler>();
}

internal static bool IsSocketsHttpHandlerSetup(SocketsHttpHandler socketsHttpHandler)
Expand All @@ -54,7 +58,9 @@ value is bool isEnabled &&
}
}

internal static void ConfigureSocketsHttpHandlerSetup(SocketsHttpHandler socketsHttpHandler)
internal static void ConfigureSocketsHttpHandlerSetup(
SocketsHttpHandler socketsHttpHandler,
Func<SocketsHttpConnectionContext, CancellationToken, ValueTask<Stream>> connectCallback)
{
// We're modifying the SocketsHttpHandler and nothing prevents two threads from creating a
// channel with the same handler on different threads.
Expand All @@ -67,15 +73,17 @@ internal static void ConfigureSocketsHttpHandlerSetup(SocketsHttpHandler sockets
{
Debug.Assert(socketsHttpHandler.ConnectCallback == null, "ConnectCallback should be null to get to this point.");

socketsHttpHandler.ConnectCallback = OnConnect;
socketsHttpHandler.ConnectCallback = connectCallback;
socketsHttpHandler.Properties[IsSocketsHttpHandlerSetupKey] = true;
}
}
}

#if NET5_0_OR_GREATER
private static async ValueTask<Stream> OnConnect(SocketsHttpConnectionContext context, CancellationToken cancellationToken)
internal async ValueTask<Stream> OnConnect(SocketsHttpConnectionContext context, CancellationToken cancellationToken)
{
Log.StartingConnectCallback(_logger, context.DnsEndPoint);

if (!context.InitialRequestMessage.TryGetOption<Subchannel>(SubchannelKey, out var subchannel))
{
throw new InvalidOperationException($"Unable to get subchannel from {nameof(HttpRequestMessage)}.");
Expand Down Expand Up @@ -133,6 +141,7 @@ protected override async Task<HttpResponseMessage> SendAsync(
request.SetOption(CurrentAddressKey, address);
#endif

Log.SendingRequest(_logger, request.RequestUri);
var responseMessageTask = base.SendAsync(request, cancellationToken);
result.SubchannelCallTracker?.Start();

Expand Down Expand Up @@ -161,5 +170,27 @@ protected override async Task<HttpResponseMessage> SendAsync(
throw;
}
}

internal static class Log
{
private static readonly Action<ILogger, Uri, Exception?> _sendingRequest =
LoggerMessage.Define<Uri>(LogLevel.Trace, new EventId(1, "SendingRequest"), "Sending request {RequestUri}.");

private static readonly Action<ILogger, string, Exception?> _startingConnectCallback =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(2, "StartingConnectCallback"), "Starting connect callback for {Endpoint}.");

public static void SendingRequest(ILogger logger, Uri requestUri)
{
_sendingRequest(logger, requestUri, null);
}

public static void StartingConnectCallback(ILogger logger, DnsEndPoint endpoint)
{
if (logger.IsEnabled(LogLevel.Trace))
{
_startingConnectCallback(logger, $"{endpoint.Host}:{endpoint.Port}", null);
}
}
}
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,9 @@ public SocketConnectivitySubchannelTransport(
_socketConnectedTimer = new Timer(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
}

public object Lock => _subchannel.Lock;
private object Lock => _subchannel.Lock;
public BalancerAddress? CurrentAddress => _currentAddress;
public TimeSpan? ConnectTimeout { get; }
public bool HasStream { get; }

// For testing. Take a copy under lock for thread-safety.
internal IReadOnlyList<ActiveStream> GetActiveStreams()
Expand Down Expand Up @@ -264,13 +263,21 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
Socket? socket = null;
lock (Lock)
{
if (_initialSocket != null &&
_initialSocketAddress != null &&
Equals(_initialSocketAddress, address))
if (_initialSocket != null)
{
var socketAddressMatch = Equals(_initialSocketAddress, address);

socket = _initialSocket;
_initialSocket = null;
_initialSocketAddress = null;

// Double check the address matches the socket address and only use socket on match.
// Not sure if this is possible in practice, but better safe than sorry.
if (!socketAddressMatch)
{
socket.Dispose();
socket = null;
}
}
}

Expand All @@ -288,6 +295,8 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat

if (socket == null)
{
SocketConnectivitySubchannelTransportLog.ConnectingOnCreateStream(_logger, _subchannel.Id, address);

socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
await socket.ConnectAsync(address.EndPoint, cancellationToken).ConfigureAwait(false);
}
Expand All @@ -300,6 +309,7 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
lock (Lock)
{
_activeStreams.Add(new ActiveStream(address, socket, stream));
SocketConnectivitySubchannelTransportLog.StreamCreated(_logger, _subchannel.Id, address, _activeStreams.Count);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging inside the lock troubles me; is this intentional? would it perhaps be better to snag the Count inside the lock and log outside?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a concern. There is logging inside locks throughout gRPC and ASP.NET Core.

}

return stream;
Expand Down Expand Up @@ -331,7 +341,7 @@ private void OnStreamDisposed(Stream streamWrapper)
if (t.Stream == streamWrapper)
{
_activeStreams.RemoveAt(i);
SocketConnectivitySubchannelTransportLog.DisposingStream(_logger, _subchannel.Id, t.Address);
SocketConnectivitySubchannelTransportLog.DisposingStream(_logger, _subchannel.Id, t.Address, _activeStreams.Count);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm, maybe ignore my last comment about lock; I can't see a nice / clean way of avoiding this one, so... if we're in that situation anyway, maybe just live with it


// If the last active streams is removed then there is no active connection.
disconnect = _activeStreams.Count == 0;
Expand Down Expand Up @@ -399,15 +409,21 @@ internal static class SocketConnectivitySubchannelTransportLog
private static readonly Action<ILogger, int, BalancerAddress, Exception?> _creatingStream =
LoggerMessage.Define<int, BalancerAddress>(LogLevel.Trace, new EventId(7, "CreatingStream"), "Subchannel id '{SubchannelId}' creating stream for {Address}.");

private static readonly Action<ILogger, int, BalancerAddress, Exception?> _disposingStream =
LoggerMessage.Define<int, BalancerAddress>(LogLevel.Trace, new EventId(8, "DisposingStream"), "Subchannel id '{SubchannelId}' disposing stream for {Address}.");
private static readonly Action<ILogger, int, BalancerAddress, int, Exception?> _disposingStream =
LoggerMessage.Define<int, BalancerAddress, int>(LogLevel.Trace, new EventId(8, "DisposingStream"), "Subchannel id '{SubchannelId}' disposing stream for {Address}. Transport has {ActiveStreams} active streams.");

private static readonly Action<ILogger, int, Exception?> _disposingTransport =
LoggerMessage.Define<int>(LogLevel.Trace, new EventId(9, "DisposingTransport"), "Subchannel id '{SubchannelId}' disposing transport.");

private static readonly Action<ILogger, int, Exception> _errorOnDisposingStream =
LoggerMessage.Define<int>(LogLevel.Error, new EventId(10, "ErrorOnDisposingStream"), "Subchannel id '{SubchannelId}' unexpected error when reacting to transport stream dispose.");

private static readonly Action<ILogger, int, BalancerAddress, Exception?> _connectingOnCreateStream =
LoggerMessage.Define<int, BalancerAddress>(LogLevel.Trace, new EventId(11, "ConnectingOnCreateStream"), "Subchannel id '{SubchannelId}' doesn't have a connected socket available. Connecting new stream socket for {Address}.");

private static readonly Action<ILogger, int, BalancerAddress, int, Exception?> _streamCreated =
LoggerMessage.Define<int, BalancerAddress, int>(LogLevel.Trace, new EventId(12, "StreamCreated"), "Subchannel id '{SubchannelId}' created stream for {Address}. Transport has {ActiveStreams} active streams.");

public static void ConnectingSocket(ILogger logger, int subchannelId, BalancerAddress address)
{
_connectingSocket(logger, subchannelId, address, null);
Expand Down Expand Up @@ -443,9 +459,9 @@ public static void CreatingStream(ILogger logger, int subchannelId, BalancerAddr
_creatingStream(logger, subchannelId, address, null);
}

public static void DisposingStream(ILogger logger, int subchannelId, BalancerAddress address)
public static void DisposingStream(ILogger logger, int subchannelId, BalancerAddress address, int activeStreams)
{
_disposingStream(logger, subchannelId, address, null);
_disposingStream(logger, subchannelId, address, activeStreams, null);
}

public static void DisposingTransport(ILogger logger, int subchannelId)
Expand All @@ -457,6 +473,16 @@ public static void ErrorOnDisposingStream(ILogger logger, int subchannelId, Exce
{
_errorOnDisposingStream(logger, subchannelId, ex);
}

public static void ConnectingOnCreateStream(ILogger logger, int subchannelId, BalancerAddress address)
{
_connectingOnCreateStream(logger, subchannelId, address, null);
}

public static void StreamCreated(ILogger logger, int subchannelId, BalancerAddress address, int activeStreams)
{
_streamCreated(logger, subchannelId, address, activeStreams, null);
}
}
#endif
#endif
22 changes: 20 additions & 2 deletions src/Grpc.Net.Client/Balancer/Internal/StreamWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@

#if SUPPORT_LOAD_BALANCING
using System;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace Grpc.Net.Client.Balancer.Internal;

internal class StreamWrapper : Stream
internal sealed class StreamWrapper : Stream
{
private readonly Stream _inner;
private readonly Action<Stream> _onDispose;
private bool _disposed;

public StreamWrapper(Stream inner, Action<Stream> onDispose)
{
Expand Down Expand Up @@ -86,13 +88,29 @@ public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) =>
_inner.CopyToAsync(destination, bufferSize, cancellationToken);

public override async ValueTask DisposeAsync()
{
await base.DisposeAsync().ConfigureAwait(false);

// Avoid invoking dispose callback multiple times.
if (_disposed)
{
_onDispose(this);
await _inner.DisposeAsync().ConfigureAwait(false);
_disposed = true;
}
}

protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
if (disposing)

// Avoid invoking dispose callback multiple times.
if (disposing && !_disposed)
{
_onDispose(this);
_inner.Dispose();
_disposed = true;
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions src/Grpc.Net.Client/GrpcChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -456,15 +456,16 @@ private HttpMessageInvoker CreateInternalHttpInvoker(HttpMessageHandler? handler
#endif

#if SUPPORT_LOAD_BALANCING
BalancerHttpHandler balancerHttpHandler;
handler = balancerHttpHandler = new BalancerHttpHandler(handler, ConnectionManager);

if (HttpHandlerType == HttpHandlerType.SocketsHttpHandler)
{
var socketsHttpHandler = HttpRequestHelpers.GetHttpHandlerType<SocketsHttpHandler>(handler);
CompatibilityHelpers.Assert(socketsHttpHandler != null, "Should have handler with this handler type.");

BalancerHttpHandler.ConfigureSocketsHttpHandlerSetup(socketsHttpHandler);
BalancerHttpHandler.ConfigureSocketsHttpHandlerSetup(socketsHttpHandler, balancerHttpHandler.OnConnect);
}

handler = new BalancerHttpHandler(handler, ConnectionManager);
#endif

// Use HttpMessageInvoker instead of HttpClient because it is faster
Expand Down