Skip to content

Read the underlying buffer in SshDataStream #1638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/Renci.SshNet/Common/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ namespace Renci.SshNet.Common
/// </summary>
internal static class Extensions
{
#pragma warning disable S4136 // Method overloads should be grouped together
internal static byte[] ToArray(this ServiceName serviceName)
#pragma warning restore S4136 // Method overloads should be grouped together
{
switch (serviceName)
{
Expand Down Expand Up @@ -382,6 +384,28 @@ internal static bool Remove<TKey, TValue>(this Dictionary<TKey, TValue> dictiona
value = default;
return false;
}

internal static ArraySegment<T> Slice<T>(this ArraySegment<T> arraySegment, int index)
{
return new ArraySegment<T>(arraySegment.Array, arraySegment.Offset + index, arraySegment.Count - index);
}

internal static ArraySegment<T> Slice<T>(this ArraySegment<T> arraySegment, int index, int count)
{
return new ArraySegment<T>(arraySegment.Array, arraySegment.Offset + index, count);
}

internal static T[] ToArray<T>(this ArraySegment<T> arraySegment)
{
if (arraySegment.Count == 0)
{
return Array.Empty<T>();
}

var array = new T[arraySegment.Count];
Array.Copy(arraySegment.Array, arraySegment.Offset, array, 0, arraySegment.Count);
return array;
}
#endif
}
}
111 changes: 56 additions & 55 deletions src/Renci.SshNet/Common/SshDataStream.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Buffers.Binary;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Numerics;
Expand Down Expand Up @@ -27,7 +29,7 @@ public SshDataStream(int capacity)
/// <param name="buffer">The array of unsigned bytes from which to create the current stream.</param>
/// <exception cref="ArgumentNullException"><paramref name="buffer"/> is <see langword="null"/>.</exception>
public SshDataStream(byte[] buffer)
: base(buffer)
: base(buffer ?? throw new ArgumentNullException(nameof(buffer)), 0, buffer.Length, writable: true, publiclyVisible: true)
{
}

Expand All @@ -39,7 +41,7 @@ public SshDataStream(byte[] buffer)
/// <param name="count">The number of bytes to load.</param>
/// <exception cref="ArgumentNullException"><paramref name="buffer"/> is <see langword="null"/>.</exception>
public SshDataStream(byte[] buffer, int offset, int count)
: base(buffer, offset, count)
: base(buffer, offset, count, writable: true, publiclyVisible: true)
{
}

Expand All @@ -58,19 +60,6 @@ public bool IsEndOfData
}

#if NETFRAMEWORK || NETSTANDARD2_0
private int Read(Span<byte> buffer)
{
var sharedBuffer = System.Buffers.ArrayPool<byte>.Shared.Rent(buffer.Length);

var numRead = Read(sharedBuffer, 0, buffer.Length);

sharedBuffer.AsSpan(0, numRead).CopyTo(buffer);

System.Buffers.ArrayPool<byte>.Shared.Return(sharedBuffer);

return numRead;
}

private void Write(ReadOnlySpan<byte> buffer)
{
var sharedBuffer = System.Buffers.ArrayPool<byte>.Shared.Rent(buffer.Length);
Expand All @@ -90,7 +79,7 @@ private void Write(ReadOnlySpan<byte> buffer)
public void Write(uint value)
{
Span<byte> bytes = stackalloc byte[4];
System.Buffers.Binary.BinaryPrimitives.WriteUInt32BigEndian(bytes, value);
BinaryPrimitives.WriteUInt32BigEndian(bytes, value);
Write(bytes);
}

Expand All @@ -101,7 +90,7 @@ public void Write(uint value)
public void Write(ulong value)
{
Span<byte> bytes = stackalloc byte[8];
System.Buffers.Binary.BinaryPrimitives.WriteUInt64BigEndian(bytes, value);
BinaryPrimitives.WriteUInt64BigEndian(bytes, value);
Write(bytes);
}

Expand Down Expand Up @@ -137,6 +126,7 @@ public void Write(byte[] data)
/// <exception cref="ArgumentNullException"><paramref name="encoding"/> is <see langword="null"/>.</exception>
public void Write(string s, Encoding encoding)
{
ThrowHelper.ThrowIfNull(s);
ThrowHelper.ThrowIfNull(encoding);

#if NETSTANDARD2_1 || NET
Expand All @@ -153,12 +143,21 @@ public void Write(string s, Encoding encoding)
}

/// <summary>
/// Reads a byte array from the SSH data stream.
/// Reads a length-prefixed byte array from the SSH data stream.
/// </summary>
/// <returns>
/// The byte array read from the SSH data stream.
/// </returns>
public byte[] ReadBinary()
{
return ReadBinarySegment().ToArray();
}

/// <summary>
/// Reads a length-prefixed byte array from the SSH data stream,
/// returned as a view over the underlying buffer.
/// </summary>
internal ArraySegment<byte> ReadBinarySegment()
Copy link
Collaborator Author

@Rob-Hague Rob-Hague May 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

later this will be used in SshData to remove large array allocations when reading e.g. ChannelDataMessage (the motivation for this change)

{
var length = ReadUInt32();

Expand All @@ -167,7 +166,23 @@ public byte[] ReadBinary()
throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Data longer than {0} is not supported.", int.MaxValue));
}

return ReadBytes((int)length);
var buffer = GetRemainingBuffer().Slice(0, (int)length);

Position += length;

return buffer;
}

