Skip to content

Commit 2a27452

Browse files
authored
improve SslStream exception after disposal (#79329)
* improve SslStream exception after disposal * add tests * add StreamUse * fix cleanup * fix condition * avoid casting
1 parent c3d1dd9 commit 2a27452

File tree

4 files changed

+88
-25
lines changed

4 files changed

+88
-25
lines changed

src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ private void CloseInternal()
5454

5555
// Ensure a Read or Auth operation is not in progress,
5656
// block potential future read and auth operations since SslStream is disposing.
57-
// This leaves the _nestedRead = 1 and _nestedAuth = 1, but that's ok, since
57+
// This leaves the _nestedRead = 2 and _nestedAuth = 2, but that's ok, since
5858
// subsequent operations check the _exception sentinel first
59-
if (Interlocked.Exchange(ref _nestedRead, 1) == 0 &&
60-
Interlocked.Exchange(ref _nestedAuth, 1) == 0)
59+
if (Interlocked.Exchange(ref _nestedRead, StreamDisposed) == StreamNotInUse &&
60+
Interlocked.Exchange(ref _nestedAuth, StreamDisposed) == StreamNotInUse)
6161
{
6262
_buffer.ReturnBuffer();
6363
}
@@ -162,19 +162,22 @@ private async Task ReplyOnReAuthenticationAsync<TIOAdapter>(byte[]? buffer, Canc
162162
private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationToken)
163163
where TIOAdapter : IReadWriteAdapter
164164
{
165-
if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
165+
if (Interlocked.CompareExchange(ref _nestedAuth, StreamInUse, StreamNotInUse) != StreamNotInUse)
166166
{
167+
ObjectDisposedException.ThrowIf(_nestedAuth == StreamDisposed, this);
167168
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate"));
168169
}
169170

170-
if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
171+
if (Interlocked.CompareExchange(ref _nestedRead, StreamInUse, StreamNotInUse) != StreamNotInUse)
171172
{
173+
ObjectDisposedException.ThrowIf(_nestedRead == StreamDisposed, this);
172174
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read"));
173175
}
174176

175-
if (Interlocked.Exchange(ref _nestedWrite, 1) == 1)
177+
// Write is different since we do not do anything special in Dispose
178+
if (Interlocked.Exchange(ref _nestedWrite, StreamInUse) != StreamNotInUse)
176179
{
177-
_nestedRead = 0;
180+
_nestedRead = StreamNotInUse;
178181
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write"));
179182
}
180183

@@ -231,8 +234,8 @@ private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationTo
231234
_buffer.ReturnBuffer();
232235
}
233236

234-
_nestedRead = 0;
235-
_nestedWrite = 0;
237+
_nestedRead = StreamNotInUse;
238+
_nestedWrite = StreamNotInUse;
236239
_isRenego = false;
237240
// We will not release _nestedAuth at this point to prevent another renegotiation attempt.
238241
}
@@ -248,7 +251,7 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
248251
if (reAuthenticationData == null)
249252
{
250253
// prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation transparently.
251-
if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
254+
if (Interlocked.Exchange(ref _nestedAuth, StreamInUse) == StreamInUse)
252255
{
253256
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate"));
254257
}
@@ -335,7 +338,7 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
335338
{
336339
if (reAuthenticationData == null)
337340
{
338-
_nestedAuth = 0;
341+
_nestedAuth = StreamNotInUse;
339342
_isRenego = false;
340343
}
341344
}
@@ -500,7 +503,7 @@ private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyError
500503
{
501504
ProcessHandshakeSuccess();
502505

503-
if (_nestedAuth != 1)
506+
if (_nestedAuth != StreamInUse)
504507
{
505508
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, $"Ignoring unsolicited renegotiated certificate.");
506509
// ignore certificates received outside of handshake or requested renegotiation.
@@ -769,13 +772,16 @@ private SecurityStatusPal DecryptData(int frameSize)
769772
private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer, CancellationToken cancellationToken)
770773
where TIOAdapter : IReadWriteAdapter
771774
{
772-
if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
775+
// Throw first if we already have exception.
776+
// Check for disposal is not atomic so we will check again below.
777+
ThrowIfExceptionalOrNotAuthenticated();
778+
779+
if (Interlocked.CompareExchange(ref _nestedRead, StreamInUse, StreamNotInUse) != StreamNotInUse)
773780
{
781+
ObjectDisposedException.ThrowIf(_nestedRead == StreamDisposed, this);
774782
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read"));
775783
}
776784

777-
ThrowIfExceptionalOrNotAuthenticated();
778-
779785
try
780786
{
781787
int processedLength = 0;
@@ -910,7 +916,7 @@ private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer,
910916
finally
911917
{
912918
ReturnReadBufferIfEmpty();
913-
_nestedRead = 0;
919+
_nestedRead = StreamNotInUse;
914920
}
915921
}
916922

@@ -925,7 +931,7 @@ private async ValueTask WriteAsyncInternal<TIOAdapter>(ReadOnlyMemory<byte> buff
925931
return;
926932
}
927933

928-
if (Interlocked.Exchange(ref _nestedWrite, 1) == 1)
934+
if (Interlocked.Exchange(ref _nestedWrite, StreamInUse) == StreamInUse)
929935
{
930936
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write"));
931937
}
@@ -948,7 +954,7 @@ private async ValueTask WriteAsyncInternal<TIOAdapter>(ReadOnlyMemory<byte> buff
948954
}
949955
finally
950956
{
951-
_nestedWrite = 0;
957+
_nestedWrite = StreamNotInUse;
952958
}
953959
}
954960

src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ public void ReturnBuffer()
170170
}
171171
}
172172

