Skip to content

Expose SshIdentificationReceived event #1195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 16, 2023
13 changes: 13 additions & 0 deletions src/Renci.SshNet/BaseClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ public TimeSpan KeepAliveInterval
/// </example>
public event EventHandler<HostKeyEventArgs> HostKeyReceived;

/// <summary>
/// Occurs when server identification received.
/// </summary>
public event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;

/// <summary>
/// Initializes a new instance of the <see cref="BaseClient"/> class.
/// </summary>
Expand Down Expand Up @@ -390,6 +395,11 @@ private void Session_HostKeyReceived(object sender, HostKeyEventArgs e)
HostKeyReceived?.Invoke(this, e);
}

private void Session_ServerIdentificationReceived(object sender, SshIdentificationEventArgs e)
{
ServerIdentificationReceived?.Invoke(this, e);
}

/// <summary>
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
/// </summary>
Expand Down Expand Up @@ -532,6 +542,7 @@ private Timer CreateKeepAliveTimer(TimeSpan dueTime, TimeSpan period)
private ISession CreateAndConnectSession()
{
var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory());
session.ServerIdentificationReceived += Session_ServerIdentificationReceived;
session.HostKeyReceived += Session_HostKeyReceived;
session.ErrorOccured += Session_ErrorOccured;

Expand All @@ -550,6 +561,7 @@ private ISession CreateAndConnectSession()
private async Task<ISession> CreateAndConnectSessionAsync(CancellationToken cancellationToken)
{
var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory());
session.ServerIdentificationReceived += Session_ServerIdentificationReceived;
session.HostKeyReceived += Session_HostKeyReceived;
session.ErrorOccured += Session_ErrorOccured;

Expand All @@ -569,6 +581,7 @@ private void DisposeSession(ISession session)
{
session.ErrorOccured -= Session_ErrorOccured;
session.HostKeyReceived -= Session_HostKeyReceived;
session.ServerIdentificationReceived -= Session_ServerIdentificationReceived;
session.Dispose();
}

Expand Down
26 changes: 26 additions & 0 deletions src/Renci.SshNet/Common/SshIdentificationEventArgs.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System;

using Renci.SshNet.Connection;

namespace Renci.SshNet.Common
{
/// <summary>
/// Provides data for the ServerIdentificationReceived events.
/// </summary>
public class SshIdentificationEventArgs : EventArgs
{
/// <summary>
/// Initializes a new instance of the <see cref="SshIdentificationEventArgs"/> class.
/// </summary>
/// <param name="sshIdentification">The SSH identification.</param>
public SshIdentificationEventArgs(SshIdentification sshIdentification)
{
SshIdentification = sshIdentification;
}

/// <summary>
/// Gets the SSH identification.
/// </summary>
public SshIdentification SshIdentification { get; private set; }
}
}
2 changes: 1 addition & 1 deletion src/Renci.SshNet/Connection/SshIdentification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace Renci.SshNet.Connection
/// <summary>
/// Represents an SSH identification.
/// </summary>
internal sealed class SshIdentification
public sealed class SshIdentification
{
/// <summary>
/// Initializes a new instance of the <see cref="SshIdentification"/> class with the specified protocol version
Expand Down
5 changes: 5 additions & 0 deletions src/Renci.SshNet/ISession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ internal interface ISession : IDisposable
/// </summary>
event EventHandler<ExceptionEventArgs> ErrorOccured;

/// <summary>
/// Occurs when server identification received.
/// </summary>
event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;

/// <summary>
/// Occurs when host key received.
/// </summary>
Expand Down
9 changes: 9 additions & 0 deletions src/Renci.SshNet/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,11 @@ public Message ClientInitMessage
/// </summary>
public event EventHandler<EventArgs> Disconnected;

/// <summary>
/// Occurs when server identification received.
/// </summary>
public event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;

/// <summary>
/// Occurs when host key received.
/// </summary>
Expand Down Expand Up @@ -624,6 +629,8 @@ public void Connect()
DisconnectReason.ProtocolVersionNotSupported);
}

ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification));

