Skip to content

Replace TlsStream type by using SslStream directly #106451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 0 additions & 106 deletions src/libraries/Common/src/System/Net/TlsStream.cs

This file was deleted.

2 changes: 0 additions & 2 deletions src/libraries/System.Net.Mail/src/System.Net.Mail.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@
Link="Common\System\Net\DebugSafeHandleZeroOrMinusOneIsInvalid.cs" />
<Compile Include="$(CommonPath)System\Net\DebugSafeHandle.cs"
Link="Common\System\Net\DebugSafeHandle.cs" />
<Compile Include="$(CommonPath)System\Net\TlsStream.cs"
Link="Common\System\Net\TlsStream.cs" />
<Compile Include="$(CommonPath)System\Net\InternalException.cs"
Link="Common\System\Net\InternalException.cs" />
<Compile Include="$(CommonPath)System\Net\ExceptionCheck.cs"
Expand Down
78 changes: 48 additions & 30 deletions src/libraries/System.Net.Mail/src/System/Net/Mail/SmtpConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ internal sealed partial class SmtpConnection
private readonly EventHandler? _onCloseHandler;
internal SmtpTransport? _parent;
private readonly SmtpClient? _client;
private NetworkStream? _networkStream;
private Stream? _stream;
internal TcpClient? _tcpClient;
private SmtpReplyReaderFactory? _responseReader;

Expand Down Expand Up @@ -82,7 +82,7 @@ internal X509CertificateCollection? ClientCertificates
internal void InitializeConnection(string host, int port)
{
_tcpClient!.Connect(host, port);
_networkStream = _tcpClient.GetStream();
_stream = _tcpClient.GetStream();
}

internal IAsyncResult BeginInitializeConnection(string host, int port, AsyncCallback? callback, object? state)
Expand All @@ -93,7 +93,7 @@ internal IAsyncResult BeginInitializeConnection(string host, int port, AsyncCall
internal void EndInitializeConnection(IAsyncResult result)
{
_tcpClient!.EndConnect(result);
_networkStream = _tcpClient.GetStream();
_stream = _tcpClient.GetStream();
}

internal IAsyncResult BeginGetConnection(ContextAwareResult outerResult, AsyncCallback? callback, object? state, string host, int port)
Expand All @@ -105,18 +105,18 @@ internal IAsyncResult BeginGetConnection(ContextAwareResult outerResult, AsyncCa

internal IAsyncResult BeginFlush(AsyncCallback? callback, object? state)
{
return _networkStream!.BeginWrite(_bufferBuilder.GetBuffer(), 0, _bufferBuilder.Length, callback, state);
return _stream!.BeginWrite(_bufferBuilder.GetBuffer(), 0, _bufferBuilder.Length, callback, state);
}

internal void EndFlush(IAsyncResult result)
{
_networkStream!.EndWrite(result);
_stream!.EndWrite(result);
_bufferBuilder.Reset();
}

internal void Flush()
{
_networkStream!.Write(_bufferBuilder.GetBuffer(), 0, _bufferBuilder.Length);
_stream!.Write(_bufferBuilder.GetBuffer(), 0, _bufferBuilder.Length);
_bufferBuilder.Reset();
}

Expand Down Expand Up @@ -150,7 +150,7 @@ private void ShutdownConnection(bool isAbort)
finally
{
//free cbt buffer
_networkStream?.Close();
_stream?.Close();
_tcpClient.Dispose();
}
}
Expand Down Expand Up @@ -190,7 +190,7 @@ internal void GetConnection(string host, int port)
}

InitializeConnection(host, port);
_responseReader = new SmtpReplyReaderFactory(_networkStream!);
_responseReader = new SmtpReplyReaderFactory(_stream!);

LineInfo info = _responseReader.GetNextReplyReader().ReadLine();

Expand Down Expand Up @@ -225,17 +225,25 @@ internal void GetConnection(string host, int port)
if (!_serverSupportsStartTls)
{
// Either TLS is already established or server does not support TLS
if (!(_networkStream is TlsStream))
if (!(_stream is SslStream))
{
throw new SmtpException(SR.MailServerDoesNotSupportStartTls);
}
}

