Skip to content

Add ExecuteAsync, Fix CancelAsync Deadlock #1343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Renci.SshNet/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
[assembly: InternalsVisibleTo("Renci.SshNet.IntegrationTests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f9194e1eb66b7e2575aaee115ee1d27bc100920e7150e43992d6f668f9737de8b9c7ae892b62b8a36dd1d57929ff1541665d101dc476d6e02390846efae7e5186eec409710fdb596e3f83740afef0d4443055937649bc5a773175b61c57615dac0f0fd10f52b52fedf76c17474cc567b3f7a79de95dde842509fb39aaf69c6c2")]
[assembly: InternalsVisibleTo("Renci.SshNet.Benchmarks, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f9194e1eb66b7e2575aaee115ee1d27bc100920e7150e43992d6f668f9737de8b9c7ae892b62b8a36dd1d57929ff1541665d101dc476d6e02390846efae7e5186eec409710fdb596e3f83740afef0d4443055937649bc5a773175b61c57615dac0f0fd10f52b52fedf76c17474cc567b3f7a79de95dde842509fb39aaf69c6c2")]
[assembly: InternalsVisibleTo("DynamicProxyGenAssembly2, PublicKey=0024000004800000940000000602000000240000525341310004000001000100c547cac37abd99c8db225ef2f6c8a3602f3b3606cc9891605d02baa56104f4cfc0734aa39b93bf7852f7d9266654753cc297e7d2edfe0bac1cdcf9f717241550e0a7b191195b7667bb4f64bcb8e2121380fd1d9d46ad2d92d2d15605093924cceaf74c4861eff62abf69b9291ed0a340e113be11e6a7d3113e92484cf7045cc7")]
[assembly: InternalsVisibleTo("Renci.SshNet.IntegrationBenchmarks, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f9194e1eb66b7e2575aaee115ee1d27bc100920e7150e43992d6f668f9737de8b9c7ae892b62b8a36dd1d57929ff1541665d101dc476d6e02390846efae7e5186eec409710fdb596e3f83740afef0d4443055937649bc5a773175b61c57615dac0f0fd10f52b52fedf76c17474cc567b3f7a79de95dde842509fb39aaf69c6c2")]
180 changes: 146 additions & 34 deletions src/Renci.SshNet/SshCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Runtime.ExceptionServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

