Skip to content

Add support for customising the creation of Kestrel listen sockets #32827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 28, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportFactory.Bin
~Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportFactory.SocketTransportFactory(Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions!>! options, Microsoft.Extensions.Logging.ILoggerFactory! loggerFactory) -> void
static Microsoft.AspNetCore.Hosting.WebHostBuilderSocketExtensions.UseSockets(this Microsoft.AspNetCore.Hosting.IWebHostBuilder! hostBuilder) -> Microsoft.AspNetCore.Hosting.IWebHostBuilder!
static Microsoft.AspNetCore.Hosting.WebHostBuilderSocketExtensions.UseSockets(this Microsoft.AspNetCore.Hosting.IWebHostBuilder! hostBuilder, System.Action<Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions!>! configureOptions) -> Microsoft.AspNetCore.Hosting.IWebHostBuilder!
static Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateDefaultBoundListenSocket(System.Net.EndPoint! endpoint) -> System.Net.Sockets.Socket!
Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.get -> System.Func<System.Net.EndPoint!, System.Net.Sockets.Socket!>!
Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.SocketTransportOptions.CreateBoundListenSocket.set -> void
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Buffers;
using System.ComponentModel;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Net;
Expand All @@ -23,7 +24,6 @@ internal sealed class SocketConnectionListener : IConnectionListener
private Socket? _listenSocket;
private int _settingsIndex;
private readonly SocketTransportOptions _options;
private SafeSocketHandle? _socketHandle;

public EndPoint EndPoint { get; private set; }

Expand Down Expand Up @@ -92,43 +92,13 @@ internal void Bind()
}

Socket listenSocket;

