Skip to content

Plumbing for more async/await work #1281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 25, 2024
50 changes: 50 additions & 0 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#if NET6_0_OR_GREATER

using System;
using System.Diagnostics;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

namespace Renci.SshNet.Abstractions
{
internal static partial class SocketAbstraction
{
public static ValueTask<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
{
return socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken);
}

public static ValueTask SendAsync(Socket socket, ReadOnlyMemory<byte> data, CancellationToken cancellationToken = default)
{
Debug.Assert(socket != null);
Debug.Assert(data.Length > 0);

if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled(cancellationToken);
}

return SendAsyncCore(socket, data, cancellationToken);

static async ValueTask SendAsyncCore(Socket socket, ReadOnlyMemory<byte> data, CancellationToken cancellationToken)
{
do
{
try
{
var bytesSent = await socket.SendAsync(data, SocketFlags.None, cancellationToken).ConfigureAwait(false);
data = data.Slice(bytesSent);
}
catch (SocketException ex) when (IsErrorResumable(ex.SocketErrorCode))
{
// Buffer may be full; attempt a short delay and retry
await Task.Delay(30, cancellationToken).ConfigureAwait(false);
}
}
while (data.Length > 0);
}
}
}
}
#endif // NET6_0_OR_GREATER
9 changes: 2 additions & 7 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace Renci.SshNet.Abstractions
{
internal static class SocketAbstraction
internal static partial class SocketAbstraction
{
public static bool CanRead(Socket socket)
{
Expand Down Expand Up @@ -325,12 +325,7 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS
return totalBytesRead;
}

#if NET6_0_OR_GREATER
public static async Task<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
{
return await socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken).ConfigureAwait(false);
}
#else
#if NET6_0_OR_GREATER == false
public static Task<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
{
return socket.ReceiveAsync(buffer, 0, buffer.Length, cancellationToken);
Expand Down
4 changes: 4 additions & 0 deletions src/Renci.SshNet/Connection/ProtocolVersionExchange.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ public async Task<SshIdentification> StartAsync(string clientVersion, Socket soc
{
// Immediately send the identification string since the spec states both sides MUST send an identification string
// when the connection has been established
#if NET6_0_OR_GREATER
await SocketAbstraction.SendAsync(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"), cancellationToken).ConfigureAwait(false);
#else
SocketAbstraction.Send(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"));
#endif // NET6_0_OR_GREATER

var bytesReceived = new List<byte>();

Expand Down
32 changes: 25 additions & 7 deletions src/Renci.SshNet/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public class Session : ISession
/// This is also used to ensure that <see cref="_socket"/> will not be disposed
/// while performing a given operation or set of operations on <see cref="_socket"/>.
/// </remarks>
private readonly object _socketDisposeLock = new object();
private readonly SemaphoreSlim _socketDisposeLock = new SemaphoreSlim(1, 1);

/// <summary>
/// Holds an object that is used to ensure only a single thread can connect
Expand Down Expand Up @@ -1127,12 +1127,14 @@ internal void SendMessage(Message message)
/// </para>
/// <para>
/// This method is only to be used when the connection is established, as the locking
/// overhead is not required while establising the connection.
/// overhead is not required while establishing the connection.
/// </para>
/// </remarks>
private void SendPacket(byte[] packet, int offset, int length)
{
lock (_socketDisposeLock)
_socketDisposeLock.Wait();

try
{
if (!_socket.IsConnected())
{
Expand All @@ -1141,6 +1143,10 @@ private void SendPacket(byte[] packet, int offset, int length)

SocketAbstraction.Send(_socket, packet, offset, length);
}
finally
{
_ = _socketDisposeLock.Release();
}
}

/// <summary>
Expand Down Expand Up @@ -1798,8 +1804,9 @@ internal static string ToHex(byte[] bytes)
/// </remarks>
private bool IsSocketConnected()
{
#pragma warning disable S2222 // Locks should be released on all paths
lock (_socketDisposeLock)
_socketDisposeLock.Wait();

try
{
if (!_socket.IsConnected())
{
Expand All @@ -1821,7 +1828,10 @@ private bool IsSocketConnected()
Monitor.Exit(_socketReadLock);
}
}
#pragma warning restore S2222 // Locks should be released on all paths
finally
{
_ = _socketDisposeLock.Release();
}
}

/// <summary>
Expand All @@ -1848,9 +1858,13 @@ private void SocketDisconnectAndDispose()
{
if (_socket != null)
{
lock (_socketDisposeLock)
_socketDisposeLock.Wait();

try
{
#pragma warning disable CA1508 // Avoid dead conditional code; Value could have been changed by another thread.
if (_socket != null)
#pragma warning restore CA1508 // Avoid dead conditional code
{
if (_socket.Connected)
{
Expand Down Expand Up @@ -1879,6 +1893,10 @@ private void SocketDisconnectAndDispose()
_socket = null;
}
}
finally
{
_ = _socketDisposeLock.Release();
}
}
}

Expand Down