StartTlsCommand.Send(this);
TlsStream tlsStream = new TlsStream(_networkStream!, _tcpClient!.Client, host, _clientCertificates);
tlsStream.AuthenticateAsClient();
_networkStream = tlsStream;
_responseReader = new SmtpReplyReaderFactory(_networkStream);
#pragma warning disable SYSLIB0014 // ServicePointManager is obsolete
SslStream sslStream = new SslStream(_stream!, false, ServicePointManager.ServerCertificateValidationCallback);

sslStream.AuthenticateAsClient(
host,
_clientCertificates,
(SslProtocols)ServicePointManager.SecurityProtocol, // enums use same values
ServicePointManager.CheckCertificateRevocationList);
#pragma warning restore SYSLIB0014 // ServicePointManager is obsolete

_stream = sslStream;
_responseReader = new SmtpReplyReaderFactory(_stream);

// According to RFC 3207: The client SHOULD send an EHLO command
// as the first command after a successful TLS negotiation.
Expand Down Expand Up @@ -362,7 +370,7 @@ internal static void EndGetConnection(IAsyncResult result)

internal Stream GetClosableStream()
{
ClosableStream cs = new ClosableStream(_networkStream!, _onCloseHandler);
ClosableStream cs = new ClosableStream(_stream!, _onCloseHandler);
_isStreamOpen = true;
return cs;
}
Expand Down Expand Up @@ -460,7 +468,7 @@ private static void InitializeConnectionCallback(IAsyncResult result)

