Skip to content

Commit

Permalink
add cancellation support to TcpSocketClient
Browse files Browse the repository at this point in the history
  • Loading branch information
rdavisau committed Jan 10, 2016
1 parent e7b1482 commit 0896e56
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ public static Task WrapNativeSocketExceptions(this Task task)
return task.ContinueWith(
t =>
{
if (!t.IsFaulted)
return t;
var ex = t.Exception.InnerException;
var ex = t.Exception?.InnerException ?? t.Exception;
throw (NativeExceptionExtensions.NativeSocketExceptions.Contains(ex.GetType()))
? new PclSocketException(ex)
Expand All @@ -35,7 +36,7 @@ public static Task<T> WrapNativeSocketExceptions<T>(this Task<T> task)
if (!t.IsFaulted)
return t.Result;
var ex = t.Exception.InnerException;
var ex = t.Exception?.InnerException ?? t.Exception;
throw (NativeExceptionExtensions.NativeSocketExceptions.Contains(ex.GetType()))
? new PclSocketException(ex)
Expand Down
32 changes: 27 additions & 5 deletions Sockets/Sockets.Implementation.NET/TcpSocketClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using Sockets.Plugin.Abstractions;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;

using System.Threading;
using PlatformSocketException = System.Net.Sockets.SocketException;
using PclSocketException = Sockets.Plugin.Abstractions.SocketException;

Expand Down Expand Up @@ -55,11 +55,33 @@ internal TcpSocketClient(TcpClient backingClient, int bufferSize)
/// <param name="address">The address of the endpoint to connect to.</param>
/// <param name="port">The port of the endpoint to connect to.</param>
/// <param name="secure">True to enable TLS on the socket.</param>
public async Task ConnectAsync(string address, int port, bool secure = false)
/// <param name="cancellationToken">The cancellation token to cancel the operation.</param>
public async Task ConnectAsync(string address, int port, bool secure = false, CancellationToken cancellationToken = default(CancellationToken))
{
await _backingTcpClient
.ConnectAsync(address, port)
.WrapNativeSocketExceptions();
// standard connect
var connectTask =
_backingTcpClient
.ConnectAsync(address, port)
.WrapNativeSocketExceptions();

// set up cancellation trigger
var ret = new TaskCompletionSource<bool>();
var canceller = cancellationToken.Register(() => ret.SetCanceled());

// if cancellation comes before connect completes, we honour it
var okOrCancelled = await Task.WhenAny(connectTask, ret.Task);

if (okOrCancelled == ret.Task)
{
// reset the backing field.
// depending on the state of the socket this may throw ODE which it is appropriate to ignore
try { await DisconnectAsync(); } catch (ObjectDisposedException) { }

// notify that we did cancel
cancellationToken.ThrowIfCancellationRequested();
}
else
canceller.Dispose();

InitializeWriteStream();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Windows.Foundation;
using Windows.Networking.Sockets;
Expand All @@ -11,9 +12,11 @@ namespace Sockets.Plugin
{
public static class NativeExceptionExtensions
{
public static Task WrapNativeSocketExceptionsAsTask(this IAsyncAction task)
public static Task WrapNativeSocketExceptionsAsTask(this IAsyncAction task, CancellationToken cancellationToken = default(CancellationToken))
{
var tcs = new TaskCompletionSource<bool>();
if (cancellationToken != default(CancellationToken))
cancellationToken.Register(task.Cancel);

task.Completed = delegate(IAsyncAction info, AsyncStatus status)
{
Expand All @@ -34,10 +37,13 @@ public static Task WrapNativeSocketExceptionsAsTask(this IAsyncAction task)
break;
}
};

return tcs.Task.ContinueWith(
t =>
{
if (t.IsCanceled)
cancellationToken.ThrowIfCancellationRequested();
if (!t.IsFaulted)
return t;
Expand Down
14 changes: 10 additions & 4 deletions Sockets/Sockets.Implementation.WinRT/TcpSocketClient.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Windows.Networking;
using Windows.Networking.Sockets;
Expand All @@ -22,7 +23,7 @@ public class TcpSocketClient : ITcpSocketClient
#else
private SocketProtectionLevel _secureSocketProtectionLevel = SocketProtectionLevel.Tls10;
#endif
private readonly StreamSocket _backingStreamSocket;
private StreamSocket _backingStreamSocket;
private readonly int _bufferSize;

/// <summary>
Expand Down Expand Up @@ -54,15 +55,16 @@ internal TcpSocketClient(StreamSocket nativeSocket, int bufferSize)
/// <param name="address">The address of the endpoint to connect to.</param>
/// <param name="port">The port of the endpoint to connect to.</param>
/// <param name="secure">True to enable TLS on the socket.</param>
public Task ConnectAsync(string address, int port, bool secure = false)
/// <param name="cancellationToken">The cancellation token to cancel the operation.</param>
public Task ConnectAsync(string address, int port, bool secure = false, CancellationToken cancellationToken = default(CancellationToken))
{
var hn = new HostName(address);
var sn = port.ToString();
var spl = secure ? _secureSocketProtectionLevel : SocketProtectionLevel.PlainSocket;

return _backingStreamSocket
.ConnectAsync(hn, sn, spl)
.WrapNativeSocketExceptionsAsTask();
.WrapNativeSocketExceptionsAsTask(cancellationToken);
}

/// <summary>
Expand All @@ -71,7 +73,11 @@ public Task ConnectAsync(string address, int port, bool secure = false)
/// </summary>
public Task DisconnectAsync()
{
return Task.Run(() => _backingStreamSocket.Dispose());
return Task.Run(() =>
{
_backingStreamSocket.Dispose();
_backingStreamSocket = new StreamSocket();
});
}

/// <summary>
Expand Down
4 changes: 3 additions & 1 deletion Sockets/Sockets.Plugin.Abstractions/ITcpSocketClient.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace Sockets.Plugin.Abstractions
Expand All @@ -18,7 +19,8 @@ public interface ITcpSocketClient : IDisposable
/// <param name="address">The address of the endpoint to connect to.</param>
/// <param name="port">The port of the endpoint to connect to.</param>
/// <param name="secure">Is this socket secure?</param>
Task ConnectAsync(string address, int port, bool secure = false);
/// <param name="cancellationToken">The cancellation token to cancel the operation.</param>
Task ConnectAsync(string address, int port, bool secure = false, CancellationToken cancellationToken = default(CancellationToken));

/// <summary>
/// Disconnects from an endpoint previously connected to using <code>ConnectAsync</code>.
Expand Down
5 changes: 3 additions & 2 deletions Sockets/Sockets.Plugin/TcpSocketClient.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Sockets.Plugin.Abstractions;

Expand Down Expand Up @@ -35,8 +36,8 @@ public TcpSocketClient(int bufferSize) : this()
/// </summary>
/// <param name="address">The address of the endpoint to connect to.</param>
/// <param name="port">The port of the endpoint to connect to.</param>
/// <param name="secure">True to enable TLS on the socket.</param>
public Task ConnectAsync(string address, int port, bool secure = false)
/// <param name="cancellationToken">The cancellation token to cancel the operation.</param>
public Task ConnectAsync(string address, int port, bool secure = false, CancellationToken cancellationToken = default(CancellationToken))
{
throw new NotImplementedException(PCL.BaitWithoutSwitchMessage);
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions Sockets/Tests/Sockets.Tests/TcpSocketClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -236,5 +236,17 @@ public async Task TcpSocketClient_ShouldBeAbleToDisconnectThenReconnect(int buff
await listener.StopListeningAsync();
}

[Fact]
public Task TcpSocketClient_Connect_ShouldCancelByCancellationToken()
{
var sut = new TcpSocketClient();

var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var ct = cts.Token;

// let's just hope no one's home :)
return Assert.ThrowsAsync<OperationCanceledException>(()=> sut.ConnectAsync("99.99.99.99", 51234, cancellationToken: cts.Token));
}

}
}

0 comments on commit 0896e56

Please sign in to comment.