Skip to content

Commit

Permalink
[#590] ConnectionFactory: support CancellationToken
Browse files Browse the repository at this point in the history
  • Loading branch information
xinchen10 committed Sep 5, 2024
1 parent 1362ac5 commit d5a0fcb
Show file tree
Hide file tree
Showing 12 changed files with 154 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
<StartupObject />
</PropertyGroup>
<ItemGroup>
<Compile Include="..\..\..\test\Common\Extensions.cs">
<Link>Extensions.cs</Link>
</Compile>
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="..\..\..\test\Common\TestAmqpBroker.cs">
<Link>TestAmqpBroker.cs</Link>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<Compile Include="..\..\..\test\Common\Extensions.cs">
<Link>Extensions.cs</Link>
</Compile>
<Compile Include="Program.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
Expand Down
36 changes: 2 additions & 34 deletions Examples/PeerToPeer/PeerToPeer.Certificate/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ static void Main(string[] args)
Console.WriteLine("Starting server...");
ContainerHost host = new ContainerHost(address);
var listener = host.Listeners[0];
listener.SSL.Certificate = GetCertificate("localhost");
listener.SSL.Certificate = Test.Common.Extensions.GetCertificate("localhost");
listener.SSL.ClientCertificateRequired = true;
listener.SSL.RemoteCertificateValidationCallback = ValidateServerCertificate;
listener.SASL.EnableExternalMechanism = true;
Expand All @@ -50,7 +50,7 @@ static void Main(string[] args)

Console.WriteLine("Starting client...");
ConnectionFactory factory = new ConnectionFactory();
factory.SSL.ClientCertificates.Add(GetCertificate("localhost"));
factory.SSL.ClientCertificates.Add(Test.Common.Extensions.GetCertificate("localhost"));
factory.SSL.RemoteCertificateValidationCallback = ValidateServerCertificate;
factory.SASL.Profile = SaslProfile.External;
Console.WriteLine("Sending message...");
Expand All @@ -73,38 +73,6 @@ static bool ValidateServerCertificate(object sender, X509Certificate certificate
return true;
}

static X509Certificate2 GetCertificate(string certFindValue)
{
StoreLocation[] locations = new StoreLocation[] { StoreLocation.LocalMachine, StoreLocation.CurrentUser };
foreach (StoreLocation location in locations)
{
X509Store store = new X509Store(StoreName.My, location);
store.Open(OpenFlags.OpenExistingOnly);

X509Certificate2Collection collection = store.Certificates.Find(
X509FindType.FindBySubjectName,
certFindValue,
false);

if (collection.Count == 0)
{
collection = store.Certificates.Find(
X509FindType.FindByThumbprint,
certFindValue,
false);
}

store.Close();

if (collection.Count > 0)
{
return collection[0];
}
}

throw new ArgumentException("No certificate can be found using the find value " + certFindValue);
}

class MessageProcessor : IMessageProcessor
{
int IMessageProcessor.Credit
Expand Down
55 changes: 45 additions & 10 deletions src/Net/ConnectionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace Amqp
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Amqp.Framing;
using Amqp.Handler;
Expand Down Expand Up @@ -103,7 +104,7 @@ internal SslSettings SslInternal
/// <returns>A task for the connection creation operation. On success, the result is an AMQP <see cref="Connection"/></returns>
public Task<Connection> CreateAsync(Address address, IHandler handler)
{
return this.CreateAsync(address, null, null, handler);
return this.CreateAsync(address, null, null, handler, CancellationToken.None);
}

/// <summary>
Expand All @@ -116,7 +117,19 @@ public Task<Connection> CreateAsync(Address address, IHandler handler)
/// <remarks>The Open object, when provided, is used as is, and not augmented by the AMQP settings.</remarks>
public Task<Connection> CreateAsync(Address address, Open open = null, OnOpened onOpened = null)
{
return this.CreateAsync(address, open, onOpened, null);
return this.CreateAsync(address, open, onOpened, null, CancellationToken.None);
}

/// <summary>
/// Creates a new connection with an optional protocol handler.
/// </summary>
/// <param name="address">The address of remote endpoint to connect to.</param>
/// <param name="cancellationToken">The cancellation token associated with the async operation.</param>
/// <param name="handler">The protocol handler.</param>
/// <returns>A task for the connection creation operation. On success, the result is an AMQP <see cref="Connection"/></returns>
public Task<Connection> CreateAsync(Address address, CancellationToken cancellationToken, IHandler handler = null)
{
return this.CreateAsync(address, null, null, handler, cancellationToken);
}

internal async Task ConnectAsync(Address address, SaslProfile saslProfile, Open open, Connection connection)
Expand All @@ -133,14 +146,14 @@ internal async Task ConnectAsync(Address address, SaslProfile saslProfile, Open
}
}

IAsyncTransport transport = await this.CreateTransportAsync(address, saslProfile, connection.Handler).ConfigureAwait(false);
IAsyncTransport transport = await this.CreateTransportAsync(address, saslProfile, connection.Handler, CancellationToken.None).ConfigureAwait(false);
connection.Init(this.BufferManager, this.AMQP, transport, open);

AsyncPump pump = new AsyncPump(this.BufferManager, transport);
pump.Start(connection);
}

