Skip to content

Don't dispose channel when completing SshCommand #1596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions src/Renci.SshNet/SshCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class SshCommand : IDisposable
private readonly ISession _session;
private readonly Encoding _encoding;

private IChannelSession? _channel;
private IChannelSession _channel;
private TaskCompletionSource<object>? _tcs;
private CancellationTokenSource? _cts;
private CancellationTokenRegistration _tokenRegistration;
Expand Down Expand Up @@ -142,14 +142,14 @@ public int? ExitStatus
/// </example>
public Stream CreateInputStream()
{
if (_channel == null)
if (!_channel.IsOpen)
{
throw new InvalidOperationException($"The input stream can be used only after calling BeginExecute and before calling EndExecute.");
throw new InvalidOperationException("The input stream can be used only during execution.");
}

if (_inputStream != null)
{
throw new InvalidOperationException($"The input stream already exists.");
throw new InvalidOperationException("The input stream already exists.");
}

_inputStream = new ChannelInputStream(_channel);
Expand Down Expand Up @@ -226,6 +226,7 @@ internal SshCommand(ISession session, string commandText, Encoding encoding)
ExtendedOutputStream = new PipeStream();
_session.Disconnected += Session_Disconnected;
_session.ErrorOccured += Session_ErrorOccured;
_channel = _session.CreateChannelSession();
}

/// <summary>
Expand Down Expand Up @@ -257,6 +258,8 @@ public Task ExecuteAsync(CancellationToken cancellationToken = default)
throw new InvalidOperationException("Asynchronous operation is already in progress.");
}

UnsubscribeFromChannelEvents(dispose: true);

OutputStream.Dispose();
ExtendedOutputStream.Dispose();

Expand All @@ -265,6 +268,7 @@ public Task ExecuteAsync(CancellationToken cancellationToken = default)
// so we just need to reinitialise them for subsequent executions.
OutputStream = new PipeStream();
ExtendedOutputStream = new PipeStream();
_channel = _session.CreateChannelSession();
}

_exitStatus = default;
Expand All @@ -282,7 +286,6 @@ public Task ExecuteAsync(CancellationToken cancellationToken = default)
_tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
_userToken = cancellationToken;

_channel = _session.CreateChannelSession();
_channel.DataReceived += Channel_DataReceived;
_channel.ExtendedDataReceived += Channel_ExtendedDataReceived;
_channel.RequestReceived += Channel_RequestReceived;
Expand Down Expand Up @@ -542,7 +545,10 @@ private void SetAsyncComplete(bool setResult = true)
}
}

UnsubscribeFromEventsAndDisposeChannel();
// We don't dispose the channel here to avoid a race condition
// where SSH_MSG_CHANNEL_CLOSE arrives before _channel starts
// waiting for a response in _channel.SendExecRequest().
UnsubscribeFromChannelEvents(dispose: false);

OutputStream.Dispose();
ExtendedOutputStream.Dispose();
Expand All @@ -568,7 +574,7 @@ private void Channel_RequestReceived(object? sender, ChannelRequestEventArgs e)

Debug.Assert(!exitSignalInfo.WantReply, "exit-signal is want_reply := false by definition.");
}
else if (e.Info.WantReply && _channel?.RemoteChannelNumber is uint remoteChannelNumber)
else if (e.Info.WantReply && sender is IChannel { RemoteChannelNumber: uint remoteChannelNumber })
{
var replyMessage = new ChannelFailureMessage(remoteChannelNumber);
_session.SendMessage(replyMessage);
Expand All @@ -591,29 +597,24 @@ private void Channel_DataReceived(object? sender, ChannelDataEventArgs e)
}

/// <summary>
/// Unsubscribes the current <see cref="SshCommand"/> from channel events, and disposes
/// the <see cref="_channel"/>.
/// Unsubscribes the current <see cref="SshCommand"/> from channel events, and optionally,
/// disposes <see cref="_channel"/>.
/// </summary>
private void UnsubscribeFromEventsAndDisposeChannel()
private void UnsubscribeFromChannelEvents(bool dispose)
{
var channel = _channel;

if (channel is null)
{
return;
}

_channel = null;

// unsubscribe from events as we do not want to be signaled should these get fired
// during the dispose of the channel
channel.DataReceived -= Channel_DataReceived;
channel.ExtendedDataReceived -= Channel_ExtendedDataReceived;
channel.RequestReceived -= Channel_RequestReceived;
channel.Closed -= Channel_Closed;

// actually dispose the channel
channel.Dispose();
if (dispose)
{
channel.Dispose();
}
}

/// <summary>
Expand Down Expand Up @@ -645,7 +646,7 @@ protected virtual void Dispose(bool disposing)

// unsubscribe from channel events to ensure other objects that we're going to dispose
// are not accessed while disposing
UnsubscribeFromEventsAndDisposeChannel();
UnsubscribeFromChannelEvents(dispose: true);

_inputStream?.Dispose();
_inputStream = null;
Expand Down
72 changes: 0 additions & 72 deletions test/Renci.SshNet.Tests/Classes/SshCommandTest_EndExecute.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ protected override void OnInit()

private void Arrange()
{
_sessionMock = new Mock<ISession>(MockBehavior.Strict);
_sessionMock = new Mock<ISession>();
_commandText = new Random().Next().ToString(CultureInfo.InvariantCulture);
_encoding = Encoding.UTF8;
_asyncResult = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ private void Act()
_actual = _sshCommand.EndExecute(_asyncResult);
}

[TestMethod]
public void ChannelSessionShouldBeDisposedOnce()
{
_channelSessionMock.Verify(p => p.Dispose(), Times.Once);
}

[TestMethod]
public void EndExecuteShouldReturnAllDataReceivedInSpecifiedEncoding()
{
Expand Down