Skip to content

avoid ProtocolToken allocations in TLS handshake #86163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public partial class SslStream
{
private JavaProxy.RemoteCertificateValidationResult VerifyRemoteCertificate()
{
ProtocolToken? alertToken = null;
ProtocolToken alertToken = default;
var isValid = VerifyRemoteCertificate(
_sslAuthenticationOptions.CertValidationDelegate,
_sslAuthenticationOptions.CertificateContext?.Trust,
Expand All @@ -31,13 +31,13 @@ private JavaProxy.RemoteCertificateValidationResult VerifyRemoteCertificate()
};
}

private bool TryGetRemoteCertificateValidationResult(out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus, out ProtocolToken? alertToken, out bool isValid)
private bool TryGetRemoteCertificateValidationResult(out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus, ref ProtocolToken alertToken, out bool isValid)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: A bunch of these look like they could remain as out (e.g., NextMessage, ProcessTlsFrame)

{
JavaProxy.RemoteCertificateValidationResult? validationResult = _securityContext?.SslStreamProxy.ValidationResult;
sslPolicyErrors = validationResult?.SslPolicyErrors ?? default;
chainStatus = validationResult?.ChainStatus ?? default;
isValid = validationResult?.IsValid ?? default;
alertToken = validationResult?.AlertToken;
alertToken = validationResult?.AlertToken ?? default;
return validationResult is not null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationTo
ProtocolToken message;
do
{
message = await ReceiveBlobAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
ProcessTlsFrame(frameSize, out message);

if (message.Size > 0)
{
await TIOAdapter.WriteAsync(InnerStream, new ReadOnlyMemory<byte>(message.Payload!, 0, message.Size), cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -245,7 +247,7 @@ private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationTo
private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[]? reAuthenticationData, CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
ProtocolToken message;
ProtocolToken message = default;
bool handshakeCompleted = false;

if (reAuthenticationData == null)
Expand All @@ -256,12 +258,12 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate"));
}
}

try
{
if (!receiveFirst)
{
message = NextMessage(reAuthenticationData);
NextMessage(reAuthenticationData, out message);

if (message.Size > 0)
{
await TIOAdapter.WriteAsync(InnerStream, new ReadOnlyMemory<byte>(message.Payload!, 0, message.Size), cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -289,7 +291,8 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[

while (!handshakeCompleted)
{
message = await ReceiveBlobAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
ProcessTlsFrame(frameSize, out message);

ReadOnlyMemory<byte> payload = default;
if (message.Size > 0)
Expand Down Expand Up @@ -355,7 +358,8 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[

}

private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(CancellationToken cancellationToken)
// This method will make sure we have at least one full TLS frame buffered.
private async ValueTask<int> ReceiveTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -430,11 +434,11 @@ private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(Cancellation

}

return ProcessBlob(frameSize);
return frameSize;
}

// Calls crypto on received data. No IO inside.
private ProtocolToken ProcessBlob(int frameSize)
private void ProcessTlsFrame(int frameSize, out ProtocolToken message)
{
int chunkSize = frameSize;

Expand Down Expand Up @@ -467,26 +471,26 @@ private ProtocolToken ProcessBlob(int frameSize)
_buffer.DiscardEncrypted(frameSize);
}

return NextMessage(availableData.Slice(0, chunkSize));
NextMessage(availableData.Slice(0, chunkSize), out message);
}

//
// This is to reset auth state on remote side.
// If this write succeeds we will allow auth retrying.
//
private void SendAuthResetSignal(ProtocolToken? message, ExceptionDispatchInfo exception)
private void SendAuthResetSignal(ReadOnlySpan<byte> alert, ExceptionDispatchInfo exception)
{
SetException(exception.SourceException);

if (message == null || message.Size == 0)
if (alert.Length == 0)
{
//
// We don't have an alert to send so cannot retry and fail prematurely.
//
exception.Throw();
}

InnerStream.Write(message.Payload!, 0, message.Size);
InnerStream.Write(alert);

exception.Throw();
}
Expand All @@ -499,7 +503,7 @@ private void SendAuthResetSignal(ProtocolToken? message, ExceptionDispatchInfo e
//
// - Returns false if failed to verify the Remote Cert
//
private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
private bool CompleteHandshake(ref ProtocolToken alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
{
ProcessHandshakeSuccess();

Expand Down Expand Up @@ -527,7 +531,7 @@ private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyError
// The Java TrustManager callback is called only when the peer has a certificate. It's possible that
// the peer didn't provide any certificate (for example when the peer is the client) and the validation
// result hasn't been set. In that case we still need to run the verification at this point.
if (TryGetRemoteCertificateValidationResult(out sslPolicyErrors, out chainStatus, out alertToken, out bool isValid))
if (TryGetRemoteCertificateValidationResult(out sslPolicyErrors, out chainStatus, ref alertToken, out bool isValid))
{
_handshakeCompleted = isValid;
return isValid;
Expand All @@ -546,23 +550,23 @@ private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyError

private void CompleteHandshake(SslAuthenticationOptions sslAuthenticationOptions)
{
ProtocolToken? alertToken = null;
ProtocolToken alertToken = default;
if (!CompleteHandshake(ref alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus))
{
if (sslAuthenticationOptions!.CertValidationDelegate != null)
{
// there may be some chain errors but the decision was made by custom callback. Details should be tracing if enabled.
SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_custom_validation, null)));
SendAuthResetSignal(new ReadOnlySpan<byte>(alertToken.Payload), ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_custom_validation, null)));
}
else if (sslPolicyErrors == SslPolicyErrors.RemoteCertificateChainErrors && chainStatus != X509ChainStatusFlags.NoError)
{
// We failed only because of chain and we have some insight.
SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_chain_validation, chainStatus), null)));
SendAuthResetSignal(new ReadOnlySpan<byte>(alertToken.Payload), ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_chain_validation, chainStatus), null)));
}
else
{
// Simple add sslPolicyErrors as crude info.
SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_validation, sslPolicyErrors), null)));
SendAuthResetSignal(new ReadOnlySpan<byte>(alertToken.Payload), ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_validation, sslPolicyErrors), null)));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,20 +751,20 @@ static DateTime GetExpiryTimestamp(SslStreamCertificateContext certificateContex
}

