Skip to content

Commit 16fc92f

Browse files
authored
avoid ProtocolToken allocations in TLS handshake (#86163)
* avoild ProtocolToken allocations in TLS handshake * cleanup * UnitTests * feedback from review
1 parent 8c637de commit 16fc92f

File tree

5 files changed

+52
-40
lines changed

5 files changed

+52
-40
lines changed

src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Android.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public partial class SslStream
1414
{
1515
private JavaProxy.RemoteCertificateValidationResult VerifyRemoteCertificate()
1616
{
17-
ProtocolToken? alertToken = null;
17+
ProtocolToken alertToken = default;
1818
var isValid = VerifyRemoteCertificate(
1919
_sslAuthenticationOptions.CertValidationDelegate,
2020
_sslAuthenticationOptions.CertificateContext?.Trust,
@@ -31,13 +31,13 @@ private JavaProxy.RemoteCertificateValidationResult VerifyRemoteCertificate()
3131
};
3232
}
3333

34-
private bool TryGetRemoteCertificateValidationResult(out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus, out ProtocolToken? alertToken, out bool isValid)
34+
private bool TryGetRemoteCertificateValidationResult(out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus, ref ProtocolToken alertToken, out bool isValid)
3535
{
3636
JavaProxy.RemoteCertificateValidationResult? validationResult = _securityContext?.SslStreamProxy.ValidationResult;
3737
sslPolicyErrors = validationResult?.SslPolicyErrors ?? default;
3838
chainStatus = validationResult?.ChainStatus ?? default;
3939
isValid = validationResult?.IsValid ?? default;
40-
alertToken = validationResult?.AlertToken;
40+
alertToken = validationResult?.AlertToken ?? default;
4141
return validationResult is not null;
4242
}
4343

src/libraries/System.Net.Security/src/System/Net/Security/SslStream.IO.cs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationTo
216216
ProtocolToken message;
217217
do
218218
{
219-
message = await ReceiveBlobAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
219+
int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
220+
ProcessTlsFrame(frameSize, out message);
221+
220222
if (message.Size > 0)
221223
{
222224
await TIOAdapter.WriteAsync(InnerStream, new ReadOnlyMemory<byte>(message.Payload!, 0, message.Size), cancellationToken).ConfigureAwait(false);
@@ -245,7 +247,7 @@ private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationTo
245247
private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[]? reAuthenticationData, CancellationToken cancellationToken)
246248
where TIOAdapter : IReadWriteAdapter
247249
{
248-
ProtocolToken message;
250+
ProtocolToken message = default;
249251
bool handshakeCompleted = false;
250252

251253
if (reAuthenticationData == null)
@@ -256,12 +258,12 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
256258
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate"));
257259
}
258260
}
259-
260261
try
261262
{
262263
if (!receiveFirst)
263264
{
264-
message = NextMessage(reAuthenticationData);
265+
NextMessage(reAuthenticationData, out message);
266+
265267
if (message.Size > 0)
266268
{
267269
await TIOAdapter.WriteAsync(InnerStream, new ReadOnlyMemory<byte>(message.Payload!, 0, message.Size), cancellationToken).ConfigureAwait(false);
@@ -289,7 +291,8 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
289291

290292
while (!handshakeCompleted)
291293
{
292-
message = await ReceiveBlobAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
294+
int frameSize = await ReceiveTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
295+
ProcessTlsFrame(frameSize, out message);
293296

294297
ReadOnlyMemory<byte> payload = default;
295298
if (message.Size > 0)
@@ -355,7 +358,8 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
355358

356359
}
357360

358-
private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(CancellationToken cancellationToken)
361+
// This method will make sure we have at least one full TLS frame buffered.
362+
private async ValueTask<int> ReceiveTlsFrameAsync<TIOAdapter>(CancellationToken cancellationToken)
359363
where TIOAdapter : IReadWriteAdapter
360364
{
361365
int frameSize = await EnsureFullTlsFrameAsync<TIOAdapter>(cancellationToken).ConfigureAwait(false);
@@ -430,11 +434,11 @@ private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(Cancellation
430434

431435
}
432436

433-
return ProcessBlob(frameSize);
437+
return frameSize;
434438
}
435439

436440
// Calls crypto on received data. No IO inside.
437-
private ProtocolToken ProcessBlob(int frameSize)
441+
private void ProcessTlsFrame(int frameSize, out ProtocolToken message)
438442
{
439443
int chunkSize = frameSize;
440444

@@ -467,26 +471,26 @@ private ProtocolToken ProcessBlob(int frameSize)
467471
_buffer.DiscardEncrypted(frameSize);
468472
}
469473

470-
return NextMessage(availableData.Slice(0, chunkSize));
474+
NextMessage(availableData.Slice(0, chunkSize), out message);
471475
}
472476

473477
//
474478
// This is to reset auth state on remote side.
475479
// If this write succeeds we will allow auth retrying.
476480
//
477-
private void SendAuthResetSignal(ProtocolToken? message, ExceptionDispatchInfo exception)
481+
private void SendAuthResetSignal(ReadOnlySpan<byte> alert, ExceptionDispatchInfo exception)
478482
{
479483
SetException(exception.SourceException);
480484

481-
if (message == null || message.Size == 0)
485+
if (alert.Length == 0)
482486
{
483487
//
484488
// We don't have an alert to send so cannot retry and fail prematurely.
485489
//
486490
exception.Throw();
487491
}
488492

489-
InnerStream.Write(message.Payload!, 0, message.Size);
493+
InnerStream.Write(alert);
490494

491495
exception.Throw();
492496
}
@@ -499,7 +503,7 @@ private void SendAuthResetSignal(ProtocolToken? message, ExceptionDispatchInfo e
499503
//
500504
// - Returns false if failed to verify the Remote Cert
501505
//
502-
private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
506+
private bool CompleteHandshake(ref ProtocolToken alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
503507
{
504508
ProcessHandshakeSuccess();
505509

@@ -527,7 +531,7 @@ private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyError
527531
// The Java TrustManager callback is called only when the peer has a certificate. It's possible that
528532
// the peer didn't provide any certificate (for example when the peer is the client) and the validation
529533
// result hasn't been set. In that case we still need to run the verification at this point.
530-
if (TryGetRemoteCertificateValidationResult(out sslPolicyErrors, out chainStatus, out alertToken, out bool isValid))
534+
if (TryGetRemoteCertificateValidationResult(out sslPolicyErrors, out chainStatus, ref alertToken, out bool isValid))
531535
{
532536
_handshakeCompleted = isValid;
533537
return isValid;
@@ -546,23 +550,23 @@ private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyError
546550

547551
private void CompleteHandshake(SslAuthenticationOptions sslAuthenticationOptions)
548552
{
549-
ProtocolToken? alertToken = null;
553+
ProtocolToken alertToken = default;
550554
if (!CompleteHandshake(ref alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus))
551555
{
552556
if (sslAuthenticationOptions!.CertValidationDelegate != null)
553557
{
554558
// there may be some chain errors but the decision was made by custom callback. Details should be tracing if enabled.
555-
SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_custom_validation, null)));
559+
SendAuthResetSignal(new ReadOnlySpan<byte>(alertToken.Payload), ExceptionDispatchInfo.Capture(new AuthenticationException(SR.net_ssl_io_cert_custom_validation, null)));
556560
}
557561
else if (sslPolicyErrors == SslPolicyErrors.RemoteCertificateChainErrors && chainStatus != X509ChainStatusFlags.NoError)
558562
{
559563
// We failed only because of chain and we have some insight.
560-
SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_chain_validation, chainStatus), null)));
564+
SendAuthResetSignal(new ReadOnlySpan<byte>(alertToken.Payload), ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_chain_validation, chainStatus), null)));
561565
}
562566
else
563567
{
564568
// Simple add sslPolicyErrors as crude info.
565-
SendAuthResetSignal(alertToken, ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_validation, sslPolicyErrors), null)));
569+
SendAuthResetSignal(new ReadOnlySpan<byte>(alertToken.Payload), ExceptionDispatchInfo.Capture(new AuthenticationException(SR.Format(SR.net_ssl_io_cert_validation, sslPolicyErrors), null)));
566570
}
567571
}
568572
}

