Skip to content

add zero byte read to SslStream #87563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2739,6 +2739,8 @@ public abstract class WrappingConnectedStreamConformanceTests : ConnectedStreamC
/// </summary>
protected virtual bool ZeroByteReadPerformsZeroByteReadOnUnderlyingStream => false;

protected virtual bool ExtraZeroByteReadsAllowed => false;

[Theory]
[InlineData(false)]
[InlineData(true)]
Expand Down Expand Up @@ -2938,7 +2940,7 @@ public virtual async Task ZeroByteRead_PerformsZeroByteReadOnUnderlyingStreamWhe
using StreamPair innerStreams = ConnectedStreams.CreateBidirectional();
(Stream innerWriteable, Stream innerReadable) = GetReadWritePair(innerStreams);

var tracker = new ZeroByteReadTrackingStream(innerReadable);
var tracker = new ZeroByteReadTrackingStream(innerReadable, ExtraZeroByteReadsAllowed);
using StreamPair streams = await CreateWrappedConnectedStreamsAsync((innerWriteable, tracker));

(Stream writeable, Stream readable) = GetReadWritePair(streams);
Expand Down Expand Up @@ -2993,9 +2995,11 @@ public virtual async Task ZeroByteRead_PerformsZeroByteReadOnUnderlyingStreamWhe
private sealed class ZeroByteReadTrackingStream : DelegatingStream
{
private TaskCompletionSource? _signal;
private bool _extraZeroByteReadsAllowed;

public ZeroByteReadTrackingStream(Stream innerStream) : base(innerStream)
public ZeroByteReadTrackingStream(Stream innerStream, bool extraZeroByteReadsAllowed = false) : base(innerStream)
{
_extraZeroByteReadsAllowed = extraZeroByteReadsAllowed;
}

public Task WaitForZeroByteReadAsync()
Expand All @@ -3014,13 +3018,13 @@ private void CheckForZeroByteRead(int bufferLength)
if (bufferLength == 0)
{
var signal = _signal;
if (signal is null)
if (signal is null && !_extraZeroByteReadsAllowed)
{
throw new Exception("Unexpected zero byte read");
}

_signal = null;
signal.SetResult();
signal?.SetResult();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@ public async Task ZeroByteRead_IssuesZeroByteReadOnUnderlyingStream(StreamConfor

using HttpResponseMessage response = await clientTask.WaitAsync(TestHelper.PassingTestTimeout);
using Stream clientStream = response.Content.ReadAsStream();
Assert.False(sawZeroByteRead.Task.IsCompleted);

if (!useSsl)
{
// SslStream does zero byte reads under the covers
Assert.False(sawZeroByteRead.Task.IsCompleted);
}

Task<int> zeroByteReadTask = Task.Run(() => StreamConformanceTests.ReadAsync(readMode, clientStream, Array.Empty<byte>(), 0, 0, CancellationToken.None));
Assert.False(zeroByteReadTask.IsCompleted);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public partial class SslStream
private const int HandshakeTypeOffsetSsl2 = 2; // Offset of HelloType in Sslv2 and Unified frames
private const int HandshakeTypeOffsetTls = 5; // Offset of HelloType in Sslv3 and TLS frames

private const int UnknownTlsFrameLength = int.MaxValue; // frame too short to determine length

private bool _receivedEOF;

// Used by Telemetry to ensure we log connection close exactly once
Expand Down Expand Up @@ -211,12 +213,10 @@ private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationTo
throw SslStreamPal.GetException(status);
}

_buffer.EnsureAvailableSpace(InitialHandshakeBufferSize);

ProtocolToken message;
do
{
int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
int frameSize = await ReceiveHandshakeFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
ProcessTlsFrame(frameSize, out message);

if (message.Size > 0)
Expand Down Expand Up @@ -291,7 +291,7 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[

while (!handshakeCompleted)
{
int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
int frameSize = await ReceiveHandshakeFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
ProcessTlsFrame(frameSize, out message);

ReadOnlyMemory<byte> payload = default;
Expand Down Expand Up @@ -359,10 +359,10 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
}

// This method will make sure we have at least one full TLS frame buffered.
private async ValueTask<int> ReceiveTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
private async ValueTask<int> ReceiveHandshakeFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken, InitialHandshakeBufferSize).ConfigureAwait(false);

if (frameSize == 0)
{
Expand Down Expand Up @@ -699,38 +699,27 @@ private void ReturnReadBufferIfEmpty()

private bool HaveFullTlsFrame(out int frameSize)
{
if (_buffer.EncryptedLength < TlsFrameHelper.HeaderSize)
{
frameSize = int.MaxValue;
return false;
}

frameSize = GetFrameSize(_buffer.EncryptedReadOnlySpan);
return _buffer.EncryptedLength >= frameSize;
}

[AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
private async ValueTask<int> EnsureFullTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
private async ValueTask<int> EnsureFullTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken, int estimatedSize)
where TIOAdapter : IReadWriteAdapter
{
int frameSize;
if (HaveFullTlsFrame(out frameSize))
if (HaveFullTlsFrame(out int frameSize))
{
return frameSize;
}

if (frameSize != int.MaxValue)
{
// make sure we have space for the whole frame
_buffer.EnsureAvailableSpace(frameSize - _buffer.EncryptedLength);
}
else
{
// move existing data to the beginning of the buffer (they will
// be couple of bytes only, otherwise we would have entire
// header and know exact size)
_buffer.EnsureAvailableSpace(_buffer.Capacity - _buffer.EncryptedLength);
}
await TIOAdapter.ReadAsync(InnerStream, Memory<byte>.Empty, cancellationToken).ConfigureAwait(false);

// If we don't have enough data to determine the frame size, use the provided estimate
// (e.g. a full TLS frame for reads, and a somewhat shorter frame for handshake / renegotiation).
// If we do know the frame size, ensure we have space for the whole frame.
_buffer.EnsureAvailableSpace(frameSize == UnknownTlsFrameLength ?
estimatedSize :
frameSize - _buffer.EncryptedLength);

while (_buffer.EncryptedLength < frameSize)
{
Expand Down Expand Up @@ -806,6 +795,7 @@ private SecurityStatusPal DecryptData(int frameSize)
private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer, CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{

// Throw first if we already have exception.
// Check for disposal is not atomic so we will check again below.
ThrowIfExceptionalOrNotAuthenticated();
Expand All @@ -819,11 +809,12 @@ private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer,
try
{
int processedLength = 0;
int nextTlsFrameLength = UnknownTlsFrameLength;

if (_buffer.DecryptedLength != 0)
{
processedLength = CopyDecryptedData(buffer);
if (processedLength == buffer.Length || !HaveFullTlsFrame(out _))
if (processedLength == buffer.Length || !HaveFullTlsFrame(out nextTlsFrameLength))
{
// We either filled whole buffer or used all buffered frames.
return processedLength;
Expand All @@ -832,32 +823,19 @@ private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer,
buffer = buffer.Slice(processedLength);
}

if (_receivedEOF)
if (_receivedEOF && nextTlsFrameLength == UnknownTlsFrameLength)
{
// there should be no frames waiting for processing
Debug.Assert(_buffer.EncryptedLength == 0);
// We received EOF during previous read but had buffered data to return.
return 0;
}

if (buffer.Length == 0 && _buffer.ActiveLength == 0)
{
// User requested a zero-byte read, and we have no data available in the buffer for processing.
// This zero-byte read indicates their desire to trade off the extra cost of a zero-byte read
// for reduced memory consumption when data is not immediately available.
// So, we will issue our own zero-byte read against the underlying stream and defer buffer allocation
// until data is actually available from the underlying stream.
// Note that if the underlying stream does not supporting blocking on zero byte reads, then this will
// complete immediately and won't save any memory, but will still function correctly.
await TIOAdapter.ReadAsync(InnerStream, Memory<byte>.Empty, cancellationToken).ConfigureAwait(false);
}

Debug.Assert(_buffer.DecryptedLength == 0);

_buffer.EnsureAvailableSpace(ReadBufferSize - _buffer.ActiveLength);

while (true)
{
int payloadBytes = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
int payloadBytes = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken, ReadBufferSize).ConfigureAwait(false);
if (payloadBytes == 0)
{
_receivedEOF = true;
Expand Down Expand Up @@ -1009,6 +987,11 @@ private int CopyDecryptedData(Memory<byte> buffer)
// Returns TLS Frame size including header size.
private int GetFrameSize(ReadOnlySpan<byte> buffer)
{
if (buffer.Length < TlsFrameHelper.HeaderSize)
{
return UnknownTlsFrameLength;
}

if (!TlsFrameHelper.TryGetFrameHeader(buffer, ref _lastFrame.Header))
{
throw new IOException(SR.net_ssl_io_frame);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ public abstract class SslStreamConformanceTests : WrappingConnectedStreamConform
protected override bool BlocksOnZeroByteReads => true;
protected override bool ZeroByteReadPerformsZeroByteReadOnUnderlyingStream => true;
protected override Type UnsupportedConcurrentExceptionType => typeof(NotSupportedException);
protected override bool ExtraZeroByteReadsAllowed => true;

protected virtual SslProtocols GetSslProtocols() => SslProtocols.None;

Expand Down