From e23f381deaab258bad97156211b496631f6de639 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Sun, 11 Feb 2024 09:00:38 +0000 Subject: [PATCH 1/6] Add tests and benchmarks for ShellStream.Read and Expect (#1313) Most of them are ignored because they fail. --- .../SshClientBenchmark.cs | 35 ++ .../Classes/ShellStreamTest_ReadExpect.cs | 329 ++++++++++++++++++ 2 files changed, 364 insertions(+) create mode 100644 test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs diff --git a/test/Renci.SshNet.IntegrationBenchmarks/SshClientBenchmark.cs b/test/Renci.SshNet.IntegrationBenchmarks/SshClientBenchmark.cs index 0ce126e0e..32659ba11 100644 --- a/test/Renci.SshNet.IntegrationBenchmarks/SshClientBenchmark.cs +++ b/test/Renci.SshNet.IntegrationBenchmarks/SshClientBenchmark.cs @@ -1,5 +1,6 @@ using BenchmarkDotNet.Attributes; +using Renci.SshNet.Common; using Renci.SshNet.IntegrationTests.TestsFixtures; namespace Renci.SshNet.IntegrationBenchmarks @@ -8,6 +9,11 @@ namespace Renci.SshNet.IntegrationBenchmarks [SimpleJob] public class SshClientBenchmark : IntegrationBenchmarkBase { + private static readonly Dictionary ShellStreamTerminalModes = new Dictionary + { + { TerminalModes.ECHO, 0 } + }; + private readonly InfrastructureFixture _infrastructureFixture; private SshClient? _sshClient; @@ -65,5 +71,34 @@ public string RunCommand() { return _sshClient!.RunCommand("echo $'test !@#$%^&*()_+{}:,./<>[];\\|'").Result; } + + [Benchmark] + public string ShellStreamReadLine() + { + using (var shellStream = _sshClient!.CreateShellStream("xterm", 80, 24, 800, 600, 1024, ShellStreamTerminalModes)) + { + shellStream.WriteLine("for i in $(seq 500); do echo \"Within cells. Interlinked. $i\"; sleep 0.001; done; echo \"Username:\";"); + + while (true) + { + var line = shellStream.ReadLine(); + + if (line.EndsWith("500", StringComparison.Ordinal)) + { + return line; + } + } + } + } + + [Benchmark] + public string ShellStreamExpect() + { + using (var shellStream = _sshClient!.CreateShellStream("xterm", 80, 24, 800, 600, 1024, ShellStreamTerminalModes)) + { + shellStream.WriteLine("for i in $(seq 500); do echo \"Within cells. Interlinked. $i\"; sleep 0.001; done; echo \"Username:\";"); + return shellStream.Expect("Username:"); + } + } } } diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs new file mode 100644 index 000000000..8e0387160 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs @@ -0,0 +1,329 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Moq; + +using Renci.SshNet.Channels; +using Renci.SshNet.Common; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class ShellStreamTest_ReadExpect + { + private ShellStream _shellStream; + private ChannelSessionStub _channelSessionStub; + + [TestInitialize] + public void Initialize() + { + _channelSessionStub = new ChannelSessionStub(); + + var connectionInfoMock = new Mock(); + + connectionInfoMock.Setup(p => p.Encoding).Returns(Encoding.UTF8); + + var sessionMock = new Mock(); + + sessionMock.Setup(p => p.ConnectionInfo).Returns(connectionInfoMock.Object); + sessionMock.Setup(p => p.CreateChannelSession()).Returns(_channelSessionStub); + + _shellStream = new ShellStream( + sessionMock.Object, + "terminalName", + columns: 80, + rows: 24, + width: 800, + height: 600, + terminalModeValues: null, + bufferSize: 1024); + } + + [TestMethod] + public void Read_String() + { + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Hello ")); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("World!")); + + Assert.AreEqual("Hello World!", _shellStream.Read()); + } + + [TestMethod] + public void Read_Bytes() + { + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Hello ")); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("World!")); + + byte[] buffer = new byte[12]; + + Assert.AreEqual(7, _shellStream.Read(buffer, 3, 7)); + CollectionAssert.AreEqual(Encoding.UTF8.GetBytes("\0\0\0Hello W\0\0"), buffer); + + Assert.AreEqual(5, _shellStream.Read(buffer, 0, 12)); + CollectionAssert.AreEqual(Encoding.UTF8.GetBytes("orld!llo W\0\0"), buffer); + } + + [DataTestMethod] + [DataRow("\r\n")] + //[DataRow("\r")] These currently fail. + //[DataRow("\n")] + public void ReadLine(string newLine) + { + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Hello ")); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("World!")); + + // We specify a nonzero timeout to avoid waiting infinitely. + Assert.IsNull(_shellStream.ReadLine(TimeSpan.FromTicks(1))); + + _channelSessionStub.Receive(Encoding.UTF8.GetBytes(newLine)); + + Assert.AreEqual("Hello World!", _shellStream.ReadLine(TimeSpan.FromTicks(1))); + Assert.IsNull(_shellStream.ReadLine(TimeSpan.FromTicks(1))); + + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Second line!" + newLine + "Third line!" + newLine)); + + Assert.AreEqual("Second line!", _shellStream.ReadLine(TimeSpan.FromTicks(1))); + Assert.AreEqual("Third line!", _shellStream.ReadLine(TimeSpan.FromTicks(1))); + Assert.IsNull(_shellStream.ReadLine(TimeSpan.FromTicks(1))); + } + + [DataTestMethod] + [DataRow("\r\n")] + [DataRow("\r")] + [DataRow("\n")] + public void Read_MultipleLines(string newLine) + { + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Hello ")); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("World!")); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes(newLine)); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Second line!" + newLine + "Third line!" + newLine)); + + Assert.AreEqual("Hello World!" + newLine + "Second line!" + newLine + "Third line!" + newLine, _shellStream.Read()); + } + + [TestMethod] + [Ignore] // Currently returns 0 immediately + public void Read_NonEmptyArray_OnlyReturnsZeroAfterClose() + { + Task closeTask = Task.Run(async () => + { + // For the test to have meaning, we should be in + // the call to Read before closing the channel. + // Impose a short delay to make that more likely. + await Task.Delay(50); + + _channelSessionStub.Close(); + }); + + Assert.AreEqual(0, _shellStream.Read(new byte[16], 0, 16)); + Assert.AreEqual(TaskStatus.RanToCompletion, closeTask.Status); + } + + [TestMethod] + [Ignore] // Currently returns 0 immediately + public void Read_EmptyArray_OnlyReturnsZeroWhenDataAvailable() + { + Task receiveTask = Task.Run(async () => + { + // For the test to have meaning, we should be in + // the call to Read before receiving the data. + // Impose a short delay to make that more likely. + await Task.Delay(50); + + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Hello World!")); + }); + + Assert.AreEqual(0, _shellStream.Read(Array.Empty(), 0, 0)); + Assert.AreEqual(TaskStatus.RanToCompletion, receiveTask.Status); + } + + [TestMethod] + [Ignore] // Currently hangs + public void ReadLine_NoData_ReturnsNullAfterClose() + { + Task closeTask = Task.Run(async () => + { + await Task.Delay(50); + + _channelSessionStub.Close(); + }); + + Assert.IsNull(_shellStream.ReadLine()); + Assert.AreEqual(TaskStatus.RanToCompletion, closeTask.Status); + } + + [TestMethod] + [Ignore] // Fails because it returns the whole buffer i.e. "Hello World!\r\n12345" + // We might actually want to keep that behaviour, but just make the documentation clearer. + // The Expect documentation says: + // "The text available in the shell that contains all the text that ends with expected expression." + // Does that mean + // 1. the returned string ends with the expected expression; or + // 2. the returned string is all the text in the buffer, which is guaranteed to contain the expected expression? + // The current behaviour is closer to 2. I think the documentation implies 1. + // Either way, there are bugs. + public void Expect() + { + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Hello ")); + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("World!")); + + Assert.IsNull(_shellStream.Expect("123", TimeSpan.FromTicks(1))); + + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("\r\n12345")); + + // Both of these cases fail + // Case 1 above. + Assert.AreEqual("Hello World!\r\n123", _shellStream.Expect("123")); // Fails, returns "Hello World!\r\n12345" + Assert.AreEqual("45", _shellStream.Read()); // Passes, but should probably fail and return "" + + // Case 2 above. + Assert.AreEqual("Hello World!\r\n12345", _shellStream.Expect("123")); // Passes + Assert.AreEqual("", _shellStream.Read()); // Fails, returns "45" + } + + [TestMethod] + public void Read_MultiByte() + { + _channelSessionStub.Receive(new byte[] { 0xF0 }); + _channelSessionStub.Receive(new byte[] { 0x9F }); + _channelSessionStub.Receive(new byte[] { 0x91 }); + _channelSessionStub.Receive(new byte[] { 0x8D }); + + Assert.AreEqual("👍", _shellStream.Read()); + } + + [TestMethod] + public void ReadLine_MultiByte() + { + _channelSessionStub.Receive(new byte[] { 0xF0 }); + _channelSessionStub.Receive(new byte[] { 0x9F }); + _channelSessionStub.Receive(new byte[] { 0x91 }); + _channelSessionStub.Receive(new byte[] { 0x8D }); + _channelSessionStub.Receive(new byte[] { 0x0D }); + _channelSessionStub.Receive(new byte[] { 0x0A }); + + Assert.AreEqual("👍", _shellStream.ReadLine()); + Assert.AreEqual("", _shellStream.Read()); + } + + [TestMethod] + [Ignore] + public void Expect_Regex_MultiByte() + { + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("𐓏𐓘𐓻𐓘𐓻𐓟 𐒻𐓟")); + + Assert.AreEqual("𐓏𐓘𐓻𐓘𐓻𐓟 ", _shellStream.Expect(new Regex(@"\s"))); + Assert.AreEqual("𐒻𐓟", _shellStream.Read()); + } + + [TestMethod] + [Ignore] + public void Expect_String_MultiByte() + { + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("hello 你好")); + + Assert.AreEqual("hello 你好", _shellStream.Expect("你好")); + Assert.AreEqual("", _shellStream.Read()); + } + + [TestMethod] + public void Expect_Timeout() + { + Stopwatch stopwatch = Stopwatch.StartNew(); + + Assert.IsNull(_shellStream.Expect("Hello World!", TimeSpan.FromMilliseconds(200))); + + TimeSpan elapsed = stopwatch.Elapsed; + + // Account for variance in system timer resolution. + Assert.IsTrue(elapsed > TimeSpan.FromMilliseconds(180), elapsed.ToString()); + } + + private class ChannelSessionStub : IChannelSession + { + public void Receive(byte[] data) + { + DataReceived.Invoke(this, new ChannelDataEventArgs(channelNumber: 0, data)); + } + + public void Close() + { + Closed.Invoke(this, new ChannelEventArgs(channelNumber: 0)); + } + + public bool SendShellRequest() + { + return true; + } + + public bool SendPseudoTerminalRequest(string environmentVariable, uint columns, uint rows, uint width, uint height, IDictionary terminalModeValues) + { + return true; + } + + public void Dispose() + { + } + + public void Open() + { + } + + public event EventHandler DataReceived; + public event EventHandler Closed; +#pragma warning disable 0067 + public event EventHandler Exception; + public event EventHandler ExtendedDataReceived; + public event EventHandler RequestReceived; +#pragma warning restore 0067 + +#pragma warning disable IDE0025 // Use block body for property +#pragma warning disable IDE0022 // Use block body for method + public uint LocalChannelNumber => throw new NotImplementedException(); + + public uint LocalPacketSize => throw new NotImplementedException(); + + public uint RemotePacketSize => throw new NotImplementedException(); + + public bool IsOpen => throw new NotImplementedException(); + + public bool SendBreakRequest(uint breakLength) => throw new NotImplementedException(); + + public void SendData(byte[] data) => throw new NotImplementedException(); + + public void SendData(byte[] data, int offset, int size) => throw new NotImplementedException(); + + public bool SendEndOfWriteRequest() => throw new NotImplementedException(); + + public bool SendEnvironmentVariableRequest(string variableName, string variableValue) => throw new NotImplementedException(); + + public void SendEof() => throw new NotImplementedException(); + + public bool SendExecRequest(string command) => throw new NotImplementedException(); + + public bool SendExitSignalRequest(string signalName, bool coreDumped, string errorMessage, string language) => throw new NotImplementedException(); + + public bool SendExitStatusRequest(uint exitStatus) => throw new NotImplementedException(); + + public bool SendKeepAliveRequest() => throw new NotImplementedException(); + + public bool SendLocalFlowRequest(bool clientCanDo) => throw new NotImplementedException(); + + public bool SendSignalRequest(string signalName) => throw new NotImplementedException(); + + public bool SendSubsystemRequest(string subsystem) => throw new NotImplementedException(); + + public bool SendWindowChangeRequest(uint columns, uint rows, uint width, uint height) => throw new NotImplementedException(); + + public bool SendX11ForwardingRequest(bool isSingleConnection, string protocol, byte[] cookie, uint screenNumber) => throw new NotImplementedException(); +#pragma warning restore IDE0022 // Use block body for method +#pragma warning restore IDE0025 // Use block body for property + } + } +} From dec04f26e0d71deb6c2e470fcd1e4c6c020a9ecb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Nag=C3=B3rski?= Date: Sun, 11 Feb 2024 20:15:11 +0100 Subject: [PATCH 2/6] 2023.0.2 version (#1314) --- src/Renci.SshNet/Renci.SshNet.csproj | 2 +- .../Renci.SshNet.IntegrationTests.csproj | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Renci.SshNet/Renci.SshNet.csproj b/src/Renci.SshNet/Renci.SshNet.csproj index 567f5aef9..f1ce038ba 100644 --- a/src/Renci.SshNet/Renci.SshNet.csproj +++ b/src/Renci.SshNet/Renci.SshNet.csproj @@ -9,7 +9,7 @@ SSH.NET SSH.NET - 2023.0.1 + 2023.0.2 SSH.NET is a Secure Shell (SSH) library for .NET, optimized for parallelism. Copyright © Renci 2010-$([System.DateTime]::UtcNow.Year) MIT diff --git a/test/Renci.SshNet.IntegrationTests/Renci.SshNet.IntegrationTests.csproj b/test/Renci.SshNet.IntegrationTests/Renci.SshNet.IntegrationTests.csproj index 109da05d2..ed0a5d0c0 100644 --- a/test/Renci.SshNet.IntegrationTests/Renci.SshNet.IntegrationTests.csproj +++ b/test/Renci.SshNet.IntegrationTests/Renci.SshNet.IntegrationTests.csproj @@ -17,9 +17,9 @@ From 3bfac50ad0734b11b99559e6c7b2c400ba82e25a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Nag=C3=B3rski?= Date: Sun, 11 Feb 2024 23:06:25 +0100 Subject: [PATCH 3/6] Improve ShellStream Expect (#1315) --- src/Renci.SshNet/ShellStream.cs | 17 ++++++++----- .../Classes/ShellStreamTest_ReadExpect.cs | 25 ++++++++----------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/Renci.SshNet/ShellStream.cs b/src/Renci.SshNet/ShellStream.cs index ce57072d5..e28935735 100644 --- a/src/Renci.SshNet/ShellStream.cs +++ b/src/Renci.SshNet/ShellStream.cs @@ -268,8 +268,9 @@ public void Expect(TimeSpan timeout, params ExpectAction[] expectActions) if (match.Success) { var result = text.Substring(0, match.Index + match.Length); + var charCount = _encoding.GetByteCount(result); - for (var i = 0; i < match.Index + match.Length && _incoming.Count > 0; i++) + for (var i = 0; i < charCount && _incoming.Count > 0; i++) { // Remove processed items from the queue _ = _incoming.Dequeue(); @@ -348,7 +349,7 @@ public string Expect(Regex regex) /// public string Expect(Regex regex, TimeSpan timeout) { - var text = string.Empty; + var result = string.Empty; while (true) { @@ -356,15 +357,18 @@ public string Expect(Regex regex, TimeSpan timeout) { if (_incoming.Count > 0) { - text = _encoding.GetString(_incoming.ToArray(), 0, _incoming.Count); + result = _encoding.GetString(_incoming.ToArray(), 0, _incoming.Count); } - var match = regex.Match(text); + var match = regex.Match(result); if (match.Success) { + result = result.Substring(0, match.Index + match.Length); + var charCount = _encoding.GetByteCount(result); + // Remove processed items from the queue - for (var i = 0; i < match.Index + match.Length && _incoming.Count > 0; i++) + for (var i = 0; i < charCount && _incoming.Count > 0; i++) { _ = _incoming.Dequeue(); } @@ -386,7 +390,7 @@ public string Expect(Regex regex, TimeSpan timeout) } } - return text; + return result; } /// @@ -471,6 +475,7 @@ public IAsyncResult BeginExpect(TimeSpan timeout, AsyncCallback callback, object if (match.Success) { var result = text.Substring(0, match.Index + match.Length); + var charCount = _encoding.GetByteCount(result); for (var i = 0; i < match.Index + match.Length && _incoming.Count > 0; i++) { diff --git a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs index 8e0387160..9aa34423b 100644 --- a/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs +++ b/test/Renci.SshNet.Tests/Classes/ShellStreamTest_ReadExpect.cs @@ -159,15 +159,6 @@ public void ReadLine_NoData_ReturnsNullAfterClose() } [TestMethod] - [Ignore] // Fails because it returns the whole buffer i.e. "Hello World!\r\n12345" - // We might actually want to keep that behaviour, but just make the documentation clearer. - // The Expect documentation says: - // "The text available in the shell that contains all the text that ends with expected expression." - // Does that mean - // 1. the returned string ends with the expected expression; or - // 2. the returned string is all the text in the buffer, which is guaranteed to contain the expected expression? - // The current behaviour is closer to 2. I think the documentation implies 1. - // Either way, there are bugs. public void Expect() { _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Hello ")); @@ -181,10 +172,6 @@ public void Expect() // Case 1 above. Assert.AreEqual("Hello World!\r\n123", _shellStream.Expect("123")); // Fails, returns "Hello World!\r\n12345" Assert.AreEqual("45", _shellStream.Read()); // Passes, but should probably fail and return "" - - // Case 2 above. - Assert.AreEqual("Hello World!\r\n12345", _shellStream.Expect("123")); // Passes - Assert.AreEqual("", _shellStream.Read()); // Fails, returns "45" } [TestMethod] @@ -213,7 +200,6 @@ public void ReadLine_MultiByte() } [TestMethod] - [Ignore] public void Expect_Regex_MultiByte() { _channelSessionStub.Receive(Encoding.UTF8.GetBytes("𐓏𐓘𐓻𐓘𐓻𐓟 𐒻𐓟")); @@ -223,7 +209,6 @@ public void Expect_Regex_MultiByte() } [TestMethod] - [Ignore] public void Expect_String_MultiByte() { _channelSessionStub.Receive(Encoding.UTF8.GetBytes("hello 你好")); @@ -232,6 +217,16 @@ public void Expect_String_MultiByte() Assert.AreEqual("", _shellStream.Read()); } + [TestMethod] + public void Expect_String_non_ASCII_characters() + { + _channelSessionStub.Receive(Encoding.UTF8.GetBytes("Hello, こんにちは, Bonjour")); + + Assert.AreEqual("Hello, こ", _shellStream.Expect(new Regex(@"[^\u0000-\u007F]"))); + + Assert.AreEqual("んにちは, Bonjour", _shellStream.Read()); + } + [TestMethod] public void Expect_Timeout() { From 4bfcfaeb8028bd315210fb74475a55f66d8e1166 Mon Sep 17 00:00:00 2001 From: "Xu, Scott" Date: Mon, 12 Feb 2024 14:16:15 +0800 Subject: [PATCH 4/6] Support ETM (Encrypt-then-MAC) variants for HMAC --- .../Abstractions/CryptoAbstraction.cs | 58 +++++++----- src/Renci.SshNet/ConnectionInfo.cs | 4 + src/Renci.SshNet/HashInfo.cs | 13 ++- .../Security/Cryptography/HMAC.cs | 49 +++++++++++ src/Renci.SshNet/Security/IKeyExchange.cs | 5 +- src/Renci.SshNet/Security/KeyExchange.cs | 11 ++- src/Renci.SshNet/Session.cs | 88 +++++++++++++------ .../Classes/SessionTest_ConnectedBase.cs | 5 +- ...Connected_ServerAndClientDisconnectRace.cs | 5 +- 9 files changed, 169 insertions(+), 69 deletions(-) create mode 100644 src/Renci.SshNet/Security/Cryptography/HMAC.cs diff --git a/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs b/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs index ca187833f..8456b37a8 100644 --- a/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/CryptoAbstraction.cs @@ -84,75 +84,85 @@ public static System.Security.Cryptography.RIPEMD160 CreateRIPEMD160() } #endif // FEATURE_HASH_RIPEMD160 - public static System.Security.Cryptography.HMACMD5 CreateHMACMD5(byte[] key) + public static HMAC CreateHMACMD5(byte[] key) { #pragma warning disable CA5351 // Do not use broken cryptographic algorithms - return new System.Security.Cryptography.HMACMD5(key); + return new HMAC(new System.Security.Cryptography.HMACMD5(key)); #pragma warning restore CA5351 // Do not use broken cryptographic algorithms } - public static HMACMD5 CreateHMACMD5(byte[] key, int hashSize) + public static HMAC CreateHMACMD5(byte[] key, int hashSize) { #pragma warning disable CA5351 // Do not use broken cryptographic algorithms - return new HMACMD5(key, hashSize); + return new HMAC(new HMACMD5(key, hashSize)); #pragma warning restore CA5351 // Do not use broken cryptographic algorithms } - public static System.Security.Cryptography.HMACSHA1 CreateHMACSHA1(byte[] key) + public static HMAC CreateHMACSHA1(byte[] key) { #pragma warning disable CA5350 // Do not use weak cryptographic algorithms - return new System.Security.Cryptography.HMACSHA1(key); + return new HMAC(new System.Security.Cryptography.HMACSHA1(key)); #pragma warning restore CA5350 // Do not use weak cryptographic algorithms } - public static HMACSHA1 CreateHMACSHA1(byte[] key, int hashSize) + public static HMAC CreateHMACSHA1(byte[] key, int hashSize) { #pragma warning disable CA5350 // Do not use weak cryptographic algorithms - return new HMACSHA1(key, hashSize); + return new HMAC(new HMACSHA1(key, hashSize)); #pragma warning restore CA5350 // Do not use weak cryptographic algorithms } - public static System.Security.Cryptography.HMACSHA256 CreateHMACSHA256(byte[] key) + public static HMAC CreateHMACSHA256(byte[] key) { - return new System.Security.Cryptography.HMACSHA256(key); + return new HMAC(new System.Security.Cryptography.HMACSHA256(key)); } - public static HMACSHA256 CreateHMACSHA256(byte[] key, int hashSize) + public static HMAC CreateHMACSHA256(byte[] key, int hashSize) { - return new HMACSHA256(key, hashSize); + return new HMAC(new HMACSHA256(key, hashSize)); } - public static System.Security.Cryptography.HMACSHA384 CreateHMACSHA384(byte[] key) + public static HMAC CreateHMACSHA256(byte[] key, bool etm) { - return new System.Security.Cryptography.HMACSHA384(key); + return new HMAC(new System.Security.Cryptography.HMACSHA256(key), etm); } - public static HMACSHA384 CreateHMACSHA384(byte[] key, int hashSize) + public static HMAC CreateHMACSHA384(byte[] key) { - return new HMACSHA384(key, hashSize); + return new HMAC(new System.Security.Cryptography.HMACSHA384(key)); } - public static System.Security.Cryptography.HMACSHA512 CreateHMACSHA512(byte[] key) + public static HMAC CreateHMACSHA384(byte[] key, int hashSize) { - return new System.Security.Cryptography.HMACSHA512(key); + return new HMAC(new HMACSHA384(key, hashSize)); } - public static HMACSHA512 CreateHMACSHA512(byte[] key, int hashSize) + public static HMAC CreateHMACSHA512(byte[] key) { - return new HMACSHA512(key, hashSize); + return new HMAC(new System.Security.Cryptography.HMACSHA512(key)); + } + + public static HMAC CreateHMACSHA512(byte[] key, int hashSize) + { + return new HMAC(new HMACSHA512(key, hashSize)); + } + + public static HMAC CreateHMACSHA512(byte[] key, bool etm) + { + return new HMAC(new System.Security.Cryptography.HMACSHA512(key), etm); } #if FEATURE_HMAC_RIPEMD160 - public static System.Security.Cryptography.HMACRIPEMD160 CreateHMACRIPEMD160(byte[] key) + public static HMAC CreateHMACRIPEMD160(byte[] key) { #pragma warning disable CA5350 // Do not use weak cryptographic algorithms - return new System.Security.Cryptography.HMACRIPEMD160(key); + return new HMAC(new System.Security.Cryptography.HMACRIPEMD160(key)); #pragma warning restore CA5350 // Do not use weak cryptographic algorithms } #else - public static global::SshNet.Security.Cryptography.HMACRIPEMD160 CreateHMACRIPEMD160(byte[] key) + public static HMAC CreateHMACRIPEMD160(byte[] key) { - return new global::SshNet.Security.Cryptography.HMACRIPEMD160(key); + return new HMAC(new global::SshNet.Security.Cryptography.HMACRIPEMD160(key)); } #endif // FEATURE_HMAC_RIPEMD160 } diff --git a/src/Renci.SshNet/ConnectionInfo.cs b/src/Renci.SshNet/ConnectionInfo.cs index b4cb0fc85..8abf0a75b 100644 --- a/src/Renci.SshNet/ConnectionInfo.cs +++ b/src/Renci.SshNet/ConnectionInfo.cs @@ -376,6 +376,7 @@ public ConnectionInfo(string host, int port, string username, ProxyTypes proxyTy #pragma warning disable IDE0200 // Remove unnecessary lambda expression; We want to prevent instantiating the HashAlgorithm objects. HmacAlgorithms = new Dictionary { + /* Encrypt-and-MAC (encrypt-and-authenticate) variants */ { "hmac-sha2-256", new HashInfo(32*8, key => CryptoAbstraction.CreateHMACSHA256(key)) }, { "hmac-sha2-512", new HashInfo(64 * 8, key => CryptoAbstraction.CreateHMACSHA512(key)) }, { "hmac-sha2-512-96", new HashInfo(64 * 8, key => CryptoAbstraction.CreateHMACSHA512(key, 96)) }, @@ -386,6 +387,9 @@ public ConnectionInfo(string host, int port, string username, ProxyTypes proxyTy { "hmac-sha1-96", new HashInfo(20*8, key => CryptoAbstraction.CreateHMACSHA1(key, 96)) }, { "hmac-md5", new HashInfo(16*8, key => CryptoAbstraction.CreateHMACMD5(key)) }, { "hmac-md5-96", new HashInfo(16*8, key => CryptoAbstraction.CreateHMACMD5(key, 96)) }, + /* Encrypt-then-MAC variants */ + { "hmac-sha2-256-etm@openssh.com", new HashInfo(32*8, key => CryptoAbstraction.CreateHMACSHA256(key, etm: true)) }, + { "hmac-sha2-512-etm@openssh.com", new HashInfo(64 * 8, key => CryptoAbstraction.CreateHMACSHA512(key, etm: true)) }, }; #pragma warning restore IDE0200 // Remove unnecessary lambda expression diff --git a/src/Renci.SshNet/HashInfo.cs b/src/Renci.SshNet/HashInfo.cs index cbbbf5fe7..a7f0e9375 100644 --- a/src/Renci.SshNet/HashInfo.cs +++ b/src/Renci.SshNet/HashInfo.cs @@ -1,6 +1,6 @@ using System; -using System.Security.Cryptography; using Renci.SshNet.Common; +using Renci.SshNet.Security.Cryptography; namespace Renci.SshNet { @@ -20,17 +20,22 @@ public class HashInfo /// /// Gets the cipher. /// - public Func HashAlgorithm { get; private set; } + public Func HMAC { get; private set; } + + /// + /// Gets a value indicating whether Encrypt-then-MAC or not. + /// + public bool ETM { get; private set; } /// /// Initializes a new instance of the class. /// /// Size of the key. /// The hash algorithm to use for a given key. - public HashInfo(int keySize, Func hash) + public HashInfo(int keySize, Func hash) { KeySize = keySize; - HashAlgorithm = key => hash(key.Take(KeySize / 8)); + HMAC = key => hash(key.Take(KeySize / 8)); } } } diff --git a/src/Renci.SshNet/Security/Cryptography/HMAC.cs b/src/Renci.SshNet/Security/Cryptography/HMAC.cs new file mode 100644 index 000000000..079089b89 --- /dev/null +++ b/src/Renci.SshNet/Security/Cryptography/HMAC.cs @@ -0,0 +1,49 @@ +using System; +using System.Security.Cryptography; + +namespace Renci.SshNet.Security.Cryptography +{ + /// + /// Represents the info for Message Authentication Code (MAC). + /// + public sealed class HMAC : IDisposable + { + /// + /// Initializes a new instance of the class. + /// + /// The hash algorithm. + public HMAC(HashAlgorithm hashAlgorithm) + : this(hashAlgorithm, etm: false) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The hash algorithm. + /// to enable encrypt-then-MAC, to use encrypt-and-MAC. + public HMAC( + HashAlgorithm hashAlgorithm, + bool etm) + { + HashAlgorithm = hashAlgorithm; + ETM = etm; + } + + /// + public void Dispose() + { + HashAlgorithm?.Dispose(); + } + + /// + /// Gets the hash algorithem. + /// + public HashAlgorithm HashAlgorithm { get; private set; } + + /// + /// Gets a value indicating whether enable encryption-to-mac or encryption-then-mac. + /// + public bool ETM { get; private set; } + } +} diff --git a/src/Renci.SshNet/Security/IKeyExchange.cs b/src/Renci.SshNet/Security/IKeyExchange.cs index 7ffd2f465..545a5f872 100644 --- a/src/Renci.SshNet/Security/IKeyExchange.cs +++ b/src/Renci.SshNet/Security/IKeyExchange.cs @@ -1,5 +1,4 @@ using System; -using System.Security.Cryptography; using Renci.SshNet.Common; using Renci.SshNet.Compression; @@ -69,7 +68,7 @@ public interface IKeyExchange : IDisposable /// /// The server hash algorithm. /// - HashAlgorithm CreateServerHash(); + HMAC CreateServerHash(); /// /// Creates the client-side hash algorithm to use. @@ -77,7 +76,7 @@ public interface IKeyExchange : IDisposable /// /// The client hash algorithm. /// - HashAlgorithm CreateClientHash(); + HMAC CreateClientHash(); /// /// Creates the compression algorithm to use to deflate data. diff --git a/src/Renci.SshNet/Security/KeyExchange.cs b/src/Renci.SshNet/Security/KeyExchange.cs index f24dcafe1..d1544ad33 100644 --- a/src/Renci.SshNet/Security/KeyExchange.cs +++ b/src/Renci.SshNet/Security/KeyExchange.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Security.Cryptography; using Renci.SshNet.Abstractions; using Renci.SshNet.Common; @@ -103,7 +102,7 @@ from a in message.MacAlgorithmsClientToServer select a).FirstOrDefault(); if (string.IsNullOrEmpty(clientHmacAlgorithmName)) { - throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed); + throw new SshConnectionException("Client HMAC algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentClientHmacAlgorithm = clientHmacAlgorithmName; @@ -221,7 +220,7 @@ public Cipher CreateClientCipher() /// /// The server-side hash algorithm. /// - public HashAlgorithm CreateServerHash() + public HMAC CreateServerHash() { // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; @@ -235,7 +234,7 @@ public HashAlgorithm CreateServerHash() Session.ToHex(Session.SessionId), Session.ConnectionInfo.CurrentServerHmacAlgorithm)); - return _serverHashInfo.HashAlgorithm(serverKey); + return _serverHashInfo.HMAC(serverKey); } /// @@ -244,7 +243,7 @@ public HashAlgorithm CreateServerHash() /// /// The client-side hash algorithm. /// - public HashAlgorithm CreateClientHash() + public HMAC CreateClientHash() { // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; @@ -258,7 +257,7 @@ public HashAlgorithm CreateClientHash() Session.ToHex(Session.SessionId), Session.ConnectionInfo.CurrentClientHmacAlgorithm)); - return _clientHashInfo.HashAlgorithm(clientKey); + return _clientHashInfo.HMAC(clientKey); } /// diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index 193b10028..3afd4b4fb 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -3,7 +3,6 @@ using System.Globalization; using System.Linq; using System.Net.Sockets; -using System.Security.Cryptography; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -172,9 +171,9 @@ public class Session : ISession private IKeyExchange _keyExchange; - private HashAlgorithm _serverMac; + private HMAC _serverMac; - private HashAlgorithm _clientMac; + private HMAC _clientMac; private Cipher _clientCipher; @@ -300,20 +299,20 @@ public Message ClientInitMessage get { _clientInitMessage ??= new KeyExchangeInitMessage - { - KeyExchangeAlgorithms = ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(), - ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(), - EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(), - EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(), - MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(), - MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(), - CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(), - CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(), - LanguagesClientToServer = new[] { string.Empty }, - LanguagesServerToClient = new[] { string.Empty }, - FirstKexPacketFollows = false, - Reserved = 0 - }; + { + KeyExchangeAlgorithms = ConnectionInfo.KeyExchangeAlgorithms.Keys.ToArray(), + ServerHostKeyAlgorithms = ConnectionInfo.HostKeyAlgorithms.Keys.ToArray(), + EncryptionAlgorithmsClientToServer = ConnectionInfo.Encryptions.Keys.ToArray(), + EncryptionAlgorithmsServerToClient = ConnectionInfo.Encryptions.Keys.ToArray(), + MacAlgorithmsClientToServer = ConnectionInfo.HmacAlgorithms.Keys.ToArray(), + MacAlgorithmsServerToClient = ConnectionInfo.HmacAlgorithms.Keys.ToArray(), + CompressionAlgorithmsClientToServer = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(), + CompressionAlgorithmsServerToClient = ConnectionInfo.CompressionAlgorithms.Keys.ToArray(), + LanguagesClientToServer = new[] { string.Empty }, + LanguagesServerToClient = new[] { string.Empty }, + FirstKexPacketFollows = false, + Reserved = 0 + }; return _clientInitMessage; } @@ -1063,20 +1062,43 @@ internal void SendMessage(Message message) byte[] hash = null; var packetDataOffset = 4; // first four bytes are reserved for outbound packet sequence - if (_clientMac != null) + if (_clientMac != null && !_clientMac.ETM) { // write outbound packet sequence to start of packet data Pack.UInt32ToBigEndian(_outboundPacketSequence, packetData); // calculate packet hash - hash = _clientMac.ComputeHash(packetData); + hash = _clientMac.HashAlgorithm.ComputeHash(packetData); } + var packetDataLengthSize = 4; + // Encrypt packet data if (_clientCipher != null) { - packetData = _clientCipher.Encrypt(packetData, packetDataOffset, packetData.Length - packetDataOffset); - packetDataOffset = 0; + if (_clientMac != null && _clientMac.ETM) + { + var encryptedData = _clientCipher.Encrypt(packetData, packetDataOffset + packetDataLengthSize, packetData.Length - packetDataOffset - packetDataLengthSize); + + packetData = new byte[packetDataOffset + packetDataLengthSize + encryptedData.Length]; + + // write outbound packet sequence to start of packet data + Pack.UInt32ToBigEndian(_outboundPacketSequence, packetData); + + // write packet length + Pack.UInt32ToBigEndian((uint) encryptedData.Length, packetData, packetDataOffset); + + // write encrypted data + Buffer.BlockCopy(encryptedData, 0, packetData, packetDataOffset + packetDataLengthSize, encryptedData.Length); + + // calculate packet hash + hash = _clientMac.HashAlgorithm.ComputeHash(packetData); + } + else + { + packetData = _clientCipher.Encrypt(packetData, packetDataOffset, packetData.Length - packetDataOffset); + packetDataOffset = 0; + } } if (packetData.Length > MaximumSshPacketSize) @@ -1197,7 +1219,7 @@ private Message ReceiveMessage(Socket socket) // Determine the size of the first block, which is 8 or cipher block size (whichever is larger) bytes var blockSize = _serverCipher is null ? (byte) 8 : Math.Max((byte) 8, _serverCipher.MinimumSize); - var serverMacLength = _serverMac != null ? _serverMac.HashSize/8 : 0; + var serverMacLength = _serverMac != null ? _serverMac.HashAlgorithm.HashSize/8 : 0; byte[] data; uint packetLength; @@ -1215,7 +1237,7 @@ private Message ReceiveMessage(Socket socket) return null; } - if (_serverCipher != null) + if (_serverCipher != null && !_serverMac.ETM) { firstBlock = _serverCipher.Decrypt(firstBlock); } @@ -1257,6 +1279,20 @@ private Message ReceiveMessage(Socket socket) } } + // validate encrypted message against MAC + if (_serverMac != null && _serverMac.ETM) + { + var clientHash = _serverMac.HashAlgorithm.ComputeHash(data, blockSize, data.Length - blockSize - serverMacLength); + var serverHash = data.Take(data.Length - serverMacLength, serverMacLength); + + // TODO Add IsEqualTo overload that takes left+right index and number of bytes to compare. + // TODO That way we can eliminate the extra allocation of the Take above. + if (!serverHash.IsEqualTo(clientHash)) + { + throw new SshConnectionException("MAC error", DisconnectReason.MacError); + } + } + if (_serverCipher != null) { var numberOfBytesToDecrypt = data.Length - (blockSize + inboundPacketSequenceLength + serverMacLength); @@ -1271,10 +1307,10 @@ private Message ReceiveMessage(Socket socket) var messagePayloadLength = (int) packetLength - paddingLength - paddingLengthFieldLength; var messagePayloadOffset = inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength; - // validate message against MAC - if (_serverMac != null) + // validate decrpted message against MAC + if (_serverMac != null && !_serverMac.ETM) { - var clientHash = _serverMac.ComputeHash(data, 0, data.Length - serverMacLength); + var clientHash = _serverMac.HashAlgorithm.ComputeHash(data, 0, data.Length - serverMacLength); var serverHash = data.Take(data.Length - serverMacLength, serverMacLength); // TODO Add IsEqualTo overload that takes left+right index and number of bytes to compare. diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs index 6331f7b9c..3ed966253 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs @@ -3,7 +3,6 @@ using System.Globalization; using System.Net; using System.Net.Sockets; -using System.Security.Cryptography; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -216,9 +215,9 @@ private void SetupMocks() _ = _keyExchangeMock.Setup(p => p.CreateClientCipher()) .Returns((Cipher) null); _ = _keyExchangeMock.Setup(p => p.CreateServerHash()) - .Returns((HashAlgorithm) null); + .Returns((HMAC) null); _ = _keyExchangeMock.Setup(p => p.CreateClientHash()) - .Returns((HashAlgorithm) null); + .Returns((HMAC) null); _ = _keyExchangeMock.Setup(p => p.CreateCompressor()) .Returns((Compressor) null); _ = _keyExchangeMock.Setup(p => p.CreateDecompressor()) diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs index 11cda2d90..7e4fed4fc 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs @@ -3,7 +3,6 @@ using System.Globalization; using System.Net; using System.Net.Sockets; -using System.Security.Cryptography; using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; using Renci.SshNet.Common; @@ -164,9 +163,9 @@ private void SetupMocks() _ = _keyExchangeMock.Setup(p => p.CreateClientCipher()) .Returns((Cipher) null); _ = _keyExchangeMock.Setup(p => p.CreateServerHash()) - .Returns((HashAlgorithm) null); + .Returns((HMAC) null); _ = _keyExchangeMock.Setup(p => p.CreateClientHash()) - .Returns((HashAlgorithm) null); + .Returns((HMAC) null); _ = _keyExchangeMock.Setup(p => p.CreateCompressor()) .Returns((Compressor) null); _ = _keyExchangeMock.Setup(p => p.CreateDecompressor()) From 1063894f7f4bf2d70a39232959123a515631a5b6 Mon Sep 17 00:00:00 2001 From: "Xu, Scott" Date: Mon, 12 Feb 2024 16:08:30 +0800 Subject: [PATCH 5/6] Add intergration test for HmacSha2_256_ETM and HmacSha2_512_ETM --- test/Renci.SshNet.IntegrationTests/HmacTests.cs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/Renci.SshNet.IntegrationTests/HmacTests.cs b/test/Renci.SshNet.IntegrationTests/HmacTests.cs index 993e5ec98..9a89c7089 100644 --- a/test/Renci.SshNet.IntegrationTests/HmacTests.cs +++ b/test/Renci.SshNet.IntegrationTests/HmacTests.cs @@ -58,6 +58,18 @@ public void HmacSha2_512() DoTest(MessageAuthenticationCodeAlgorithm.HmacSha2_512); } + [TestMethod] + public void HmacSha2_256_Etm() + { + DoTest(MessageAuthenticationCodeAlgorithm.HmacSha2_256_Etm); + } + + [TestMethod] + public void HmacSha2_512_Etm() + { + DoTest(MessageAuthenticationCodeAlgorithm.HmacSha2_512_Etm); + } + private void DoTest(MessageAuthenticationCodeAlgorithm macAlgorithm) { _remoteSshdConfig.ClearMessageAuthenticationCodeAlgorithms() From c01854d8f2fc777ab6f2c0d8ebbaeb9975301089 Mon Sep 17 00:00:00 2001 From: scott-xu Date: Mon, 12 Feb 2024 16:17:07 +0800 Subject: [PATCH 6/6] typo --- src/Renci.SshNet/Session.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index 3afd4b4fb..96ee1e2e2 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -1307,7 +1307,7 @@ private Message ReceiveMessage(Socket socket) var messagePayloadLength = (int) packetLength - paddingLength - paddingLengthFieldLength; var messagePayloadOffset = inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength; - // validate decrpted message against MAC + // validate decrypted message against MAC if (_serverMac != null && !_serverMac.ETM) { var clientHash = _serverMac.HashAlgorithm.ComputeHash(data, 0, data.Length - serverMacLength);