using Renci.SshNet.Abstractions;
using Renci.SshNet.Channels;
Expand All @@ -26,6 +27,7 @@ public class SshCommand : IDisposable
private CommandAsyncResult _asyncResult;
private AsyncCallback _callback;
private EventWaitHandle _sessionErrorOccuredWaitHandle;
private EventWaitHandle _commmandCancelledWaitHandle;
private Exception _exception;
private StringBuilder _result;
private StringBuilder _error;
Expand Down Expand Up @@ -105,56 +107,72 @@ public Stream CreateInputStream()
/// <summary>
/// Gets the command execution result.
/// </summary>
#pragma warning disable S1133 // Deprecated code should be removed
[Obsolete("Please read the result from the OutputStream. I.e. new StreamReader(shell.OutputStream).ReadToEnd().")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we make it obsolete?

#pragma warning disable S1133 // Deprecated code should be removed
public string Result
{
get
{
_result ??= new StringBuilder();
return GetResult();
}
}

if (OutputStream != null && OutputStream.Length > 0)
internal string GetResult()
{
_result ??= new StringBuilder();

if (OutputStream != null && OutputStream.Length > 0)
{
using (var sr = new StreamReader(OutputStream,
_encoding,
detectEncodingFromByteOrderMarks: true,
bufferSize: 1024,
leaveOpen: true))
{
using (var sr = new StreamReader(OutputStream,
_encoding,
detectEncodingFromByteOrderMarks: true,
bufferSize: 1024,
leaveOpen: true))
{
_ = _result.Append(sr.ReadToEnd());
}
_ = _result.Append(sr.ReadToEnd());
}

return _result.ToString();
}

return _result.ToString();
}

/// <summary>
/// Gets the command execution error.
/// </summary>
#pragma warning disable S1133 // Deprecated code should be removed
[Obsolete("Please read the error result from the ExtendedOutputStream. I.e. new StreamReader(shell.ExtendedOutputStream).ReadToEnd().")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we make it obsolete?

#pragma warning disable S1133 // Deprecated code should be removed
public string Error
{
get
{
if (_hasError)
{
_error ??= new StringBuilder();
return GetError();
}
}

if (ExtendedOutputStream != null && ExtendedOutputStream.Length > 0)
internal string GetError()
{
if (_hasError)
{
_error ??= new StringBuilder();

if (ExtendedOutputStream != null && ExtendedOutputStream.Length > 0)
{
using (var sr = new StreamReader(ExtendedOutputStream,
_encoding,
detectEncodingFromByteOrderMarks: true,
bufferSize: 1024,
leaveOpen: true))
{
using (var sr = new StreamReader(ExtendedOutputStream,
_encoding,
detectEncodingFromByteOrderMarks: true,
bufferSize: 1024,
leaveOpen: true))
{
_ = _error.Append(sr.ReadToEnd());
}
_ = _error.Append(sr.ReadToEnd());
}

return _error.ToString();
}

return string.Empty;
return _error.ToString();
}

return string.Empty;
}

/// <summary>
Expand Down Expand Up @@ -186,6 +204,7 @@ internal SshCommand(ISession session, string commandText, Encoding encoding)
_encoding = encoding;
CommandTimeout = Session.InfiniteTimeSpan;
_sessionErrorOccuredWaitHandle = new AutoResetEvent(initialState: false);
_commmandCancelledWaitHandle = new AutoResetEvent(initialState: false);

_session.Disconnected += Session_Disconnected;
_session.ErrorOccured += Session_ErrorOccured;
Expand Down Expand Up @@ -348,21 +367,109 @@ public string EndExecute(IAsyncResult asyncResult)
_channel = null;

commandAsyncResult.EndCalled = true;

#pragma warning disable CS0618
return Result;
#pragma warning disable CS0618
}
}

/// <summary>
/// Cancels command execution in asynchronous scenarios.
/// Waits for the pending asynchronous command execution to complete.
/// </summary>
public void CancelAsync()
/// <param name="asyncResult">The reference to the pending asynchronous request to finish.</param>
/// <returns>Command execution exit status.</returns>
/// <example>
/// <code source="..\..\src\Renci.SshNet.Tests\Classes\SshCommandTest.cs" region="Example SshCommand CreateCommand BeginExecute IsCompleted EndExecute" language="C#" title="Asynchronous Command Execution" />
/// </example>
/// <exception cref="ArgumentException">Either the IAsyncResult object did not come from the corresponding async method on this type, or EndExecute was called multiple times with the same IAsyncResult.</exception>
/// <exception cref="ArgumentNullException"><paramref name="asyncResult"/> is <c>null</c>.</exception>
public int EndExecuteWithStatus(IAsyncResult asyncResult)
{
if (_channel is not null && _channel.IsOpen && _asyncResult is not null)
if (asyncResult == null)
{
// TODO: check with Oleg if we shouldn't dispose the channel and uninitialize it ?
_channel.Dispose();
throw new ArgumentNullException(nameof(asyncResult));
}

var commandAsyncResult = asyncResult switch
{
CommandAsyncResult result when result == _asyncResult => result,
_ => throw new ArgumentException(
$"The {nameof(IAsyncResult)} object was not returned from the corresponding asynchronous method on this class.")
};

lock (_endExecuteLock)
{
if (commandAsyncResult.EndCalled)
{
throw new ArgumentException("EndExecute can only be called once for each asynchronous operation.");
}

// wait for operation to complete (or time out)
WaitOnHandle(_asyncResult.AsyncWaitHandle);
UnsubscribeFromEventsAndDisposeChannel(_channel);
_channel = null;

commandAsyncResult.EndCalled = true;

return ExitStatus;
}
}

/// <summary>
/// Executes the the command asynchronously.
/// </summary>
/// <returns>Exit status of the operation.</returns>
public Task<int> ExecuteAsync()
{
return ExecuteAsync(forceKill: false, default);
}

/// <summary>
/// Executes the the command asynchronously.
/// </summary>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to observe.</param>
/// <returns>Exit status of the operation.</returns>
public Task<int> ExecuteAsync(CancellationToken cancellationToken)
{
return ExecuteAsync(forceKill: false, cancellationToken);
}

/// <summary>
/// Executes the the command asynchronously.
/// </summary>
/// <param name="forceKill">if true send SIGKILL instead of SIGTERM to cancel the command.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to observe.</param>
/// <returns>Exit status of the operation.</returns>
public async Task<int> ExecuteAsync(bool forceKill, CancellationToken cancellationToken)
{
#if NET || NETSTANDARD2_1_OR_GREATER
await using var ctr = cancellationToken.Register(() => CancelAsync(forceKill), useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false);
#else
using var ctr = cancellationToken.Register(() => CancelAsync(forceKill), useSynchronizationContext: false);
#endif // NET || NETSTANDARD2_1_OR_GREATER

try
{
var status = await Task<int>.Factory.FromAsync(BeginExecute(), EndExecuteWithStatus).ConfigureAwait(false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to add Async overloaded methods but I think this is not a real async code. I've prepared a POC here: #1318

I know that your approach is easier to implement, but this is similar to nuget: https://www.nuget.org/packages/Renci.SshNet.Async#versions-body-tab

You can find code here: https://github.com/JohnTheGr8/Renci.SshNet.Async/blob/master/Renci.SshNet.Async/SshNetExtensions.cs

cancellationToken.ThrowIfCancellationRequested();

return status;
}
catch (Exception) when (cancellationToken.IsCancellationRequested)
{
throw new OperationCanceledException("Command execution has been cancelled.", cancellationToken);
}
}

/// <summary>
/// Cancels command execution in asynchronous scenarios.
/// </summary>
/// <param name="forceKill">if true send SIGKILL instead of SIGTERM.</param>
public void CancelAsync(bool forceKill = false)
{
var signal = forceKill ? "KILL" : "TERM";
_ = _channel?.SendExitSignalRequest(signal, coreDumped: false, "Command execution has been cancelled.", "en");
_ = _commmandCancelledWaitHandle.Set();
}

/// <summary>
Expand Down Expand Up @@ -506,6 +613,7 @@ private void WaitOnHandle(WaitHandle waitHandle)
var waitHandles = new[]
{
_sessionErrorOccuredWaitHandle,
_commmandCancelledWaitHandle,
waitHandle
};

Expand All @@ -515,7 +623,8 @@ private void WaitOnHandle(WaitHandle waitHandle)
case 0:
ExceptionDispatchInfo.Capture(_exception).Throw();
break;
case 1:
case 1: // Command cancelled
case 2:
// Specified waithandle was signaled
break;
case WaitHandle.WaitTimeout:
Expand Down Expand Up @@ -620,6 +729,9 @@ protected virtual void Dispose(bool disposing)
_sessionErrorOccuredWaitHandle = null;
}

_commmandCancelledWaitHandle?.Dispose();
_commmandCancelledWaitHandle = null;

_isDisposed = true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,21 @@ public string ConnectAndRunCommand()
{
using var sshClient = new SshClient(_infrastructureFixture.SshServerHostName, _infrastructureFixture.SshServerPort, _infrastructureFixture.User.UserName, _infrastructureFixture.User.Password);
sshClient.Connect();
return sshClient.RunCommand("echo $'test !@#$%^&*()_+{}:,./<>[];\\|'").Result;
return sshClient.RunCommand("echo $'test !@#$%^&*()_+{}:,./<>[];\\|'").GetResult();
}

[Benchmark]
public async Task<string> ConnectAsyncAndRunCommand()
{
using var sshClient = new SshClient(_infrastructureFixture.SshServerHostName, _infrastructureFixture.SshServerPort, _infrastructureFixture.User.UserName, _infrastructureFixture.User.Password);
await sshClient.ConnectAsync(CancellationToken.None).ConfigureAwait(false);
return sshClient.RunCommand("echo $'test !@#$%^&*()_+{}:,./<>[];\\|'").Result;
return sshClient.RunCommand("echo $'test !@#$%^&*()_+{}:,./<>[];\\|'").GetResult();
}

[Benchmark]
public string RunCommand()
{
return _sshClient!.RunCommand("echo $'test !@#$%^&*()_+{}:,./<>[];\\|'").Result;
return _sshClient!.RunCommand("echo $'test !@#$%^&*()_+{}:,./<>[];\\|'").GetResult();
}

[Benchmark]
Expand Down
8 changes: 4 additions & 4 deletions test/Renci.SshNet.IntegrationTests/AuthenticationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ public void TearDown()
// Reset the password back to the "regular" password.
using (var cmd = client.RunCommand($"echo \"{Users.Regular.Password}\n{Users.Regular.Password}\" | sudo passwd " + Users.Regular.UserName))
{
Assert.AreEqual(0, cmd.ExitStatus, cmd.Error);
Assert.AreEqual(0, cmd.ExitStatus, cmd.GetError());
}

// Remove password expiration
using (var cmd = client.RunCommand($"sudo chage --expiredate -1 " + Users.Regular.UserName))
{
Assert.AreEqual(0, cmd.ExitStatus, cmd.Error);
Assert.AreEqual(0, cmd.ExitStatus, cmd.GetError());
}
}
}
Expand Down Expand Up @@ -324,13 +324,13 @@ public void KeyboardInteractive_PasswordExpired()
// the "regular" password.
using (var cmd = client.RunCommand($"echo \"{temporaryPassword}\n{temporaryPassword}\" | sudo passwd " + Users.Regular.UserName))
{
Assert.AreEqual(0, cmd.ExitStatus, cmd.Error);
Assert.AreEqual(0, cmd.ExitStatus, cmd.GetError());
}