/// <summary>
/// Gets a view over the remaining data in the underlying buffer.
/// </summary>
private ArraySegment<byte> GetRemainingBuffer()
{
var success = TryGetBuffer(out var buffer);

Debug.Assert(success, "Expected buffer to be publicly visible");

return buffer.Slice((int)Position);
}

/// <summary>
Expand Down Expand Up @@ -205,11 +220,11 @@ public void WriteBinary(byte[] buffer, int offset, int count)
/// </returns>
public BigInteger ReadBigInt()
{
var data = ReadBinary();

#if NETSTANDARD2_1 || NET
var data = ReadBinarySegment();
return new BigInteger(data, isBigEndian: true);
#else
var data = ReadBinary();
Array.Reverse(data);
return new BigInteger(data);
#endif
Expand All @@ -223,9 +238,9 @@ public BigInteger ReadBigInt()
/// </returns>
public ushort ReadUInt16()
{
Span<byte> bytes = stackalloc byte[2];
ReadBytes(bytes);
return System.Buffers.Binary.BinaryPrimitives.ReadUInt16BigEndian(bytes);
var ret = BinaryPrimitives.ReadUInt16BigEndian(GetRemainingBuffer());
Position += sizeof(ushort);
return ret;
}

/// <summary>
Expand All @@ -236,9 +251,9 @@ public ushort ReadUInt16()
/// </returns>
public uint ReadUInt32()
{
Span<byte> span = stackalloc byte[4];
ReadBytes(span);
return System.Buffers.Binary.BinaryPrimitives.ReadUInt32BigEndian(span);
var ret = BinaryPrimitives.ReadUInt32BigEndian(GetRemainingBuffer());
Position += sizeof(uint);
return ret;
}

/// <summary>
Expand All @@ -249,9 +264,9 @@ public uint ReadUInt32()
/// </returns>
public ulong ReadUInt64()
{
Span<byte> span = stackalloc byte[8];
ReadBytes(span);
return System.Buffers.Binary.BinaryPrimitives.ReadUInt64BigEndian(span);
var ret = BinaryPrimitives.ReadUInt64BigEndian(GetRemainingBuffer());
Position += sizeof(ulong);
return ret;
}

/// <summary>
Expand All @@ -265,19 +280,13 @@ public string ReadString(Encoding encoding = null)
{
encoding ??= Encoding.UTF8;

var length = ReadUInt32();

if (length > int.MaxValue)
{
throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Strings longer than {0} is not supported.", int.MaxValue));
}
var bytes = ReadBinarySegment();

var bytes = ReadBytes((int)length);
return encoding.GetString(bytes, 0, bytes.Length);
return encoding.GetString(bytes.Array, bytes.Offset, bytes.Count);
}

/// <summary>
/// Writes the stream contents to a byte array, regardless of the <see cref="MemoryStream.Position"/>.
/// Retrieves the stream contents as a byte array, regardless of the <see cref="MemoryStream.Position"/>.
/// </summary>
/// <returns>
/// This method returns the contents of the <see cref="SshDataStream"/> as a byte array.
Expand All @@ -288,9 +297,15 @@ public string ReadString(Encoding encoding = null)
/// </remarks>
public override byte[] ToArray()
{
if (Capacity == Length)
var success = TryGetBuffer(out var buffer);

Debug.Assert(success, "Expected buffer to be publicly visible");

if (buffer.Offset == 0 &&
buffer.Count == buffer.Array.Length &&
buffer.Count == Length)
{
return GetBuffer();
return buffer.Array;
}

return base.ToArray();
Expand All @@ -315,19 +330,5 @@ internal byte[] ReadBytes(int length)

return data;
}

/// <summary>
/// Reads data into the specified <paramref name="buffer" />.
/// </summary>
/// <param name="buffer">The buffer to read into.</param>
/// <exception cref="ArgumentOutOfRangeException"><paramref name="buffer"/> is larger than the total of bytes available.</exception>
private void ReadBytes(Span<byte> buffer)
{
var bytesRead = Read(buffer);
if (bytesRead < buffer.Length)
{
throw new ArgumentOutOfRangeException(nameof(buffer), string.Format(CultureInfo.InvariantCulture, "The requested length ({0}) is greater than the actual number of bytes read ({1}).", buffer.Length, bytesRead));
}
}
}
}
39 changes: 19 additions & 20 deletions src/Renci.SshNet/PrivateKeyFile.OpenSSH.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public OpenSSH(byte[] data, string? passPhrase)
/// </summary>
public Key Parse()
{
var keyReader = new SshDataReader(_data);
var keyReader = new SshDataStream(_data);

// check magic header
var authMagic = "openssh-key-v1\0"u8;
Expand Down Expand Up @@ -171,7 +171,7 @@ public Key Parse()
// now parse the data we called the private key, it actually contains the public key again
// so we need to parse through it to get the private key bytes, plus there's some
// validation we need to do.
var privateKeyReader = new SshDataReader(privateKeyBytes);
var privateKeyReader = new SshDataStream(privateKeyBytes);

// check ints should match, they wouldn't match for example if the wrong passphrase was supplied
var checkInt1 = (int)privateKeyReader.ReadUInt32();
Expand All @@ -196,33 +196,29 @@ public Key Parse()
// https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent-11#section-3.2.3

// ENC(A)
_ = privateKeyReader.ReadBignum2();
_ = privateKeyReader.ReadBinarySegment();

// k || ENC(A)
unencryptedPrivateKey = privateKeyReader.ReadBignum2();
unencryptedPrivateKey = privateKeyReader.ReadBinary();
parsedKey = new ED25519Key(unencryptedPrivateKey);
break;
case "ecdsa-sha2-nistp256":
case "ecdsa-sha2-nistp384":
case "ecdsa-sha2-nistp521":
// curve
var len = (int)privateKeyReader.ReadUInt32();
var curve = Encoding.ASCII.GetString(privateKeyReader.ReadBytes(len));
var curve = privateKeyReader.ReadString(Encoding.ASCII);

// public key
publicKey = privateKeyReader.ReadBignum2();
publicKey = privateKeyReader.ReadBinary();

// private key
unencryptedPrivateKey = privateKeyReader.ReadBignum2();
unencryptedPrivateKey = privateKeyReader.ReadBinary();
parsedKey = new EcdsaKey(curve, publicKey, unencryptedPrivateKey.TrimLeadingZeros());
break;
case "ssh-rsa":
var modulus = privateKeyReader.ReadBignum(); // n
var exponent = privateKeyReader.ReadBignum(); // e
var d = privateKeyReader.ReadBignum(); // d
var inverseQ = privateKeyReader.ReadBignum(); // iqmp
var p = privateKeyReader.ReadBignum(); // p
var q = privateKeyReader.ReadBignum(); // q
var modulus = privateKeyReader.ReadBigInt();
var exponent = privateKeyReader.ReadBigInt();
var d = privateKeyReader.ReadBigInt();
var inverseQ = privateKeyReader.ReadBigInt();
var p = privateKeyReader.ReadBigInt();
var q = privateKeyReader.ReadBigInt();
parsedKey = new RsaKey(modulus, exponent, d, p, q, inverseQ);
break;
default:
Expand All @@ -233,14 +229,17 @@ public Key Parse()

// The list of privatekey/comment pairs is padded with the bytes 1, 2, 3, ...
// until the total length is a multiple of the cipher block size.
var padding = privateKeyReader.ReadBytes();
for (var i = 0; i < padding.Length; i++)
int b, i = 0;

while ((b = privateKeyReader.ReadByte()) != -1)
{
if ((int)padding[i] != i + 1)
if (b != i + 1)
{
throw new SshException("Padding of openssh key format contained wrong byte at position: " +
i.ToString(CultureInfo.InvariantCulture));
}

i++;
}

return parsedKey;
Expand Down
Loading
Loading