Skip to content

Commit 94397d4

Browse files
scott-xuRob-Hague
andauthored
Implement OpenSSH strict key exchange extension (#1366)
* Implement OpenSSH strict key exchange extension * The pseudo-algorithm is only valid in the initial SSH2_MSG_KEXINIT and MUST be ignored if they are present in subsequent SSH2_MSG_KEXINIT packets. * Only send strict kex pseudo algorithm for the first kex. Strictly disable non-kex massages in strict kex mode. * Unit tests for strict kex * More unit tests * More unit tests * Correct file name * Update SessionTest_ConnectingBase.cs * More unit tests * Delete SessionTest_Connecting_ServerSendsMaxIgnoreMessagesBeforeKexInit.cs * Add a comment about throwing exception when inbound sequence number is about to wrap during init kex. * Delete SessionTest_Connecting_ServerSendsDebugMessageAfterKexInit_NoStrictKex.cs * Fix build * Update test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs --------- Co-authored-by: Rob Hague <rob.hague00@gmail.com>
1 parent 71423c1 commit 94397d4

17 files changed

+750
-58
lines changed

src/Renci.SshNet/Session.cs

Lines changed: 83 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,17 @@ public class Session : ISession
154154
/// </summary>
155155
private bool _isDisconnecting;
156156

157+
/// <summary>
158+
/// Indicates whether it is the init kex.
159+
/// </summary>
160+
private bool _isInitialKex;
161+
162+
/// <summary>
163+
/// Indicates whether server supports strict key exchange.
164+
/// <see href="https://github.com/openssh/openssh-portable/blob/master/PROTOCOL"/> 1.10.
165+
/// </summary>
166+
private bool _isStrictKex;
167+
157168
private IKeyExchange _keyExchange;
158169

159170
private HashAlgorithm _serverMac;
@@ -281,35 +292,11 @@ public bool IsConnected
281292
/// </value>
282293
public byte[] SessionId { get; private set; }
283294

284-
private Message _clientInitMessage;
285-
286295
/// <summary>
287296
/// Gets the client init message.
288297
/// </summary>
289298
/// <value>The client init message.</value>
290-
public Message ClientInitMessage
291-
{
292-
get
293-
{
294-
_clientInitMessage ??= new KeyExchangeInitMessage
295-
{
296-
KeyExchangeAlgorithms = ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(),
297-
ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(),
298-
EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(),
299-
EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(),
300-
MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
301-
MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
302-
CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
303-
CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
304-
LanguagesClientToServer = new[] { string.Empty },
305-
LanguagesServerToClient = new[] { string.Empty },
306-
FirstKexPacketFollows = false,
307-
Reserved = 0
308-
};
309-
310-
return _clientInitMessage;
311-
}
312-
}
299+
public Message ClientInitMessage { get; private set; }
313300

314301
/// <summary>
315302
/// Gets the server version string.
@@ -617,6 +604,8 @@ public void Connect()
617604
// Send our key exchange init.
618605
// We need to do this before starting the message listener to avoid the case where we receive the server
619606
// key exchange init and we continue the key exchange before having sent our own init.
607+
_isInitialKex = true;
608+
ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: true);
620609
SendMessage(ClientInitMessage);
621610

622611
// Mark the message listener threads as started
@@ -741,6 +730,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
741730
// Send our key exchange init.
742731
// We need to do this before starting the message listener to avoid the case where we receive the server
743732
// key exchange init and we continue the key exchange before having sent our own init.
733+
_isInitialKex = true;
734+
ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: true);
744735
SendMessage(ClientInitMessage);
745736

746737
// Mark the message listener threads as started
@@ -1107,13 +1098,20 @@ internal void SendMessage(Message message)
11071098
SendPacket(data, 0, data.Length);
11081099
}
11091100

1110-
// increment the packet sequence number only after we're sure the packet has
1111-
// been sent; even though it's only used for the MAC, it needs to be incremented
1112-
// for each package sent.
1113-
//
1114-
// the server will use it to verify the data integrity, and as such the order in
1115-
// which messages are sent must follow the outbound packet sequence number
1116-
_outboundPacketSequence++;
1101+
if (_isStrictKex && message is NewKeysMessage)
1102+
{
1103+
_outboundPacketSequence = 0;
1104+
}
1105+
else
1106+
{
1107+
// increment the packet sequence number only after we're sure the packet has
1108+
// been sent; even though it's only used for the MAC, it needs to be incremented
1109+
// for each package sent.
1110+
//
1111+
// the server will use it to verify the data integrity, and as such the order in
1112+
// which messages are sent must follow the outbound packet sequence number
1113+
_outboundPacketSequence++;
1114+
}
11171115
}
11181116
}
11191117

@@ -1344,6 +1342,13 @@ private Message ReceiveMessage(Socket socket)
13441342

13451343
_inboundPacketSequence++;
13461344