173+
// used to track ussage in _nested* variables bellow
174+
private const int StreamNotInUse = 0;
175+
private const int StreamInUse = 1;
176+
private const int StreamDisposed = 2;
177+
173178
private int _nestedWrite;
174179
private int _nestedRead;
175180

@@ -703,7 +708,7 @@ public override async ValueTask DisposeAsync()
703708
public override int ReadByte()
704709
{
705710
ThrowIfExceptionalOrNotAuthenticated();
706-
if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
711+
if (Interlocked.Exchange(ref _nestedRead, StreamInUse) == StreamInUse)
707712
{
708713
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read"));
709714
}
@@ -724,7 +729,7 @@ public override int ReadByte()
724729
// Regardless of whether we were able to read a byte from the buffer,
725730
// reset the read tracking. If we weren't able to read a byte, the
726731
// subsequent call to Read will set the flag again.
727-
_nestedRead = 0;
732+
_nestedRead = StreamNotInUse;
728733
}
729734

730735
// Otherwise, fall back to reading a byte via Read, the same way Stream.ReadByte does.

src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamDisposeTest.cs

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.IO;
5-
using System.Net.Test.Common;
65
using System.Security.Cryptography.X509Certificates;
6+
using System.Threading;
77
using System.Threading.Tasks;
88

99
using Xunit;
@@ -12,13 +12,13 @@ namespace System.Net.Security.Tests
1212
{
1313
using Configuration = System.Net.Test.Common.Configuration;
1414

15-
public abstract class SslStreamDisposeTest
15+
public class SslStreamDisposeTest
1616
{
1717
[Fact]
1818
public async Task DisposeAsync_NotConnected_ClosesStream()
1919
{
2020
bool disposed = false;
21-
var stream = new SslStream(new DelegateStream(disposeFunc: _ => disposed = true), false, delegate { return true; });
21+
var stream = new SslStream(new DelegateStream(disposeFunc: _ => disposed = true, canReadFunc: () => true, canWriteFunc: () => true), false, delegate { return true; });
2222

2323
Assert.False(disposed);
2424
await stream.DisposeAsync();
@@ -50,5 +50,57 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
5050
await serverStream.DisposeAsync();
5151
Assert.NotEqual(0, trackingStream2.TimesCalled(nameof(Stream.DisposeAsync)));
5252
}
53+
54+
[Theory]
55+
[InlineData(true)]
56+
[InlineData(false)]
57+
public async Task Dispose_PendingReadAsync_ThrowsODE(bool bufferedRead)
58+
{
59+
using CancellationTokenSource cts = new CancellationTokenSource();
60+
cts.CancelAfter(TestConfiguration.PassingTestTimeout);
61+
62+
(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(leaveInnerStreamOpen: true);
63+
using (client)
64+
using (server)
65+
using (X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate())
66+
using (X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate())
67+
{
68+
SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
69+
{
70+
TargetHost = Guid.NewGuid().ToString("N"),
71+
};
72+
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;
73+
74+
SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions()
75+
{
76+
ServerCertificate = serverCertificate,
77+
};
78+
79+
await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
80+
client.AuthenticateAsClientAsync(clientOptions, default),
81+
server.AuthenticateAsServerAsync(serverOptions, default));
82+
83+
await TestHelper.PingPong(client, server, cts.Token);
84+
85+
await server.WriteAsync("PINGPONG"u8.ToArray(), cts.Token);
86+
var readBuffer = new byte[1024];
87+
88+
Task<int>? task = null;
89+
if (bufferedRead)
90+
{
91+
// This will read everything into internal buffer. Following ReadAsync will not need IO.
92+
task = client.ReadAsync(readBuffer, 0, 4, cts.Token);
93+
client.Dispose();
94+
int readLength = await task.ConfigureAwait(false);
95+
Assert.Equal(4, readLength);
96+
}
97+
else
98+
{
99+
client.Dispose();
100+
}
101+
102+
await Assert.ThrowsAnyAsync<ObjectDisposedException>(() => client.ReadAsync(readBuffer, cts.Token).AsTask());
103+
}
104+
}
53105
}
54106
}

src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ public static bool AllowAnyServerCertificate(object sender, X509Certificate cert
5151
return true;
5252
}
5353

54-
public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams()
54+
public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams(bool leaveInnerStreamOpen = false)
5555
{
5656
(Stream clientStream, Stream serverStream) = GetConnectedStreams();
57-
return (new SslStream(clientStream), new SslStream(serverStream));
57+
return (new SslStream(clientStream, leaveInnerStreamOpen), new SslStream(serverStream, leaveInnerStreamOpen));
5858
}
5959

6060
public static (Stream ClientStream, Stream ServerStream) GetConnectedStreams()

0 commit comments

Comments
 (0)