Skip to content

Commit f9236c5

Browse files
IgorMilavecdrieseng
authored andcommitted
Add async support to SftpClient and SftpFileStream (#819)
* Add FEATURE_TAP and net472 target * Add TAP async support to SftpClient and SftpFileStream * Add async support to DnsAbstraction and SocketAbstraction * Add async support to *Connector and refactor the hierarchy * Add ConnectAsync to BaseClient
1 parent c64803a commit f9236c5

21 files changed

+1611
-70
lines changed

src/Renci.SshNet/Abstractions/DnsAbstraction.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
using System.Net;
33
using System.Net.Sockets;
44

5+
#if FEATURE_TAP
6+
using System.Threading.Tasks;
7+
#endif
8+
59
#if FEATURE_DNS_SYNC
610
#elif FEATURE_DNS_APM
711
using Renci.SshNet.Common;
@@ -87,5 +91,23 @@ public static IPAddress[] GetHostAddresses(string hostNameOrAddress)
8791
#endif // FEATURE_DEVICEINFORMATION_APM
8892
#endif
8993
}
94+
95+
#if FEATURE_TAP
96+
/// <summary>
97+
/// Returns the Internet Protocol (IP) addresses for the specified host.
98+
/// </summary>
99+
/// <param name="hostNameOrAddress">The host name or IP address to resolve</param>
100+
/// <returns>
101+
/// A task with result of an array of type <see cref="IPAddress"/> that holds the IP addresses for the host that
102+
/// is specified by the <paramref name="hostNameOrAddress"/> parameter.
103+
/// </returns>
104+
/// <exception cref="ArgumentNullException"><paramref name="hostNameOrAddress"/> is <c>null</c>.</exception>
105+
/// <exception cref="SocketException">An error is encountered when resolving <paramref name="hostNameOrAddress"/>.</exception>
106+
public static Task<IPAddress[]> GetHostAddressesAsync(string hostNameOrAddress)
107+
{
108+
return Dns.GetHostAddressesAsync(hostNameOrAddress);
109+
}
110+
#endif
111+
90112
}
91113
}

src/Renci.SshNet/Abstractions/SocketAbstraction.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
using System.Net;
44
using System.Net.Sockets;
55
using System.Threading;
6+
#if FEATURE_TAP
7+
using System.Threading.Tasks;
8+
#endif
69
using Renci.SshNet.Common;
710
using Renci.SshNet.Messages.Transport;
811

@@ -59,6 +62,13 @@ public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan co
5962
ConnectCore(socket, remoteEndpoint, connectTimeout, false);
6063
}
6164

65+
#if FEATURE_TAP
66+
public static Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
67+
{
68+
return socket.ConnectAsync(remoteEndpoint, cancellationToken);
69+
}
70+
#endif
71+
6272
private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
6373
{
6474
#if FEATURE_SOCKET_EAP
@@ -317,6 +327,13 @@ public static byte[] Read(Socket socket, int size, TimeSpan timeout)
317327
return buffer;
318328
}
319329

