Skip to content

Commit 7c07b10

Browse files
authored
Read the underlying buffer in SshDataStream (#1638)
SshDataStream is a MemoryStream, so we can access the buffer directly. Also simplify some usage in PrivateKeyFile.
1 parent 6039e12 commit 7c07b10

File tree

6 files changed

+129
-159
lines changed

6 files changed

+129
-159
lines changed

src/Renci.SshNet/Common/Extensions.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ namespace Renci.SshNet.Common
1919
/// </summary>
2020
internal static class Extensions
2121
{
22+
#pragma warning disable S4136 // Method overloads should be grouped together
2223
internal static byte[] ToArray(this ServiceName serviceName)
24+
#pragma warning restore S4136 // Method overloads should be grouped together
2325
{
2426
switch (serviceName)
2527
{
@@ -382,6 +384,28 @@ internal static bool Remove<TKey, TValue>(this Dictionary<TKey, TValue> dictiona
382384
value = default;
383385
return false;
384386
}
387+
388+
internal static ArraySegment<T> Slice<T>(this ArraySegment<T> arraySegment, int index)
389+
{
390+
return new ArraySegment<T>(arraySegment.Array, arraySegment.Offset + index, arraySegment.Count - index);
391+
}
392+
393+
internal static ArraySegment<T> Slice<T>(this ArraySegment<T> arraySegment, int index, int count)
394+
{
395+
return new ArraySegment<T>(arraySegment.Array, arraySegment.Offset + index, count);
396+
}
397+
398+
internal static T[] ToArray<T>(this ArraySegment<T> arraySegment)
399+
{
400+
if (arraySegment.Count == 0)
401+
{
402+
return Array.Empty<T>();
403+
}
404+
405+
var array = new T[arraySegment.Count];
406+
Array.Copy(arraySegment.Array, arraySegment.Offset, array, 0, arraySegment.Count);
407+
return array;
408+
}
385409
#endif
386410
}
387411
}

src/Renci.SshNet/Common/SshDataStream.cs

Lines changed: 56 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System;
2+
using System.Buffers.Binary;
3+
using System.Diagnostics;
24
using System.Globalization;
35
using System.IO;
46
using System.Numerics;
@@ -27,7 +29,7 @@ public SshDataStream(int capacity)
2729
/// <param name="buffer">The array of unsigned bytes from which to create the current stream.</param>
2830
/// <exception cref="ArgumentNullException"><paramref name="buffer"/> is <see langword="null"/>.</exception>
2931
public SshDataStream(byte[] buffer)
30-
: base(buffer)
32+
: base(buffer ?? throw new ArgumentNullException(nameof(buffer)), 0, buffer.Length, writable: true, publiclyVisible: true)
3133
{
3234
}
3335

@@ -39,7 +41,7 @@ public SshDataStream(byte[] buffer)
3941
/// <param name="count">The number of bytes to load.</param>
4042
/// <exception cref="ArgumentNullException"><paramref name="buffer"/> is <see langword="null"/>.</exception>
4143
public SshDataStream(byte[] buffer, int offset, int count)
42-
: base(buffer, offset, count)
44+
: base(buffer, offset, count, writable: true, publiclyVisible: true)
4345
{
4446
}
4547

@@ -58,19 +60,6 @@ public bool IsEndOfData
5860
}
5961

6062
#if NETFRAMEWORK || NETSTANDARD2_0
61-
private int Read(Span<byte> buffer)
62-
{
63-
var sharedBuffer = System.Buffers.ArrayPool<byte>.Shared.Rent(buffer.Length);
64-
65-
var numRead = Read(sharedBuffer, 0, buffer.Length);
66-
67-
sharedBuffer.AsSpan(0, numRead).CopyTo(buffer);
68-
69-
System.Buffers.ArrayPool<byte>.Shared.Return(sharedBuffer);
70-
71-
return numRead;
72-
}
73-
7463
private void Write(ReadOnlySpan<byte> buffer)
7564
{
7665
var sharedBuffer = System.Buffers.ArrayPool<byte>.Shared.Rent(buffer.Length);
@@ -90,7 +79,7 @@ private void Write(ReadOnlySpan<byte> buffer)
9079
public void Write(uint value)
9180
{
9281
Span<byte> bytes = stackalloc byte[4];
93-
System.Buffers.Binary.BinaryPrimitives.WriteUInt32BigEndian(bytes, value);
82+
BinaryPrimitives.WriteUInt32BigEndian(bytes, value);
9483
Write(bytes);
9584
}
9685

@@ -101,7 +90,7 @@ public void Write(uint value)
10190
public void Write(ulong value)
10291
{
10392
Span<byte> bytes = stackalloc byte[8];
104-
System.Buffers.Binary.BinaryPrimitives.WriteUInt64BigEndian(bytes, value);
93+
BinaryPrimitives.WriteUInt64BigEndian(bytes, value);
10594
Write(bytes);
10695
}
10796

@@ -137,6 +126,7 @@ public void Write(byte[] data)
137126
/// <exception cref="ArgumentNullException"><paramref name="encoding"/> is <see langword="null"/>.</exception>
138127
public void Write(string s, Encoding encoding)
139128
{
129+
ThrowHelper.ThrowIfNull(s);
140130
ThrowHelper.ThrowIfNull(encoding);
141131

142132
#if NETSTANDARD2_1 || NET
@@ -153,12 +143,21 @@ public void Write(string s, Encoding encoding)
153143
}
154144

155145
/// <summary>
156-
/// Reads a byte array from the SSH data stream.
146+
/// Reads a length-prefixed byte array from the SSH data stream.
157147
/// </summary>
158148
/// <returns>
159149
/// The byte array read from the SSH data stream.
160150
/// </returns>
161151
public byte[] ReadBinary()
152+
{
153+
return ReadBinarySegment().ToArray();
154+
}
155+
156+
/// <summary>
157+
/// Reads a length-prefixed byte array from the SSH data stream,
158+
/// returned as a view over the underlying buffer.
159+
/// </summary>
160+
internal ArraySegment<byte> ReadBinarySegment()
162161
{
163162
var length = ReadUInt32();
164163

@@ -167,7 +166,23 @@ public byte[] ReadBinary()
167166
throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Data longer than {0} is not supported.", int.MaxValue));
168167
}
169168

170-
return ReadBytes((int)length);
169+
var buffer = GetRemainingBuffer().Slice(0, (int)length);
170+
171+
Position += length;
172+
173+
return buffer;
174+
}
175+
176+
/// <summary>
177+
/// Gets a view over the remaining data in the underlying buffer.
178+
/// </summary>
179+
private ArraySegment<byte> GetRemainingBuffer()
180+
{
181+
var success = TryGetBuffer(out var buffer);
182+
183+
Debug.Assert(success, "Expected buffer to be publicly visible");
184+
185+
return buffer.Slice((int)Position);
171186
}
172187

173188
/// <summary>
@@ -205,11 +220,11 @@ public void WriteBinary(byte[] buffer, int offset, int count)
205220
/// </returns>
206221
public BigInteger ReadBigInt()
207222
{
208-
var data = ReadBinary();
209-
210223
#if NETSTANDARD2_1 || NET
224+
var data = ReadBinarySegment();
211225
return new BigInteger(data, isBigEndian: true);
212226
#else
227+
var data = ReadBinary();
213228
Array.Reverse(data);
214229
return new BigInteger(data);
215230
#endif
@@ -223,9 +238,9 @@ public BigInteger ReadBigInt()
223238
/// </returns>
224239
public ushort ReadUInt16()
225240
{
226-
Span<byte> bytes = stackalloc byte[2];
227-
ReadBytes(bytes);
228-
return System.Buffers.Binary.BinaryPrimitives.ReadUInt16BigEndian(bytes);
241+
var ret = BinaryPrimitives.ReadUInt16BigEndian(GetRemainingBuffer());
242+
Position += sizeof(ushort);
243+
return ret;
229244
}
230245

231246
/// <summary>
@@ -236,9 +251,9 @@ public ushort ReadUInt16()
236251
/// </returns>
237252
public uint ReadUInt32()
238253
{
239-
Span<byte> span = stackalloc byte[4];
240-
ReadBytes(span);
241-
return System.Buffers.Binary.BinaryPrimitives.ReadUInt32BigEndian(span);
254+
var ret = BinaryPrimitives.ReadUInt32BigEndian(GetRemainingBuffer());
255+
Position += sizeof(uint);
256+
return ret;
242257
}
243258

244259
/// <summary>
@@ -249,9 +264,9 @@ public uint ReadUInt32()
249264
/// </returns>
250265
public ulong ReadUInt64()
251266
{
252-
Span<byte> span = stackalloc byte[8];
253-
ReadBytes(span);
254-
return System.Buffers.Binary.BinaryPrimitives.ReadUInt64BigEndian(span);
267+
var ret = BinaryPrimitives.ReadUInt64BigEndian(GetRemainingBuffer());
268+
Position += sizeof(ulong);
269+
return ret;
255270
}
256271

257272
/// <summary>
@@ -265,19 +280,13 @@ public string ReadString(Encoding encoding = null)
265280
{
266281
encoding ??= Encoding.UTF8;
267282

268-
var length = ReadUInt32();
269-
270-
if (length > int.MaxValue)
271-
{
272-
throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Strings longer than {0} is not supported.", int.MaxValue));
273-
}
283+
var bytes = ReadBinarySegment();
274284

275-
var bytes = ReadBytes((int)length);
276-
return encoding.GetString(bytes, 0, bytes.Length);
285+
return encoding.GetString(bytes.Array, bytes.Offset, bytes.Count);
277286
}
278287

279288
/// <summary>
280-
/// Writes the stream contents to a byte array, regardless of the <see cref="MemoryStream.Position"/>.
289+
/// Retrieves the stream contents as a byte array, regardless of the <see cref="MemoryStream.Position"/>.
281290
/// </summary>
282291
/// <returns>
283292
/// This method returns the contents of the <see cref="SshDataStream"/> as a byte array.
@@ -288,9 +297,15 @@ public string ReadString(Encoding encoding = null)
288297
/// </remarks>
289298
public override byte[] ToArray()
290299
{
291-
if (Capacity == Length)
300+
var success = TryGetBuffer(out var buffer);
301+
302+
Debug.Assert(success, "Expected buffer to be publicly visible");
303+
304+
if (buffer.Offset == 0 &&
305+
buffer.Count == buffer.Array.Length &&
306+
buffer.Count == Length)
292307
{
293-
return GetBuffer();
308+
return buffer.Array;
294309
}
295310

296311
return base.ToArray();
@@ -315,19 +330,5 @@ internal byte[] ReadBytes(int length)
315330

316331
return data;
317332
}
318-
319-
/// <summary>
320-
/// Reads data into the specified <paramref name="buffer" />.
321-
/// </summary>
322-
/// <param name="buffer">The buffer to read into.</param>
323-
/// <exception cref="ArgumentOutOfRangeException"><paramref name="buffer"/> is larger than the total of bytes available.</exception>
324-
private void ReadBytes(Span<byte> buffer)
325-
{
326-
var bytesRead = Read(buffer);
327-
if (bytesRead < buffer.Length)
328-
{
329-
throw new ArgumentOutOfRangeException(nameof(buffer), string.Format(CultureInfo.InvariantCulture, "The requested length ({0}) is greater than the actual number of bytes read ({1}).", buffer.Length, bytesRead));
330-
}
331-
}
332333
}
333334
}