src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -751,20 +751,20 @@ static DateTime GetExpiryTimestamp(SslStreamCertificateContext certificateContex
751751
}
752752

753753
//
754-
internal ProtocolToken NextMessage(ReadOnlySpan<byte> incomingBuffer)
754+
internal void NextMessage(ReadOnlySpan<byte> incomingBuffer, out ProtocolToken token)
755755
{
756756
byte[]? nextmsg = null;
757-
SecurityStatusPal status = GenerateToken(incomingBuffer, ref nextmsg);
758-
ProtocolToken token = new ProtocolToken(nextmsg, status);
757+
token.Status = GenerateToken(incomingBuffer, ref nextmsg);
758+
token.Size = nextmsg?.Length ?? 0;
759+
token.Payload = nextmsg;
759760

760761
if (NetEventSource.Log.IsEnabled())
761762
{
762763
if (token.Failed)
763764
{
764-
NetEventSource.Error(this, $"Authentication failed. Status: {status}, Exception message: {token.GetException()!.Message}");
765+
NetEventSource.Error(this, $"Authentication failed. Status: {token.Status}, Exception message: {token.GetException()!.Message}");
765766
}
766767
}
767-
return token;
768768
}
769769

770770
/*++
@@ -992,7 +992,7 @@ internal SecurityStatusPal Decrypt(Span<byte> buffer, out int outputOffset, out
992992
--*/
993993