330+
#if FEATURE_TAP
331+
public static Task<int> ReadAsync(Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
332+
{
333+
return socket.ReceiveAsync(buffer, offset, length, cancellationToken);
334+
}
335+
#endif
336+
320337
/// <summary>
321338
/// Receives data from a bound <see cref="Socket"/> into a receive buffer.
322339
/// </summary>
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#if FEATURE_TAP
2+
using System;
3+
using System.Net;
4+
using System.Net.Sockets;
5+
using System.Runtime.CompilerServices;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
9+
namespace Renci.SshNet.Abstractions
10+
{
11+
// Async helpers based on https://devblogs.microsoft.com/pfxteam/awaiting-socket-operations/
12+
13+
internal static class SocketExtensions
14+
{
15+
sealed class SocketAsyncEventArgsAwaitable : SocketAsyncEventArgs, INotifyCompletion
16+
{
17+
private readonly static Action SENTINEL = () => { };
18+
19+
private bool isCancelled;
20+
private Action continuationAction;
21+
22+
public SocketAsyncEventArgsAwaitable()
23+
{
24+
Completed += delegate { SetCompleted(); };
25+
}
26+
27+
public SocketAsyncEventArgsAwaitable ExecuteAsync(Func<SocketAsyncEventArgs, bool> func)
28+
{
29+
if (!func(this))
30+
{
31+
SetCompleted();
32+
}
33+
return this;
34+
}
35+
36+
public void SetCompleted()
37+
{
38+
IsCompleted = true;
39+
var continuation = continuationAction ?? Interlocked.CompareExchange(ref continuationAction, SENTINEL, null);
40+
if (continuation != null)
41+
{
42+
continuation();
43+
}
44+
}
45+
46+
public void SetCancelled()
47+
{
48+
isCancelled = true;
49+
SetCompleted();
50+
}
51+
52+
public SocketAsyncEventArgsAwaitable GetAwaiter() { return this; }
53+
54+
public bool IsCompleted { get; private set; }
55+
56+
void INotifyCompletion.OnCompleted(Action continuation)
57+
{
58+
if (continuationAction == SENTINEL || Interlocked.CompareExchange(ref continuationAction, continuation, null) == SENTINEL)
59+
{
60+
// We have already completed; run continuation asynchronously
61+
Task.Run(continuation);
62+
}
63+
}
64+
65+
public void GetResult()
66+
{
67+
if (isCancelled)
68+
{
69+
throw new TaskCanceledException();
70+
}
71+
else if (IsCompleted)
72+
{
73+
if (SocketError != SocketError.Success)
74+
{
75+
throw new SocketException((int)SocketError);
76+
}
77+
}
78+
else
79+
{
80+
// We don't support sync/async
81+
throw new InvalidOperationException("The asynchronous operation has not yet completed.");
82+
}
83+
}
84+
}
85+
86+
public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
87+
{
88+
cancellationToken.ThrowIfCancellationRequested();
89+
90+
using (var args = new SocketAsyncEventArgsAwaitable())
91+
{
92+
args.RemoteEndPoint = remoteEndpoint;
93+
94+
using (cancellationToken.Register(o => ((SocketAsyncEventArgsAwaitable)o).SetCancelled(), args, false))
95+
{
96+
await args.ExecuteAsync(socket.ConnectAsync);
97+
}
98+
}
99+
}
100+
101+
public static async Task<int> ReceiveAsync(this Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
102+
{
103+
cancellationToken.ThrowIfCancellationRequested();
104+
105+
using (var args = new SocketAsyncEventArgsAwaitable())
106+
{
107+
args.SetBuffer(buffer, offset, length);
108+
109+
using (cancellationToken.Register(o => ((SocketAsyncEventArgsAwaitable)o).SetCancelled(), args, false))
110+
{
111+
await args.ExecuteAsync(socket.ReceiveAsync);
112+
}
113+
114+
return args.BytesTransferred;
115+
}
116+
}
117+
}
118+
}
119+
#endif

src/Renci.SshNet/BaseClient.cs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
using System;
22
using System.Net.Sockets;
33
using System.Threading;
4+
#if FEATURE_TAP
5+
using System.Threading.Tasks;
6+
#endif
47
using Renci.SshNet.Abstractions;
58
using Renci.SshNet.Common;
69
using Renci.SshNet.Messages.Transport;
@@ -239,6 +242,63 @@ public void Connect()
239242
StartKeepAliveTimer();
240243
}
241244

245+
#if FEATURE_TAP
246+
/// <summary>
247+
/// Asynchronously connects client to the server.
248+
/// </summary>
249+
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to observe.</param>
250+
/// <returns>A <see cref="Task"/> that represents the asynchronous connect operation.
251+
/// </returns>
252+
/// <exception cref="InvalidOperationException">The client is already connected.</exception>
253+
/// <exception cref="ObjectDisposedException">The method was called after the client was disposed.</exception>
254+
/// <exception cref="SocketException">Socket connection to the SSH server or proxy server could not be established, or an error occurred while resolving the hostname.</exception>
255+
/// <exception cref="SshConnectionException">SSH session could not be established.</exception>
256+
/// <exception cref="SshAuthenticationException">Authentication of SSH session failed.</exception>
257+
/// <exception cref="ProxyException">Failed to establish proxy connection.</exception>
258+
public async Task ConnectAsync(CancellationToken cancellationToken)
259+
{
260+
CheckDisposed();
261+
cancellationToken.ThrowIfCancellationRequested();
262+
263+
// TODO (see issue #1758):
264+
// we're not stopping the keep-alive timer and disposing the session here
265+
//
266+
// we could do this but there would still be side effects as concrete
267+
// implementations may still hang on to the original session
268+
//
269+
// therefore it would be better to actually invoke the Disconnect method
270+
// (and then the Dispose on the session) but even that would have side effects
271+
// eg. it would remove all forwarded ports from SshClient
272+
//
273+
// I think we should modify our concrete clients to better deal with a
274+
// disconnect. In case of SshClient this would mean not removing the
275+
// forwarded ports on disconnect (but only on dispose ?) and link a
276+
// forwarded port with a client instead of with a session
277+
//
278+
// To be discussed with Oleg (or whoever is interested)
279+
if (IsSessionConnected())
280+
throw new InvalidOperationException("The client is already connected.");
281+
282+
OnConnecting();
283+
284+
Session = await CreateAndConnectSessionAsync(cancellationToken).ConfigureAwait(false);
285+
try
286+
{
287+
// Even though the method we invoke makes you believe otherwise, at this point only
288+
// the SSH session itself is connected.
289+
OnConnected();
290+
}
291+
catch
292+
{
293+
// Only dispose the session as Disconnect() would have side-effects (such as remove forwarded
294+
// ports in SshClient).
295+
DisposeSession();
296+
throw;
297+
}
298+
StartKeepAliveTimer();
299+
}
300+
#endif
301+
242302
/// <summary>
243303
/// Disconnects client from the server.
244304
/// </summary>
@@ -473,6 +533,26 @@ private ISession CreateAndConnectSession()
473533
}
474534
}
475535