async Task<IAsyncTransport> CreateTransportAsync(Address address, SaslProfile saslProfile, IHandler handler)
async Task<IAsyncTransport> CreateTransportAsync(Address address, SaslProfile saslProfile, IHandler handler, CancellationToken cancellationToken)
{
IAsyncTransport transport;
TransportProvider provider;
Expand All @@ -151,7 +164,7 @@ async Task<IAsyncTransport> CreateTransportAsync(Address address, SaslProfile sa
else if (TcpTransport.MatchScheme(address.Scheme))
{
TcpTransport tcpTransport = new TcpTransport(this.BufferManager);
await tcpTransport.ConnectAsync(address, this, handler).ConfigureAwait(false);
await tcpTransport.ConnectAsync(address, this, handler, cancellationToken).ConfigureAwait(false);
transport = tcpTransport;
}
#if NETFX
Expand Down Expand Up @@ -183,7 +196,7 @@ async Task<IAsyncTransport> CreateTransportAsync(Address address, SaslProfile sa
return transport;
}

async Task<Connection> CreateAsync(Address address, Open open, OnOpened onOpened, IHandler handler)
async Task<Connection> CreateAsync(Address address, Open open, OnOpened onOpened, IHandler handler, CancellationToken cancellationToken)
{
SaslProfile saslProfile = null;
if (address.User != null)
Expand All @@ -195,9 +208,9 @@ async Task<Connection> CreateAsync(Address address, Open open, OnOpened onOpened
saslProfile = this.saslSettings.Profile;
}

IAsyncTransport transport = await this.CreateTransportAsync(address, saslProfile, handler).ConfigureAwait(false);
IAsyncTransport transport = await this.CreateTransportAsync(address, saslProfile, handler, cancellationToken).ConfigureAwait(false);

var tcs = new ConnectTaskCompletionSource(this, address, open, onOpened, handler, transport);
var tcs = new ConnectTaskCompletionSource(this, address, open, onOpened, handler, transport, cancellationToken);
return await tcs.Task.ConfigureAwait(false);
}

Expand Down Expand Up @@ -283,12 +296,20 @@ sealed class ConnectTaskCompletionSource : TaskCompletionSource<Connection>
{
readonly ConnectionFactory factory;
readonly OnOpened onOpened;
readonly IAsyncTransport transport;
readonly CancellationTokenRegistration ctr;
Connection connection;

public ConnectTaskCompletionSource(ConnectionFactory factory, Address address, Open open, OnOpened onOpened, IHandler handler, IAsyncTransport transport)
public ConnectTaskCompletionSource(ConnectionFactory factory, Address address, Open open,
OnOpened onOpened, IHandler handler, IAsyncTransport transport, CancellationToken cancellationToken)
{
this.factory = factory;
this.onOpened = onOpened;
this.transport = transport;
if (cancellationToken.CanBeCanceled)
{
this.ctr = cancellationToken.Register(o => ((ConnectTaskCompletionSource)o).OnCancel(), this);
}

this.connection = new Connection(this.factory.BufferManager, this.factory.AMQP, address, transport, open, this.OnOpen, handler);
AsyncPump pump = new AsyncPump(this.factory.BufferManager, transport);
Expand All @@ -297,16 +318,30 @@ public ConnectTaskCompletionSource(ConnectionFactory factory, Address address, O

void OnOpen(IConnection connection, Open open)
{
this.ctr.Dispose();
if (this.onOpened != null)
{
this.onOpened(connection, open);
}

this.TrySetResult(this.connection);
if (!this.TrySetResult(this.connection))
{
this.transport.Close();
}
}

void OnCancel()
{
this.ctr.Dispose();
if (this.TrySetCanceled())
{
this.transport.Close();
}
}

void OnException(Exception exception)
{
this.ctr.Dispose();
this.TrySetException(exception);
}
}
Expand Down
64 changes: 46 additions & 18 deletions src/Net/SocketExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace Amqp
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

static class SocketExtensions
Expand Down Expand Up @@ -64,26 +65,28 @@ ULONG is an unsigned 32 bit integer

public static void Complete<T>(object sender, SocketAsyncEventArgs args, bool throwOnError, T result)
{
var tcs = (TaskCompletionSource<T>)args.UserToken;
args.UserToken = null;
if (tcs == null)
using (var tcs = (SocketTaskCompletionSource<T>)args.UserToken)
{
return;
}

if (args.SocketError != SocketError.Success && throwOnError)
{
tcs.TrySetException(new SocketException((int)args.SocketError));
}
else
{
tcs.TrySetResult(result);
args.UserToken = null;
if (tcs == null)
{
return;
}

if (args.SocketError != SocketError.Success && throwOnError)
{
tcs.TrySetException(new SocketException((int)args.SocketError));
}
else
{
tcs.TrySetResult(result);
}
}
}

public static Task ConnectAsync(this Socket socket, IPAddress addr, int port)
public static Task ConnectAsync(this Socket socket, IPAddress addr, int port, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<int>();
var tcs = new SocketTaskCompletionSource<int>(cancellationToken);
var args = new SocketAsyncEventArgs();
args.RemoteEndPoint = new IPEndPoint(addr, port);
args.UserToken = tcs;
Expand All @@ -99,7 +102,7 @@ public static Task ConnectAsync(this Socket socket, IPAddress addr, int port)

public static Task<int> ReceiveAsync(this Socket socket, SocketAsyncEventArgs args, byte[] buffer, int offset, int count)
{
var tcs = new TaskCompletionSource<int>();
var tcs = new SocketTaskCompletionSource<int>(CancellationToken.None);
args.SetBuffer(buffer, offset, count);
args.UserToken = tcs;
if (!socket.ReceiveAsync(args))
Expand All @@ -112,7 +115,7 @@ public static Task<int> ReceiveAsync(this Socket socket, SocketAsyncEventArgs ar

public static Task<int> SendAsync(this Socket socket, SocketAsyncEventArgs args, IList<ArraySegment<byte>> buffers)
{
var tcs = new TaskCompletionSource<int>();
var tcs = new SocketTaskCompletionSource<int>(CancellationToken.None);
args.SetBuffer(null, 0, 0);
args.BufferList = buffers;
args.UserToken = tcs;
Expand All @@ -126,7 +129,7 @@ public static Task<int> SendAsync(this Socket socket, SocketAsyncEventArgs args,

public static Task<Socket> AcceptAsync(this Socket socket, SocketAsyncEventArgs args, SocketFlags flags)
{
var tcs = new TaskCompletionSource<Socket>();
var tcs = new SocketTaskCompletionSource<Socket>(CancellationToken.None);
args.UserToken = tcs;
if (!socket.AcceptAsync(args))
{
Expand All @@ -135,5 +138,30 @@ public static Task<Socket> AcceptAsync(this Socket socket, SocketAsyncEventArgs

return tcs.Task;
}

sealed class SocketTaskCompletionSource<T> : TaskCompletionSource<T>, IDisposable
{
readonly CancellationTokenRegistration ctr;

public SocketTaskCompletionSource(CancellationToken ct)
{
if (ct.CanBeCanceled)
{
this.ctr = ct.Register(o => OnCancel(o), this);
}
}

public void Dispose()
{
this.ctr.Dispose();
}

static void OnCancel(object state)
{
var thisPtr = (SocketTaskCompletionSource<T>)state;
thisPtr.ctr.Dispose();
thisPtr.TrySetCanceled();
}
}
}
}
7 changes: 4 additions & 3 deletions src/Net/TcpTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace Amqp
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Amqp.Handler;

Expand Down Expand Up @@ -57,10 +58,10 @@ public void Connect(Connection connection, Address address, bool noVerification)
factory.SSL.RemoteCertificateValidationCallback = noneCertValidator;
}

this.ConnectAsync(address, factory, connection.Handler).ConfigureAwait(false).GetAwaiter().GetResult();
this.ConnectAsync(address, factory, connection.Handler, CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult();
}

public async Task ConnectAsync(Address address, ConnectionFactory factory, IHandler handler)
public async Task ConnectAsync(Address address, ConnectionFactory factory, IHandler handler, CancellationToken cancellationToken)
{
IPAddress[] ipAddresses;
IPAddress ip;
Expand Down Expand Up @@ -88,7 +89,7 @@ public async Task ConnectAsync(Address address, ConnectionFactory factory, IHand
socket = new Socket(ipAddresses[i].AddressFamily, SocketType.Stream, ProtocolType.Tcp);
try
{
await socket.ConnectAsync(ipAddresses[i], address.Port).ConfigureAwait(false);
await socket.ConnectAsync(ipAddresses[i], address.Port, cancellationToken).ConfigureAwait(false);

exception = null;
break;
Expand Down
23 changes: 1 addition & 22 deletions test/Common/ContainerHostTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1189,7 +1189,7 @@ public void ContainerHostX509PrincipalTest()

try
{
cert = GetCertificate(StoreLocation.LocalMachine, StoreName.My, "localhost");
cert = Test.Common.Extensions.GetCertificate("localhost");
}
catch (PlatformNotSupportedException)
{
Expand Down Expand Up @@ -1456,27 +1456,6 @@ public void EncodeDecodeMessageWithAmqpValueTest()

Assert.AreEqual((message.BodySection as AmqpValue).Value, (copy.BodySection as AmqpValue).Value);
}

public static X509Certificate2 GetCertificate(StoreLocation storeLocation, StoreName storeName, string certFindValue)
{
X509Store store = new X509Store(storeName, storeLocation);
store.Open(OpenFlags.OpenExistingOnly);
X509Certificate2Collection collection = store.Certificates.Find(
X509FindType.FindBySubjectName,
certFindValue,
false);
if (collection.Count == 0)
{
throw new ArgumentException("No certificate can be found using the find value " + certFindValue);
}

#if DOTNET
store.Dispose();
#else
store.Close();
#endif
return collection[0];
}
}

class TestMessageProcessor : IMessageProcessor
Expand Down
Loading

0 comments on commit d5a0fcb

Please sign in to comment.