private void Handshake()
{
_connection._responseReader = new SmtpReplyReaderFactory(_connection._networkStream!);
_connection._responseReader = new SmtpReplyReaderFactory(_connection._stream!);

SmtpReplyReader reader = _connection.Reader!.GetNextReplyReader();
IAsyncResult result = reader.BeginReadLine(s_handshakeCallback, this);
Expand Down Expand Up @@ -533,10 +541,10 @@ private bool SendEHello()
{
_connection._extensions = EHelloCommand.EndSend(result);
_connection.ParseExtensions(_connection._extensions);
// If we already have a TlsStream, this is the second EHLO cmd
// If we already have a SslStream, this is the second EHLO cmd
// that we sent after TLS handshake compelted. So skip TLS and
// continue with Authenticate.
if (_connection._networkStream is TlsStream)
if (_connection._stream is SslStream)
{
Authenticate();
return true;
Expand All @@ -547,7 +555,7 @@ private bool SendEHello()
if (!_connection._serverSupportsStartTls)
{
// Either TLS is already established or server does not support TLS
if (!(_connection._networkStream is TlsStream))
if (!(_connection._stream is SslStream))
{
throw new SmtpException(SR.MailServerDoesNotSupportStartTls);
}
Expand Down Expand Up @@ -579,7 +587,7 @@ private static void SendEHelloCallback(IAsyncResult result)
// If we already have a SSlStream, this is the second EHLO cmd
// that we sent after TLS handshake compelted. So skip TLS and
// continue with Authenticate.
if (thisPtr._connection._networkStream is TlsStream)
if (thisPtr._connection._stream is SslStream)
{
thisPtr.Authenticate();
return;
Expand All @@ -606,7 +614,7 @@ private static void SendEHelloCallback(IAsyncResult result)
if (!thisPtr._connection._serverSupportsStartTls)
{
// Either TLS is already established or server does not support TLS
if (!(thisPtr._connection._networkStream is TlsStream))
if (!(thisPtr._connection._stream is SslStream))
{
throw new SmtpException(SR.MailServerDoesNotSupportStartTls);
}
Expand Down Expand Up @@ -663,7 +671,7 @@ private bool SendStartTls()
if (result.CompletedSynchronously)
{
StartTlsCommand.EndSend(result);
TlsStreamAuthenticate();
SslStreamAuthenticate();
return true;
}
return false;
Expand All @@ -677,7 +685,7 @@ private static void SendStartTlsCallback(IAsyncResult result)
try
{
StartTlsCommand.EndSend(result);
thisPtr.TlsStreamAuthenticate();
thisPtr.SslStreamAuthenticate();
}
catch (Exception e)
{
Expand All @@ -686,29 +694,39 @@ private static void SendStartTlsCallback(IAsyncResult result)
}
}

private bool TlsStreamAuthenticate()
private bool SslStreamAuthenticate()
{
_connection._networkStream = new TlsStream(_connection._networkStream!, _connection._tcpClient!.Client, _host, _connection._clientCertificates);
IAsyncResult result = ((TlsStream)_connection._networkStream).BeginAuthenticateAsClient(TlsStreamAuthenticateCallback, this);
#pragma warning disable SYSLIB0014 // ServicePointManager is obsolete
_connection._stream = new SslStream(_connection._stream!, false, ServicePointManager.ServerCertificateValidationCallback);

IAsyncResult result = ((SslStream)_connection._stream).BeginAuthenticateAsClient(
_host,
_connection._clientCertificates,
(SslProtocols)ServicePointManager.SecurityProtocol, // enums use same values
ServicePointManager.CheckCertificateRevocationList,
SslStreamAuthenticateCallback,
this);
#pragma warning restore SYSLIB0014 // ServicePointManager is obsolete

if (result.CompletedSynchronously)
{
((TlsStream)_connection._networkStream).EndAuthenticateAsClient(result);
_connection._responseReader = new SmtpReplyReaderFactory(_connection._networkStream);
((SslStream)_connection._stream).EndAuthenticateAsClient(result);
_connection._responseReader = new SmtpReplyReaderFactory(_connection._stream);
SendEHello();
return true;
}
return false;
}

private static void TlsStreamAuthenticateCallback(IAsyncResult result)
private static void SslStreamAuthenticateCallback(IAsyncResult result)
{
if (!result.CompletedSynchronously)
{
ConnectAndHandshakeAsyncResult thisPtr = (ConnectAndHandshakeAsyncResult)result.AsyncState!;
try
{
(thisPtr._connection._networkStream as TlsStream)!.EndAuthenticateAsClient(result);
thisPtr._connection._responseReader = new SmtpReplyReaderFactory(thisPtr._connection._networkStream);
(thisPtr._connection._stream as SslStream)!.EndAuthenticateAsClient(result);
thisPtr._connection._responseReader = new SmtpReplyReaderFactory(thisPtr._connection._stream);
thisPtr.SendEHello();
}
catch (Exception e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@
Link="ProductionCode\BufferBuilder.cs" />
<Compile Include="$(CommonPath)DisableRuntimeMarshalling.cs"
Link="Common\DisableRuntimeMarshalling.cs" />
<Compile Include="$(CommonPath)System\Net\TlsStream.cs"
Link="Common\System\Net\TlsStream.cs" />
<Compile Include="$(CommonPath)System\Net\InternalException.cs"
Link="Common\System\Net\InternalException.cs" />
<Compile Include="$(CommonPath)System\Net\LazyAsyncResult.cs"
Expand All @@ -140,8 +138,8 @@
Link="Common\System\HexConverter.cs" />
<Compile Include="$(CommonPath)System\Obsoletions.cs"
Link="Common\System\Obsoletions.cs" />
<Compile Include="$(CommonPath)System\Text\ValueStringBuilder.cs"
Link="Common\System\Text\ValueStringBuilder.cs" />
<Compile Include="$(CommonPath)System\Text\ValueStringBuilder.cs"
Link="Common\System\Text\ValueStringBuilder.cs" />
</ItemGroup>
<!-- Unix specific files -->
<ItemGroup Condition="'$(TargetPlatformIdentifier)' == 'unix'">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@
Link="Common\System\Net\ContextAwareResult.cs" />
<Compile Include="$(CommonPath)System\Net\ExceptionCheck.cs"
Link="Common\System\Net\ExceptionCheck.cs" />
<Compile Include="$(CommonPath)System\Net\TlsStream.cs"
Link="Common\System\Net\TlsStream.cs" />
<Compile Include="$(CommonPath)System\Net\SecurityProtocol.cs"
Link="Common\System\Net\SecurityProtocol.cs" />
<Compile Include="$(CommonPath)System\NotImplemented.cs"
Expand Down
Loading
Loading