536+
#if FEATURE_TAP
537+
private async Task<ISession> CreateAndConnectSessionAsync(CancellationToken cancellationToken)
538+
{
539+
var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory());
540+
session.HostKeyReceived += Session_HostKeyReceived;
541+
session.ErrorOccured += Session_ErrorOccured;
542+
543+
try
544+
{
545+
await session.ConnectAsync(cancellationToken).ConfigureAwait(false);
546+
return session;
547+
}
548+
catch
549+
{
550+
DisposeSession(session);
551+
throw;
552+
}
553+
}
554+
#endif
555+
476556
private void DisposeSession(ISession session)
477557
{
478558
session.ErrorOccured -= Session_ErrorOccured;

src/Renci.SshNet/Connection/ConnectorBase.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
using System;
55
using System.Net;
66
using System.Net.Sockets;
7+
using System.Threading;
8+
9+
#if FEATURE_TAP
10+
using System.Threading.Tasks;
11+
#endif
712

813
namespace Renci.SshNet.Connection
914
{
@@ -21,6 +26,10 @@ protected ConnectorBase(ISocketFactory socketFactory)
2126

2227
public abstract Socket Connect(IConnectionInfo connectionInfo);
2328

29+
#if FEATURE_TAP
30+
public abstract Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken);
31+
#endif
32+
2433
/// <summary>
2534
/// Establishes a socket connection to the specified host and port.
2635
/// </summary>
@@ -54,6 +63,42 @@ protected Socket SocketConnect(string host, int port, TimeSpan timeout)
5463
}
5564
}
5665

66+
#if FEATURE_TAP
67+
/// <summary>
68+
/// Establishes a socket connection to the specified host and port.
69+
/// </summary>
70+
/// <param name="host">The host name of the server to connect to.</param>
71+
/// <param name="port">The port to connect to.</param>
72+
/// <param name="cancellationToken">The cancellation token to observe.</param>
73+
/// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="ConnectionInfo.Timeout"/>.</exception>
74+
/// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
75+
protected async Task<Socket> SocketConnectAsync(string host, int port, CancellationToken cancellationToken)
76+
{
77+
cancellationToken.ThrowIfCancellationRequested();
78+
79+
var ipAddress = (await DnsAbstraction.GetHostAddressesAsync(host).ConfigureAwait(false))[0];
80+
var ep = new IPEndPoint(ipAddress, port);
81+
82+
DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}:{1}'.", host, port));
83+
84+
var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
85+
try
86+
{
87+
await SocketAbstraction.ConnectAsync(socket, ep, cancellationToken).ConfigureAwait(false);
88+
89+
const int socketBufferSize = 2 * Session.MaximumSshPacketSize;
90+
socket.SendBufferSize = socketBufferSize;
91+
socket.ReceiveBufferSize = socketBufferSize;
92+
return socket;
93+
}
94+
catch (Exception)
95+
{
96+
socket.Dispose();
97+
throw;
98+
}
99+
}
100+
#endif
101+
57102
protected static byte SocketReadByte(Socket socket)
58103
{
59104
var buffer = new byte[1];

src/Renci.SshNet/Connection/DirectConnector.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
using System.Net.Sockets;
2+
using System.Threading;
23

34
namespace Renci.SshNet.Connection
45
{
5-
internal class DirectConnector : ConnectorBase
6+
internal sealed class DirectConnector : ConnectorBase
67
{
78
public DirectConnector(ISocketFactory socketFactory) : base(socketFactory)
89
{
@@ -12,5 +13,12 @@ public override Socket Connect(IConnectionInfo connectionInfo)
1213
{
1314
return SocketConnect(connectionInfo.Host, connectionInfo.Port, connectionInfo.Timeout);
1415
}
16+
17+
#if FEATURE_TAP
18+
public override System.Threading.Tasks.Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken)
19+
{
20+
return SocketConnectAsync(connectionInfo.Host, connectionInfo.Port, cancellationToken);
21+
}
22+
#endif
1523
}
1624
}

0 commit comments

Comments
 (0)