1345+
// The below code mirrors from https://github.com/openssh/openssh-portable/commit/1edb00c58f8a6875fad6a497aa2bacf37f9e6cd5
1346+
// It ensures the integrity of key exchange process.
1347+
if (_inboundPacketSequence == uint.MaxValue && _isInitialKex)
1348+
{
1349+
throw new SshConnectionException("Inbound packet sequence number is about to wrap during initial key exchange.", DisconnectReason.KeyExchangeFailed);
1350+
}
1351+
13471352
return LoadMessage(data, messagePayloadOffset, messagePayloadLength);
13481353
}
13491354

@@ -1455,8 +1460,20 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
14551460

14561461
_keyExchangeCompletedWaitHandle.Reset();
14571462

1463+
if (_isInitialKex && message.KeyExchangeAlgorithms.Contains("kex-strict-s-v00@openssh.com"))
1464+
{
1465+
_isStrictKex = true;
1466+
1467+
DiagnosticAbstraction.Log(string.Format("[{0}] Enabling strict key exchange extension.", ToHex(SessionId)));
1468+
1469+
if (_inboundPacketSequence != 1)
1470+
{
1471+
throw new SshConnectionException("KEXINIT was not the first packet during strict key exchange.", DisconnectReason.KeyExchangeFailed);
1472+
}
1473+
}
1474+
14581475
// Disable messages that are not key exchange related
1459-
_sshMessageFactory.DisableNonKeyExchangeMessages();
1476+
_sshMessageFactory.DisableNonKeyExchangeMessages(_isStrictKex);
14601477

14611478
_keyExchange = _serviceFactory.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms,
14621479
message.KeyExchangeAlgorithms);
@@ -1533,6 +1550,17 @@ internal void OnNewKeysReceived(NewKeysMessage message)
15331550
// Enable activated messages that are not key exchange related
15341551
_sshMessageFactory.EnableActivatedMessages();
15351552

1553+
if (_isInitialKex)
1554+
{
1555+
_isInitialKex = false;
1556+
ClientInitMessage = BuildClientInitMessage(includeStrictKexPseudoAlgorithm: false);
1557+
}
1558+
1559+
if (_isStrictKex)
1560+
{
1561+
_inboundPacketSequence = 0;
1562+
}
1563+
15361564
NewKeysReceived?.Invoke(this, new MessageEventArgs<NewKeysMessage>(message));
15371565

15381566
// Signal that key exchange completed
@@ -2067,7 +2095,28 @@ private void Reset()
20672095
private static SshConnectionException CreateConnectionAbortedByServerException()
20682096
{
20692097
return new SshConnectionException("An established connection was aborted by the server.",
2070-
DisconnectReason.ConnectionLost);
2098+
DisconnectReason.ConnectionLost);
2099+
}
2100+
2101+
private KeyExchangeInitMessage BuildClientInitMessage(bool includeStrictKexPseudoAlgorithm)
2102+
{
2103+
return new KeyExchangeInitMessage
2104+
{
2105+
KeyExchangeAlgorithms = includeStrictKexPseudoAlgorithm ?
2106+
ConnectionInfo.KeyExchangeAlgorithms.Keys.Concat(["kex-strict-c-v00@openssh.com"]).ToArray() :
2107+
ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(),
2108+
ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(),
2109+
EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(),
2110+
EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(),
2111+
MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
2112+
MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(),
2113+
CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
2114+
CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(),
2115+
LanguagesClientToServer = new[] { string.Empty },
2116+
LanguagesServerToClient = new[] { string.Empty },
2117+
FirstKexPacketFollows = false,
2118+
Reserved = 0,
2119+
};
20712120
}
20722121

20732122
private bool _disposed;

src/Renci.SshNet/SshMessageFactory.cs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,41 @@ public Message Create(byte messageNumber)
115115
return enabledMessageMetadata.Create();
116116
}
117117

118-
public void DisableNonKeyExchangeMessages()
118+
/// <summary>
119+
/// Disables non-KeyExchange messages.
120+
/// </summary>
121+
/// <param name="strict">
122+
/// <see langword="true"/> to indicate the strict key exchange mode; otherwise <see langword="false"/>.
123+
/// <para>In strict key exchange mode, only below messages are allowed:</para>
124+
/// <list type="bullet">
125+
/// <item>SSH_MSG_KEXINIT -> 20</item>
126+
/// <item>SSH_MSG_NEWKEYS -> 21</item>
127+
/// <item>SSH_MSG_DISCONNECT -> 1</item>
128+
/// </list>
129+
/// <para>Note:</para>
130+
/// <para> The relevant KEX Reply MSG will be allowed from a sub class of KeyExchange class.</para>
131+
/// <para> For example, it calls <c>Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");</c> if the curve25519-sha256 KEX algorithm is selected per negotiation.</para>
132+
/// </param>
133+
public void DisableNonKeyExchangeMessages(bool strict)
119134
{
120135
for (var i = 0; i < AllMessages.Length; i++)
121136
{
122137
var messageMetadata = AllMessages[i];
123138

124139
var messageNumber = messageMetadata.Number;
125-
if (messageNumber is (> 2 and < 20) or > 30)
140+
if (strict)
141+
{
142+
if (messageNumber is not 20 and not 21 and not 1)
143+
{
144+
_enabledMessagesByNumber[messageNumber] = null;
145+
}
146+
}
147+
else
126148
{
127-
_enabledMessagesByNumber[messageNumber] = null;
149+
if (messageNumber is (> 2 and < 20) or > 30)
150+
{
151+
_enabledMessagesByNumber[messageNumber] = null;
152+
}
128153
}
129154
}
130155
}

