Skip to content

Commit 4542e09

Browse files
wfurtstephentoub
andauthored
add zero byte read to SslStream (#87563)
* add zero byte read to SslStream * fix test * Apply suggestions from code review Co-authored-by: Stephen Toub <stoub@microsoft.com> * feedback * add back missing line --------- Co-authored-by: Stephen Toub <stoub@microsoft.com>
1 parent 8b25fd3 commit 4542e09

File tree

4 files changed

+42
-49
lines changed

4 files changed

+42
-49
lines changed

src/libraries/Common/tests/StreamConformanceTests/System/IO/StreamConformanceTests.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2739,6 +2739,8 @@ public abstract class WrappingConnectedStreamConformanceTests : ConnectedStreamC
27392739
/// </summary>
27402740
protected virtual bool ZeroByteReadPerformsZeroByteReadOnUnderlyingStream => false;
27412741

2742+
protected virtual bool ExtraZeroByteReadsAllowed => false;
2743+
27422744
[Theory]
27432745
[InlineData(false)]
27442746
[InlineData(true)]
@@ -2938,7 +2940,7 @@ public virtual async Task ZeroByteRead_PerformsZeroByteReadOnUnderlyingStreamWhe
29382940
using StreamPair innerStreams = ConnectedStreams.CreateBidirectional();
29392941
(Stream innerWriteable, Stream innerReadable) = GetReadWritePair(innerStreams);
29402942

2941-
var tracker = new ZeroByteReadTrackingStream(innerReadable);
2943+
var tracker = new ZeroByteReadTrackingStream(innerReadable, ExtraZeroByteReadsAllowed);
29422944
using StreamPair streams = await CreateWrappedConnectedStreamsAsync((innerWriteable, tracker));
29432945

29442946
(Stream writeable, Stream readable) = GetReadWritePair(streams);
@@ -2993,9 +2995,11 @@ public virtual async Task ZeroByteRead_PerformsZeroByteReadOnUnderlyingStreamWhe
29932995
private sealed class ZeroByteReadTrackingStream : DelegatingStream
29942996
{
29952997
private TaskCompletionSource? _signal;
2998+
private bool _extraZeroByteReadsAllowed;
29962999

2997-
public ZeroByteReadTrackingStream(Stream innerStream) : base(innerStream)
3000+
public ZeroByteReadTrackingStream(Stream innerStream, bool extraZeroByteReadsAllowed = false) : base(innerStream)
29983001
{
3002+
_extraZeroByteReadsAllowed = extraZeroByteReadsAllowed;
29993003
}
30003004

30013005
public Task WaitForZeroByteReadAsync()
@@ -3014,13 +3018,13 @@ private void CheckForZeroByteRead(int bufferLength)
30143018
if (bufferLength == 0)
30153019
{
30163020
var signal = _signal;
3017-
if (signal is null)
3021+
if (signal is null && !_extraZeroByteReadsAllowed)
30183022
{
30193023
throw new Exception("Unexpected zero byte read");
30203024
}
30213025

30223026
_signal = null;
3023-
signal.SetResult();
3027+
signal?.SetResult();
30243028
}
30253029
}
30263030

src/libraries/System.Net.Http/tests/FunctionalTests/ResponseStreamZeroByteReadTests.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,12 @@ public async Task ZeroByteRead_IssuesZeroByteReadOnUnderlyingStream(StreamConfor
123123

124124
using HttpResponseMessage response = await clientTask.WaitAsync(TestHelper.PassingTestTimeout);
125125
using Stream clientStream = response.Content.ReadAsStream();
126-
Assert.False(sawZeroByteRead.Task.IsCompleted);
126+
127+
if (!useSsl)
128+
{
129+
// SslStream does zero byte reads under the covers
130+
Assert.False(sawZeroByteRead.Task.IsCompleted);
131+
}
127132

128133
Task<int> zeroByteReadTask = Task.Run(() => StreamConformanceTests.ReadAsync(readMode, clientStream, Array.Empty<byte>(), 0, 0, CancellationToken.None));
129134
Assert.False(zeroByteReadTask.IsCompleted);

src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs

Lines changed: 27 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ public partial class SslStream
2727
private const int HandshakeTypeOffsetSsl2 = 2; // Offset of HelloType in Sslv2 and Unified frames
2828
private const int HandshakeTypeOffsetTls = 5; // Offset of HelloType in Sslv3 and TLS frames
2929

30+
private const int UnknownTlsFrameLength = int.MaxValue; // frame too short to determine length
31+
3032
private bool _receivedEOF;
3133

3234
// Used by Telemetry to ensure we log connection close exactly once
@@ -211,12 +213,10 @@ private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationTo
211213
throw SslStreamPal.GetException(status);
212214
}
213215

214-
_buffer.EnsureAvailableSpace(InitialHandshakeBufferSize);
215-
216216
ProtocolToken message;
217217
do
218218
{
219-
int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
219+
int frameSize = await ReceiveHandshakeFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
220220
ProcessTlsFrame(frameSize, out message);
221221

222222
if (message.Size > 0)
@@ -291,7 +291,7 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
291291

292292
while (!handshakeCompleted)
293293
{
294-
int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
294+
int frameSize = await ReceiveHandshakeFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
295295
ProcessTlsFrame(frameSize, out message);
296296

297297
ReadOnlyMemory<byte> payload = default;
@@ -359,10 +359,10 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
359359
}
360360

361361
// This method will make sure we have at least one full TLS frame buffered.
362-
private async ValueTask<int> ReceiveTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
362+
private async ValueTask<int> ReceiveHandshakeFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
363363
where TIOAdapter : IReadWriteAdapter
364364
{
365-
int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
365+
int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken, InitialHandshakeBufferSize).ConfigureAwait(false);
366366

367367
if (frameSize == 0)
368368
{
@@ -699,38 +699,27 @@ private void ReturnReadBufferIfEmpty()
699699

700700
private bool HaveFullTlsFrame(out int frameSize)
701701
{
702-
if (_buffer.EncryptedLength < TlsFrameHelper.HeaderSize)
703-
{
704-
frameSize = int.MaxValue;
705-
return false;
706-
}
707-
708702
frameSize = GetFrameSize(_buffer.EncryptedReadOnlySpan);
709703
return _buffer.EncryptedLength >= frameSize;
710704
}
711705

712706
[AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
713-
private async ValueTask<int> EnsureFullTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
707+
private async ValueTask<int> EnsureFullTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken, int estimatedSize)
714708
where TIOAdapter : IReadWriteAdapter
715709
{
716-
int frameSize;
717-
if (HaveFullTlsFrame(out frameSize))
710+
if (HaveFullTlsFrame(out int frameSize))
718711
{
719712
return frameSize;
720713
}
721714

722-
if (frameSize != int.MaxValue)
723-
{
724-
// make sure we have space for the whole frame
725-
_buffer.EnsureAvailableSpace(frameSize - _buffer.EncryptedLength);
726-
}
727-
else
728-
{
729-
// move existing data to the beginning of the buffer (they will
730-
// be couple of bytes only, otherwise we would have entire
731-
// header and know exact size)
732-
_buffer.EnsureAvailableSpace(_buffer.Capacity - _buffer.EncryptedLength);
733-
}
715+
await TIOAdapter.ReadAsync(InnerStream, Memory<byte>.Empty, cancellationToken).ConfigureAwait(false);
716+
717+
// If we don't have enough data to determine the frame size, use the provided estimate
718+
// (e.g. a full TLS frame for reads, and a somewhat shorter frame for handshake / renegotiation).
719+
// If we do know the frame size, ensure we have space for the whole frame.
720+
_buffer.EnsureAvailableSpace(frameSize == UnknownTlsFrameLength ?
721+
estimatedSize :
722+
frameSize - _buffer.EncryptedLength);
734723

735724
while (_buffer.EncryptedLength < frameSize)
736725
{
@@ -806,6 +795,7 @@ private SecurityStatusPal DecryptData(int frameSize)
806795
private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer, CancellationToken cancellationToken)
807796
where TIOAdapter : IReadWriteAdapter
808797
{
798+
809799
// Throw first if we already have exception.
810800
// Check for disposal is not atomic so we will check again below.
811801
ThrowIfExceptionalOrNotAuthenticated();
@@ -819,11 +809,12 @@ private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer,
819809
try
820810
{
821811
int processedLength = 0;
812+
int nextTlsFrameLength = UnknownTlsFrameLength;
822813

823814
if (_buffer.DecryptedLength != 0)
824815
{
825816
processedLength = CopyDecryptedData(buffer);
826-
if (processedLength == buffer.Length || !HaveFullTlsFrame(out _))
817+
if (processedLength == buffer.Length || !HaveFullTlsFrame(out nextTlsFrameLength))
827818
{
828819
// We either filled whole buffer or used all buffered frames.
829820
return processedLength;
@@ -832,32 +823,19 @@ private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer,
832823
buffer = buffer.Slice(processedLength);
833824
}
834825

835-
if (_receivedEOF)
826+
if (_receivedEOF && nextTlsFrameLength == UnknownTlsFrameLength)
836827
{
828+
// there should be no frames waiting for processing
837829
Debug.Assert(_buffer.EncryptedLength == 0);
838830
// We received EOF during previous read but had buffered data to return.
839831
return 0;
840832
}
841833

842-
if (buffer.Length == 0 && _buffer.ActiveLength == 0)
843-
{
844-
// User requested a zero-byte read, and we have no data available in the buffer for processing.
845-
// This zero-byte read indicates their desire to trade off the extra cost of a zero-byte read
846-
// for reduced memory consumption when data is not immediately available.
847-
// So, we will issue our own zero-byte read against the underlying stream and defer buffer allocation
848-
// until data is actually available from the underlying stream.
849-
// Note that if the underlying stream does not supporting blocking on zero byte reads, then this will
850-
// complete immediately and won't save any memory, but will still function correctly.
851-
await TIOAdapter.ReadAsync(InnerStream, Memory<byte>.Empty, cancellationToken).ConfigureAwait(false);
852-
}
853-
854834
Debug.Assert(_buffer.DecryptedLength == 0);
855835

856-
_buffer.EnsureAvailableSpace(ReadBufferSize - _buffer.ActiveLength);
857-
858836
while (true)
859837
{
860-
int payloadBytes = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
838+
int payloadBytes = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken, ReadBufferSize).ConfigureAwait(false);
861839
if (payloadBytes == 0)
862840
{
863841
_receivedEOF = true;
@@ -1009,6 +987,11 @@ private int CopyDecryptedData(Memory<byte> buffer)
1009987
// Returns TLS Frame size including header size.
1010988
private int GetFrameSize(ReadOnlySpan<byte> buffer)
1011989
{
990+
if (buffer.Length < TlsFrameHelper.HeaderSize)
991+
{
992+
return UnknownTlsFrameLength;
993+
}
994+
1012995
if (!TlsFrameHelper.TryGetFrameHeader(buffer, ref _lastFrame.Header))
1013996
{
1014997
throw new IOException(SR.net_ssl_io_frame);

src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamConformanceTests.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public abstract class SslStreamConformanceTests : WrappingConnectedStreamConform
1616
protected override bool BlocksOnZeroByteReads => true;
1717
protected override bool ZeroByteReadPerformsZeroByteReadOnUnderlyingStream => true;
1818
protected override Type UnsupportedConcurrentExceptionType => typeof(NotSupportedException);
19+
protected override bool ExtraZeroByteReadsAllowed => true;
1920

2021
protected virtual SslProtocols GetSslProtocols() => SslProtocols.None;
2122

0 commit comments

Comments
 (0)