//
internal ProtocolToken NextMessage(ReadOnlySpan<byte> incomingBuffer)
internal void NextMessage(ReadOnlySpan<byte> incomingBuffer, out ProtocolToken token)
{
byte[]? nextmsg = null;
SecurityStatusPal status = GenerateToken(incomingBuffer, ref nextmsg);
ProtocolToken token = new ProtocolToken(nextmsg, status);
token.Status = GenerateToken(incomingBuffer, ref nextmsg);
token.Size = nextmsg?.Length ?? 0;
token.Payload = nextmsg;

if (NetEventSource.Log.IsEnabled())
{
if (token.Failed)
{
NetEventSource.Error(this, $"Authentication failed. Status: {status}, Exception message: {token.GetException()!.Message}");
NetEventSource.Error(this, $"Authentication failed. Status: {token.Status}, Exception message: {token.GetException()!.Message}");
}
}
return token;
}

/*++
Expand Down Expand Up @@ -992,7 +992,7 @@ internal SecurityStatusPal Decrypt(Span<byte> buffer, out int outputOffset, out
--*/

//This method validates a remote certificate.
internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remoteCertValidationCallback, SslCertificateTrust? trust, ref ProtocolToken? alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remoteCertValidationCallback, SslCertificateTrust? trust, ref ProtocolToken alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
{
sslPolicyErrors = SslPolicyErrors.None;
chainStatus = X509ChainStatusFlags.NoError;
Expand Down Expand Up @@ -1085,7 +1085,7 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot

if (!success)
{
alertToken = CreateFatalHandshakeAlertToken(sslPolicyErrors, chain!);
CreateFatalHandshakeAlertToken(sslPolicyErrors, chain!, ref alertToken);
if (chain != null)
{
foreach (X509ChainStatus status in chain.ChainStatus)
Expand Down Expand Up @@ -1115,7 +1115,7 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot
return success;
}

private ProtocolToken? CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErrors, X509Chain chain)
private void CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErrors, X509Chain chain, ref ProtocolToken alertToken)
{
TlsAlertMessage alertMessage;

Expand Down Expand Up @@ -1148,15 +1148,14 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot
{
ExceptionDispatchInfo.Throw(status.Exception);
}

return null;
}

return GenerateAlertToken();
GenerateAlertToken(ref alertToken);
}

private ProtocolToken? CreateShutdownToken()
private byte[]? CreateShutdownToken()
{
byte[]? nextmsg = null;
SecurityStatusPal status;
status = SslStreamPal.ApplyShutdownToken(_securityContext!);

Expand All @@ -1173,17 +1172,21 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot
return null;
}

return GenerateAlertToken();
GenerateToken(default, ref nextmsg);

return nextmsg;
}

private ProtocolToken GenerateAlertToken()
private void GenerateAlertToken(ref ProtocolToken alertToken)
{
byte[]? nextmsg = null;

SecurityStatusPal status;
status = GenerateToken(default, ref nextmsg);

return new ProtocolToken(nextmsg, status);
alertToken.Payload = nextmsg;
alertToken.Size = nextmsg?.Length ?? 0;
alertToken.Status = status;
}

private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain)
Expand Down Expand Up @@ -1286,7 +1289,7 @@ private void LogCertificateValidation(RemoteCertificateValidationCallback? remot
}

// ProtocolToken - used to process and handle the return codes from the SSPI wrapper
internal sealed class ProtocolToken
internal struct ProtocolToken
{
internal SecurityStatusPal Status;
internal byte[]? Payload;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,14 @@ public virtual Task ShutdownAsync()
{
ThrowIfExceptionalOrNotAuthenticatedOrShutdown();

ProtocolToken message = CreateShutdownToken()!;
byte[]? message = CreateShutdownToken();
_shutdown = true;
return InnerStream.WriteAsync(message.Payload, default).AsTask();
if (message != null)
{
return InnerStream.WriteAsync(message, default).AsTask();
}

return Task.CompletedTask;
}
#endregion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ private void ReturnReadBufferIfEmpty()
{
}

private ProtocolToken? CreateShutdownToken()
private byte[]? CreateShutdownToken()
{
return null;
}
Expand Down