switch (EndPoint)
try
{
case FileHandleEndPoint fileHandle:
_socketHandle = new SafeSocketHandle((IntPtr)fileHandle.FileHandle, ownsHandle: true);
listenSocket = new Socket(_socketHandle);
break;
case UnixDomainSocketEndPoint unix:
listenSocket = new Socket(unix.AddressFamily, SocketType.Stream, ProtocolType.Unspecified);
BindSocket();
break;
case IPEndPoint ip:
listenSocket = new Socket(ip.AddressFamily, SocketType.Stream, ProtocolType.Tcp);

// Kestrel expects IPv6Any to bind to both IPv6 and IPv4
if (ip.Address == IPAddress.IPv6Any)
{
listenSocket.DualMode = true;
}
BindSocket();
break;
default:
listenSocket = new Socket(EndPoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
BindSocket();
break;
listenSocket = _options.CreateBoundListenSocket(EndPoint);
}

void BindSocket()
catch (SocketException e) when (e.SocketErrorCode == SocketError.AddressAlreadyInUse)
{
try
{
listenSocket.Bind(EndPoint);
}
catch (SocketException e) when (e.SocketErrorCode == SocketError.AddressAlreadyInUse)
{
throw new AddressInUseException(e.Message, e);
}
throw new AddressInUseException(e.Message, e);
}

Debug.Assert(listenSocket.LocalEndPoint != null);
Expand Down Expand Up @@ -193,17 +163,13 @@ void BindSocket()
public ValueTask UnbindAsync(CancellationToken cancellationToken = default)
{
_listenSocket?.Dispose();

_socketHandle?.Dispose();
return default;
}

public ValueTask DisposeAsync()
{
_listenSocket?.Dispose();

_socketHandle?.Dispose();

// Dispose the memory pool
_memoryPool.Dispose();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

using System;
using System.Buffers;
using System.Net;
using System.Net.Sockets;
using Microsoft.AspNetCore.Connections;

namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets
{
Expand Down Expand Up @@ -65,6 +68,78 @@ public class SocketTransportOptions
/// </remarks>
public bool UnsafePreferInlineScheduling { get; set; }

/// <summary>
/// A function used to create a new <see cref="Socket"/> to listen with. If
/// not set, <see cref="CreateDefaultBoundListenSocket" /> is used.
/// </summary>
/// <remarks>
/// Implementors are expected to call <see cref="Socket.Bind"/> on the
/// <see cref="Socket"/>. Please note that <see cref="CreateDefaultBoundListenSocket"/>
/// calls <see cref="Socket.Bind"/> as part of its implementation, so implementors
/// using this method do not need to call it again.
/// </remarks>
public Func<EndPoint, Socket> CreateBoundListenSocket { get; set; } = CreateDefaultBoundListenSocket;

/// <summary>
/// Creates a default instance of <see cref="Socket"/> for the given <see cref="EndPoint"/>
/// that can be used by a connection listener to listen for inbound requests. <see cref="Socket.Bind"/>
/// is called by this method.
/// </summary>
/// <param name="endpoint">
/// An <see cref="EndPoint"/>.
/// </param>
/// <returns>
/// A <see cref="Socket"/> instance.
/// </returns>
public static Socket CreateDefaultBoundListenSocket(EndPoint endpoint)
{
Socket listenSocket;
switch (endpoint)
{
case FileHandleEndPoint fileHandle:
// We're passing "ownsHandle: true" here even though we don't necessarily
// own the handle because Socket.Dispose will clean-up everything safely.
// If the handle was already closed or disposed then the socket will
// be torn down gracefully, and if the caller never cleans up their handle
// then we'll do it for them.
//
// If we don't do this then we run the risk of Kestrel hanging because the
// the underlying socket is never closed and the transport manager can hang
// when it attempts to stop.
listenSocket = new Socket(
new SafeSocketHandle((IntPtr)fileHandle.FileHandle, ownsHandle: true)
);
break;
case UnixDomainSocketEndPoint unix:
listenSocket = new Socket(unix.AddressFamily, SocketType.Stream, ProtocolType.Unspecified);
break;
case IPEndPoint ip:
listenSocket = new Socket(ip.AddressFamily, SocketType.Stream, ProtocolType.Tcp);

// Kestrel expects IPv6Any to bind to both IPv6 and IPv4
if (ip.Address == IPAddress.IPv6Any)
{
listenSocket.DualMode = true;
}

break;
default:
listenSocket = new Socket(endpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
break;
}

// we only call Bind on sockets that were _not_ created
// using a file handle; the handle is already bound
// to an underlying socket so doing it again causes the
// underlying PAL call to throw
if (!(endpoint is FileHandleEndPoint))
{
listenSocket.Bind(endpoint);
}

return listenSocket;
}

internal Func<MemoryPool<byte>> MemoryPoolFactory { get; set; } = System.Buffers.PinnedBlockMemoryPoolFactory.Create;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Server.Kestrel.FunctionalTests;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Hosting;
using Xunit;

namespace Sockets.BindTests
{
public class SocketTransportOptionsTests : LoggedTestBase
{
[Theory]
[MemberData(nameof(GetEndpoints))]
public async Task SocketTransportCallsCreateBoundListenSocket(EndPoint endpointToTest)
{
var wasCalled = false;

Socket CreateListenSocket(EndPoint endpoint)
{
wasCalled = true;
return SocketTransportOptions.CreateDefaultBoundListenSocket(endpoint);
}

using var host = CreateWebHost(
endpointToTest,
options =>
{
options.CreateBoundListenSocket = CreateListenSocket;
}
);

await host.StartAsync();
Assert.True(wasCalled, $"Expected {nameof(SocketTransportOptions.CreateBoundListenSocket)} to be called.");
await host.StopAsync();
}

[Theory]
[MemberData(nameof(GetEndpoints))]
public void CreateDefaultBoundListenSocket_BindsForAllEndPoints(EndPoint endpoint)
{
using var listenSocket = SocketTransportOptions.CreateDefaultBoundListenSocket(endpoint);
Assert.NotNull(listenSocket.LocalEndPoint);
}

// static to ensure that the underlying handle doesn't get disposed
// when a local reference is GCed by the iterator in GetEndPoints
private static Socket _fileHandleSocket;

public static IEnumerable<object[]> GetEndpoints()
{
// IPv4
yield return new object[] {new IPEndPoint(IPAddress.Loopback, 0)};
// IPv6
yield return new object[] {new IPEndPoint(IPAddress.IPv6Loopback, 0)};
// Unix sockets
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
yield return new object[]
{
new UnixDomainSocketEndPoint($"/tmp/{DateTime.UtcNow:yyyyMMddTHHmmss.fff}.sock")
};
}

// file handle
// slightly messy but allows us to create a FileHandleEndPoint
// from the underlying OS handle used by the socket
_fileHandleSocket = new(
AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp
);
_fileHandleSocket.Bind(new IPEndPoint(IPAddress.Loopback, 0));
yield return new object[]
{
new FileHandleEndPoint((ulong) _fileHandleSocket.Handle, FileHandleType.Auto)
};

// TODO: other endpoint types?
}

private IHost CreateWebHost(EndPoint endpoint, Action<SocketTransportOptions> configureSocketOptions) =>
TransportSelector.GetHostBuilder()
.ConfigureWebHost(
webHostBuilder =>
{
webHostBuilder
.UseSockets(configureSocketOptions)
.UseKestrel(options => options.Listen(endpoint))
.Configure(
app => app.Run(ctx => ctx.Response.WriteAsync("Hello World"))
);
}
)
.ConfigureServices(AddTestLogging)
.Build();
}
}