src/Renci.SshNet/PrivateKeyFile.OpenSSH.cs

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public OpenSSH(byte[] data, string? passPhrase)
3232
/// </summary>
3333
public Key Parse()
3434
{
35-
var keyReader = new SshDataReader(_data);
35+
var keyReader = new SshDataStream(_data);
3636

3737
// check magic header
3838
var authMagic = "openssh-key-v1\0"u8;
@@ -171,7 +171,7 @@ public Key Parse()
171171
// now parse the data we called the private key, it actually contains the public key again
172172
// so we need to parse through it to get the private key bytes, plus there's some
173173
// validation we need to do.
174-
var privateKeyReader = new SshDataReader(privateKeyBytes);
174+
var privateKeyReader = new SshDataStream(privateKeyBytes);
175175

176176
// check ints should match, they wouldn't match for example if the wrong passphrase was supplied
177177
var checkInt1 = (int)privateKeyReader.ReadUInt32();
@@ -196,33 +196,29 @@ public Key Parse()
196196
// https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent-11#section-3.2.3
197197

198198
// ENC(A)
199-
_ = privateKeyReader.ReadBignum2();
199+
_ = privateKeyReader.ReadBinarySegment();
200200

201201
// k || ENC(A)
202-
unencryptedPrivateKey = privateKeyReader.ReadBignum2();
202+
unencryptedPrivateKey = privateKeyReader.ReadBinary();
203203
parsedKey = new ED25519Key(unencryptedPrivateKey);
204204
break;
205205
case "ecdsa-sha2-nistp256":
206206
case "ecdsa-sha2-nistp384":
207207
case "ecdsa-sha2-nistp521":
208-
// curve
209-
var len = (int)privateKeyReader.ReadUInt32();
210-
var curve = Encoding.ASCII.GetString(privateKeyReader.ReadBytes(len));
208+
var curve = privateKeyReader.ReadString(Encoding.ASCII);
211209

212-
// public key
213-
publicKey = privateKeyReader.ReadBignum2();
210+
publicKey = privateKeyReader.ReadBinary();
214211

215-
// private key
216-
unencryptedPrivateKey = privateKeyReader.ReadBignum2();
212+
unencryptedPrivateKey = privateKeyReader.ReadBinary();
217213
parsedKey = new EcdsaKey(curve, publicKey, unencryptedPrivateKey.TrimLeadingZeros());
218214
break;
219215
case "ssh-rsa":
220-
var modulus = privateKeyReader.ReadBignum(); // n
221-
var exponent = privateKeyReader.ReadBignum(); // e
222-
var d = privateKeyReader.ReadBignum(); // d
223-
var inverseQ = privateKeyReader.ReadBignum(); // iqmp
224-
var p = privateKeyReader.ReadBignum(); // p
225-
var q = privateKeyReader.ReadBignum(); // q
216+
var modulus = privateKeyReader.ReadBigInt();
217+
var exponent = privateKeyReader.ReadBigInt();
218+
var d = privateKeyReader.ReadBigInt();
219+
var inverseQ = privateKeyReader.ReadBigInt();
220+
var p = privateKeyReader.ReadBigInt();
221+
var q = privateKeyReader.ReadBigInt();
226222
parsedKey = new RsaKey(modulus, exponent, d, p, q, inverseQ);
227223
break;
228224
default:
@@ -233,14 +229,17 @@ public Key Parse()
233229

234230
// The list of privatekey/comment pairs is padded with the bytes 1, 2, 3, ...
235231
// until the total length is a multiple of the cipher block size.
236-
var padding = privateKeyReader.ReadBytes();
237-
for (var i = 0; i < padding.Length; i++)
232+
int b, i = 0;
233+
234+
while ((b = privateKeyReader.ReadByte()) != -1)
238235
{
239-
if ((int)padding[i] != i + 1)
236+
if (b != i + 1)
240237
{
241238
throw new SshException("Padding of openssh key format contained wrong byte at position: " +
242239
i.ToString(CultureInfo.InvariantCulture));
243240
}
241+
242+
i++;
244243
}
245244

246245
return parsedKey;

0 commit comments

Comments
 (0)