// Force the password to expire immediately
using (var cmd = client.RunCommand($"sudo chage -d 0 " + Users.Regular.UserName))
{
Assert.AreEqual(0, cmd.ExitStatus, cmd.Error);
Assert.AreEqual(0, cmd.ExitStatus, cmd.GetError());
}
}

Expand Down
4 changes: 2 additions & 2 deletions test/Renci.SshNet.IntegrationTests/ConnectivityTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void Common_DisposeAfterLossOfNetworkConnectivity()
hostNetworkConnectionDisabled = true;
WaitForConnectionInterruption(client);
}

Assert.IsNotNull(errorOccurred);
Assert.AreEqual(typeof(SshConnectionException), errorOccurred.GetType());

Expand Down Expand Up @@ -309,7 +309,7 @@ public void Common_DetectSessionKilledOnServer()
var command = $"sudo ps --no-headers -u {client.ConnectionInfo.Username} -f | grep \"{client.ConnectionInfo.Username}@notty\" | awk '{{print $2}}' | xargs sudo kill -9";
var sshCommand = adminClient.CreateCommand(command);
var result = sshCommand.Execute();
Assert.AreEqual(0, sshCommand.ExitStatus, sshCommand.Error);
Assert.AreEqual(0, sshCommand.ExitStatus, sshCommand.GetError());
}

Assert.IsTrue(errorOccurredSignaled.WaitOne(200));
Expand Down
Loading