test/Renci.SshNet.Tests/Classes/SessionTest_ConnectToServerFails.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public void IsConnectedShouldReturnFalse()
8787
}
8888

8989
[TestMethod]
90-
public void SendMessageShouldThrowShhConnectionException()
90+
public void SendMessageShouldThrowSshConnectionException()
9191
{
9292
try
9393
{
@@ -189,7 +189,7 @@ public void ISession_MessageListenerCompletedShouldBeSignaled()
189189
}
190190

191191
[TestMethod]
192-
public void ISession_SendMessageShouldThrowShhConnectionException()
192+
public void ISession_SendMessageShouldThrowSshConnectionException()
193193
{
194194
var session = (ISession)_session;
195195

test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Linq;
23
using System.Threading;
34
using Microsoft.VisualStudio.TestTools.UnitTesting;
45
using Moq;
@@ -30,6 +31,31 @@ public void ClientVersionIsRenciSshNet()
3031
Assert.AreEqual("SSH-2.0-Renci.SshNet.SshClient.0.0.1", Session.ClientVersion);
3132
}
3233

34+
[TestMethod]
35+
public void IncludeStrictKexPseudoAlgorithmInInitKex()
36+
{
37+
Assert.IsTrue(ServerBytesReceivedRegister.Count > 0);
38+
39+
var kexInitMessage = new KeyExchangeInitMessage();
40+
kexInitMessage.Load(ServerBytesReceivedRegister[0], 4 + 1 + 1, ServerBytesReceivedRegister[0].Length - 4 - 1 - 1);
41+
Assert.IsTrue(kexInitMessage.KeyExchangeAlgorithms.Contains("kex-strict-c-v00@openssh.com"));
42+
}
43+
44+
[TestMethod]
45+
public void ShouldNotIncludeStrictKexPseudoAlgorithmInSubsequentKex()
46+
{
47+
ServerBytesReceivedRegister.Clear();
48+
Session.SendMessage(Session.ClientInitMessage);
49+
50+
Thread.Sleep(100);
51+
52+
Assert.IsTrue(ServerBytesReceivedRegister.Count > 0);
53+
54+
var kexInitMessage = new KeyExchangeInitMessage();
55+
kexInitMessage.Load(ServerBytesReceivedRegister[0], 4 + 1 + 1, ServerBytesReceivedRegister[0].Length - 4 - 1 - 1);
56+
Assert.IsFalse(kexInitMessage.KeyExchangeAlgorithms.Contains("kex-strict-c-v00@openssh.com"));
57+
}
58+
3359
[TestMethod]
3460
public void ConnectionInfoShouldReturnConnectionInfoPassedThroughConstructor()
3561
{

test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ public abstract class SessionTest_ConnectedBase
4646
protected Session Session { get; private set; }
4747
protected Socket ClientSocket { get; private set; }
4848
protected Socket ServerSocket { get; private set; }
49-
internal SshIdentification ServerIdentification { get; set; }
50-
protected bool CallSessionConnectWhenArrange { get; set; }
49+
protected SshIdentification ServerIdentification { get; private set; }
5150

5251
/// <summary>
5352
/// Should the "server" wait for the client kexinit before sending its own.
@@ -163,8 +162,6 @@ protected virtual void SetupData()
163162

164163
ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);
165164

166-
CallSessionConnectWhenArrange = true;
167-
168165
void SendKeyExchangeInit()
169166
{
170167
var keyExchangeInitMessage = new KeyExchangeInitMessage
@@ -204,7 +201,7 @@ private void SetupMocks()
204201
_ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange())
205202
.Returns(_protocolVersionExchangeMock.Object);
206203
_ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout))
207-
.Returns(() => ServerIdentification);
204+
.Returns(ServerIdentification);
208205
_ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
209206
_ = _keyExchangeMock.Setup(p => p.Name)
210207
.Returns(_keyExchangeAlgorithm);
@@ -252,10 +249,7 @@ protected void Arrange()
252249
SetupData();
253250
SetupMocks();
254251

255-
if (CallSessionConnectWhenArrange)
256-
{
257-
Session.Connect();
258-
}
252+
Session.Connect();
259253
}
260254

261255
protected virtual void ClientAuthentication_Callback()

0 commit comments

Comments
 (0)