// Register Transport response messages
RegisterMessage("SSH_MSG_DISCONNECT");
RegisterMessage("SSH_MSG_IGNORE");
Expand Down Expand Up @@ -736,6 +743,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
DisconnectReason.ProtocolVersionNotSupported);
}

ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification));

// Register Transport response messages
RegisterMessage("SSH_MSG_DISCONNECT");
RegisterMessage("SSH_MSG_IGNORE");
Expand Down
12 changes: 9 additions & 3 deletions test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ public abstract class SessionTest_ConnectedBase
protected Session Session { get; private set; }
protected Socket ClientSocket { get; private set; }
protected Socket ServerSocket { get; private set; }
internal SshIdentification ServerIdentification { get; private set; }
internal SshIdentification ServerIdentification { get; set; }
protected bool CallSessionConnectWhenArrange { get; set; }

[TestInitialize]
public void Setup()
Expand Down Expand Up @@ -159,6 +160,8 @@ protected virtual void SetupData()
ServerListener.Start();

ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);

CallSessionConnectWhenArrange = true;
}

private void CreateMocks()
Expand All @@ -180,7 +183,7 @@ private void SetupMocks()
_ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange())
.Returns(_protocolVersionExchangeMock.Object);
_ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout))
.Returns(ServerIdentification);
.Returns(() => ServerIdentification);
_ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
_ = _keyExchangeMock.Setup(p => p.Name)
.Returns(_keyExchangeAlgorithm);
Expand Down Expand Up @@ -212,7 +215,10 @@ protected void Arrange()
SetupData();
SetupMocks();

Session.Connect();
if (CallSessionConnectWhenArrange)
{
Session.Connect();
}
}

protected virtual void ClientAuthentication_Callback()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;

using Renci.SshNet.Connection;

namespace Renci.SshNet.Tests.Classes
{
[TestClass]
public class SessionTest_Connected_ServerIdentificationReceived : SessionTest_ConnectedBase
{
protected override void SetupData()
{
base.SetupData();

CallSessionConnectWhenArrange = false;

Session.ServerIdentificationReceived += (s, e) =>
{
if ((e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.5", System.StringComparison.Ordinal) || e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6", System.StringComparison.Ordinal))
&& !e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6.1", System.StringComparison.Ordinal))
{
_ = ConnectionInfo.KeyExchangeAlgorithms.Remove("curve25519-sha256");
_ = ConnectionInfo.KeyExchangeAlgorithms.Remove("curve25519-sha256@libssh.org");
}
};
}

protected override void Act()
{
}

[TestMethod]
[DataRow("OpenSSH_6.5")]
[DataRow("OpenSSH_6.5p1")]
[DataRow("OpenSSH_6.5 PKIX")]
[DataRow("OpenSSH_6.6")]
[DataRow("OpenSSH_6.6p1")]
[DataRow("OpenSSH_6.6 PKIX")]
public void ShouldExcludeCurve25519KexWhenServerIs(string softwareVersion)
{
ServerIdentification = new SshIdentification("2.0", softwareVersion);

Session.Connect();

Assert.IsFalse(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256"));
Assert.IsFalse(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256@libssh.org"));
}

[TestMethod]
[DataRow("OpenSSH_6.6.1")]
[DataRow("OpenSSH_6.6.1p1")]
[DataRow("OpenSSH_6.6.1 PKIX")]
[DataRow("OpenSSH_6.7")]
[DataRow("OpenSSH_6.7p1")]
[DataRow("OpenSSH_6.7 PKIX")]
public void ShouldIncludeCurve25519KexWhenServerIs(string softwareVersion)
{
ServerIdentification = new SshIdentification("2.0", softwareVersion);

Session.Connect();

Assert.IsTrue(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256"));
Assert.IsTrue(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256@libssh.org"));
}
}
}