994994
//This method validates a remote certificate.
995-
internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remoteCertValidationCallback, SslCertificateTrust? trust, ref ProtocolToken? alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
995+
internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remoteCertValidationCallback, SslCertificateTrust? trust, ref ProtocolToken alertToken, out SslPolicyErrors sslPolicyErrors, out X509ChainStatusFlags chainStatus)
996996
{
997997
sslPolicyErrors = SslPolicyErrors.None;
998998
chainStatus = X509ChainStatusFlags.NoError;
@@ -1085,7 +1085,7 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot
10851085

10861086
if (!success)
10871087
{
1088-
alertToken = CreateFatalHandshakeAlertToken(sslPolicyErrors, chain!);
1088+
CreateFatalHandshakeAlertToken(sslPolicyErrors, chain!, ref alertToken);
10891089
if (chain != null)
10901090
{
10911091
foreach (X509ChainStatus status in chain.ChainStatus)
@@ -1115,7 +1115,7 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot
11151115
return success;
11161116
}
11171117

1118-
private ProtocolToken? CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErrors, X509Chain chain)
1118+
private void CreateFatalHandshakeAlertToken(SslPolicyErrors sslPolicyErrors, X509Chain chain, ref ProtocolToken alertToken)
11191119
{
11201120
TlsAlertMessage alertMessage;
11211121

@@ -1148,15 +1148,14 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot
11481148
{
11491149
ExceptionDispatchInfo.Throw(status.Exception);
11501150
}
1151-
1152-
return null;
11531151
}
11541152

1155-
return GenerateAlertToken();
1153+
GenerateAlertToken(ref alertToken);
11561154
}
11571155

1158-
private ProtocolToken? CreateShutdownToken()
1156+
private byte[]? CreateShutdownToken()
11591157
{
1158+
byte[]? nextmsg = null;
11601159
SecurityStatusPal status;
11611160
status = SslStreamPal.ApplyShutdownToken(_securityContext!);
11621161

@@ -1173,17 +1172,21 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot
11731172
return null;
11741173
}
11751174

1176-
return GenerateAlertToken();
1175+
GenerateToken(default, ref nextmsg);
1176+
1177+
return nextmsg;
11771178
}
11781179

1179-
private ProtocolToken GenerateAlertToken()
1180+
private void GenerateAlertToken(ref ProtocolToken alertToken)
11801181
{
11811182
byte[]? nextmsg = null;
11821183

11831184
SecurityStatusPal status;
11841185
status = GenerateToken(default, ref nextmsg);
11851186

1186-
return new ProtocolToken(nextmsg, status);
1187+
alertToken.Payload = nextmsg;
1188+
alertToken.Size = nextmsg?.Length ?? 0;
1189+
alertToken.Status = status;
11871190
}
11881191

11891192
private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain)
@@ -1286,7 +1289,7 @@ private void LogCertificateValidation(RemoteCertificateValidationCallback? remot
12861289
}
12871290

12881291
// ProtocolToken - used to process and handle the return codes from the SSPI wrapper
1289-
internal sealed class ProtocolToken
1292+
internal struct ProtocolToken
12901293
{
12911294
internal SecurityStatusPal Status;
12921295
internal byte[]? Payload;

src/libraries/System.Net.Security/src/System/Net/Security/SslStream.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,14 @@ public virtual Task ShutdownAsync()
441441
{
442442
ThrowIfExceptionalOrNotAuthenticatedOrShutdown();
443443

444-
ProtocolToken message = CreateShutdownToken()!;
444+
byte[]? message = CreateShutdownToken();
445445
_shutdown = true;
446-
return InnerStream.WriteAsync(message.Payload, default).AsTask();
446+
if (message != null)
447+
{
448+
return InnerStream.WriteAsync(message, default).AsTask();
449+
}
450+
451+
return Task.CompletedTask;
447452
}
448453
#endregion
449454

src/libraries/System.Net.Security/tests/UnitTests/Fakes/FakeSslStream.Implementation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ private void ReturnReadBufferIfEmpty()
9292
{
9393
}
9494

95-
private ProtocolToken? CreateShutdownToken()
95+
private byte[]? CreateShutdownToken()
9696
{
9797
return null;
9898
}

0 commit comments

Comments
 (0)