Skip to content

Use full TLS record size for application data on Windows #95595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ internal void ProcessHandshakeSuccess()

_headerSize = streamSizes.Header;
_trailerSize = streamSizes.Trailer;
_maxDataSize = checked(streamSizes.MaximumMessage - (_headerSize + _trailerSize));
_maxDataSize = streamSizes.MaximumMessage;
Debug.Assert(_maxDataSize > 0, "_maxDataSize > 0");

SslStreamPal.QueryContextConnectionInfo(_securityContext!, ref _connectionInfo);
Expand All @@ -942,18 +942,6 @@ internal void ProcessHandshakeSuccess()
#endif
}

/*++
Encrypt - Encrypts our bytes before we send them over the wire

PERF: make more efficient, this does an extra copy when the offset
is non-zero.

Input:
buffer - bytes for sending
offset -
size -
output - Encrypted bytes
--*/
internal ProtocolToken Encrypt(ReadOnlyMemory<byte> buffer)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.DumpBuffer(this, buffer.Span);
Expand Down Expand Up @@ -1337,7 +1325,7 @@ internal void EnsureAvailableSpace(int size)

var oldPayload = Payload;

Payload = RentBuffer? ArrayPool<byte>.Shared.Rent(Size + size) : new byte[Size + size];
Payload = RentBuffer ? ArrayPool<byte>.Shared.Rent(Size + size) : new byte[Size + size];
if (oldPayload != null)
{
oldPayload.AsSpan<byte>().CopyTo(Payload);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Authentication.ExtendedProtection;
Expand Down Expand Up @@ -49,7 +50,8 @@ public static Exception GetException(SecurityStatusPal status)

private static byte[] InitSessionTokenBuffer()
{
var schannelSessionToken = new Interop.SChannel.SCHANNEL_SESSION_TOKEN() {
var schannelSessionToken = new Interop.SChannel.SCHANNEL_SESSION_TOKEN()
{
dwTokenType = Interop.SChannel.SCHANNEL_SESSION,
dwFlags = Interop.SChannel.SSL_SESSION_DISABLE_RECONNECTS,
};
Expand All @@ -61,7 +63,7 @@ public static void VerifyPackageInfo()
SSPIWrapper.GetVerifyPackageInfo(GlobalSSPI.SSPISecureChannel, SecurityPackage, true);
}

private static unsafe void SetAlpn(ref InputSecurityBuffers inputBuffers, List<SslApplicationProtocol> alpn, Span<byte> localBuffer)
private static void SetAlpn(ref InputSecurityBuffers inputBuffers, List<SslApplicationProtocol> alpn, Span<byte> localBuffer)
{
if (alpn.Count == 1 && alpn[0] == SslApplicationProtocol.Http11)
{
Expand All @@ -82,7 +84,7 @@ private static unsafe void SetAlpn(ref InputSecurityBuffers inputBuffers, List<S
else
{
int protocolLength = Interop.Sec_Application_Protocols.GetProtocolLength(alpn);
int bufferLength = sizeof(Interop.Sec_Application_Protocols) + protocolLength;
int bufferLength = Unsafe.SizeOf<Interop.Sec_Application_Protocols>() + protocolLength;

Span<byte> alpnBuffer = bufferLength <= localBuffer.Length ? localBuffer : new byte[bufferLength];
Interop.Sec_Application_Protocols.SetProtocols(alpnBuffer, alpn, protocolLength);
Expand All @@ -99,7 +101,7 @@ public static SecurityStatusPal SelectApplicationProtocol(
throw new PlatformNotSupportedException(nameof(SelectApplicationProtocol));
}

public static unsafe ProtocolToken AcceptSecurityContext(
public static ProtocolToken AcceptSecurityContext(
ref SafeFreeCredentials? credentialsHandle,
ref SafeDeleteSslContext? context,
ReadOnlySpan<byte> inputBuffer,
Expand Down Expand Up @@ -141,7 +143,7 @@ public static bool TryUpdateClintCertificate(
return false;
}

public static unsafe ProtocolToken InitializeSecurityContext(
public static ProtocolToken InitializeSecurityContext(
ref SafeFreeCredentials? credentialsHandle,
ref SafeDeleteSslContext? context,
string? targetName,
Expand Down Expand Up @@ -445,32 +447,32 @@ public static unsafe ProtocolToken EncryptMessage(SafeDeleteSslContext securityC
input.Span.CopyTo(token.AvailableSpan.Slice(headerSize, input.Length));

const int NumSecBuffers = 4; // header + data + trailer + empty
Interop.SspiCli.SecBuffer* unmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers];
Span<Interop.SspiCli.SecBuffer> unmanagedBuffers = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers];
Interop.SspiCli.SecBufferDesc sdcInOut = new Interop.SspiCli.SecBufferDesc(NumSecBuffers)
{
pBuffers = unmanagedBuffer
pBuffers = Unsafe.AsPointer(ref MemoryMarshal.GetReference(unmanagedBuffers))
};
fixed (byte* outputPtr = token.Payload)
{
Interop.SspiCli.SecBuffer* headerSecBuffer = &unmanagedBuffer[0];
headerSecBuffer->BufferType = SecurityBufferType.SECBUFFER_STREAM_HEADER;
headerSecBuffer->pvBuffer = (IntPtr)outputPtr;
headerSecBuffer->cbBuffer = headerSize;
ref Interop.SspiCli.SecBuffer headerSecBuffer = ref unmanagedBuffers[0];
headerSecBuffer.BufferType = SecurityBufferType.SECBUFFER_STREAM_HEADER;
headerSecBuffer.pvBuffer = (IntPtr)outputPtr;
headerSecBuffer.cbBuffer = headerSize;

Interop.SspiCli.SecBuffer* dataSecBuffer = &unmanagedBuffer[1];
dataSecBuffer->BufferType = SecurityBufferType.SECBUFFER_DATA;
dataSecBuffer->pvBuffer = (IntPtr)(outputPtr + headerSize);
dataSecBuffer->cbBuffer = input.Length;
ref Interop.SspiCli.SecBuffer dataSecBuffer = ref unmanagedBuffers[1];
dataSecBuffer.BufferType = SecurityBufferType.SECBUFFER_DATA;
dataSecBuffer.pvBuffer = (IntPtr)(outputPtr + headerSize);
dataSecBuffer.cbBuffer = input.Length;

Interop.SspiCli.SecBuffer* trailerSecBuffer = &unmanagedBuffer[2];
trailerSecBuffer->BufferType = SecurityBufferType.SECBUFFER_STREAM_TRAILER;
trailerSecBuffer->pvBuffer = (IntPtr)(outputPtr + headerSize + input.Length);
trailerSecBuffer->cbBuffer = trailerSize;
ref Interop.SspiCli.SecBuffer trailerSecBuffer = ref unmanagedBuffers[2];
trailerSecBuffer.BufferType = SecurityBufferType.SECBUFFER_STREAM_TRAILER;
trailerSecBuffer.pvBuffer = (IntPtr)(outputPtr + headerSize + input.Length);
trailerSecBuffer.cbBuffer = trailerSize;

Interop.SspiCli.SecBuffer* emptySecBuffer = &unmanagedBuffer[3];
emptySecBuffer->BufferType = SecurityBufferType.SECBUFFER_EMPTY;
emptySecBuffer->cbBuffer = 0;
emptySecBuffer->pvBuffer = IntPtr.Zero;
ref Interop.SspiCli.SecBuffer emptySecBuffer = ref unmanagedBuffers[3];
emptySecBuffer.BufferType = SecurityBufferType.SECBUFFER_EMPTY;
emptySecBuffer.cbBuffer = 0;
emptySecBuffer.pvBuffer = IntPtr.Zero;

int errorCode = GlobalSSPI.SSPISecureChannel.EncryptMessage(securityContext, ref sdcInOut, 0);

Expand All @@ -483,10 +485,10 @@ public static unsafe ProtocolToken EncryptMessage(SafeDeleteSslContext securityC
return token;
}

Debug.Assert(headerSecBuffer->cbBuffer >= 0 && dataSecBuffer->cbBuffer >= 0 && trailerSecBuffer->cbBuffer >= 0);
Debug.Assert(checked(headerSecBuffer->cbBuffer + dataSecBuffer->cbBuffer + trailerSecBuffer->cbBuffer) <= token.Payload!.Length);
Debug.Assert(headerSecBuffer.cbBuffer >= 0 && dataSecBuffer.cbBuffer >= 0 && trailerSecBuffer.cbBuffer >= 0);
Debug.Assert(checked(headerSecBuffer.cbBuffer + dataSecBuffer.cbBuffer + trailerSecBuffer.cbBuffer) <= token.Payload!.Length);

token.Size = checked(headerSecBuffer->cbBuffer + dataSecBuffer->cbBuffer + trailerSecBuffer->cbBuffer);
token.Size = checked(headerSecBuffer.cbBuffer + dataSecBuffer.cbBuffer + trailerSecBuffer.cbBuffer);
token.Status = new SecurityStatusPal(SecurityStatusPalErrorCode.OK);
}

Expand All @@ -496,25 +498,26 @@ public static unsafe ProtocolToken EncryptMessage(SafeDeleteSslContext securityC
public static unsafe SecurityStatusPal DecryptMessage(SafeDeleteSslContext? securityContext, Span<byte> buffer, out int offset, out int count)
{
const int NumSecBuffers = 4; // data + empty + empty + empty
fixed (byte* bufferPtr = buffer)

Span<Interop.SspiCli.SecBuffer> unmanagedBuffers = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers];
for (int i = 1; i < NumSecBuffers; i++)
{
Interop.SspiCli.SecBuffer* unmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers];
Interop.SspiCli.SecBuffer* dataBuffer = &unmanagedBuffer[0];
dataBuffer->BufferType = SecurityBufferType.SECBUFFER_DATA;
dataBuffer->pvBuffer = (IntPtr)bufferPtr;
dataBuffer->cbBuffer = buffer.Length;
ref Interop.SspiCli.SecBuffer emptyBuffer = ref unmanagedBuffers[i];
emptyBuffer.BufferType = SecurityBufferType.SECBUFFER_EMPTY;
emptyBuffer.pvBuffer = IntPtr.Zero;
emptyBuffer.cbBuffer = 0;
}

for (int i = 1; i < NumSecBuffers; i++)
{
Interop.SspiCli.SecBuffer* emptyBuffer = &unmanagedBuffer[i];
emptyBuffer->BufferType = SecurityBufferType.SECBUFFER_EMPTY;
emptyBuffer->pvBuffer = IntPtr.Zero;
emptyBuffer->cbBuffer = 0;
}
fixed (byte* bufferPtr = buffer)
{
ref Interop.SspiCli.SecBuffer dataBuffer = ref unmanagedBuffers[0];
dataBuffer.BufferType = SecurityBufferType.SECBUFFER_DATA;
dataBuffer.pvBuffer = (IntPtr)bufferPtr;
dataBuffer.cbBuffer = buffer.Length;

Interop.SspiCli.SecBufferDesc sdcInOut = new Interop.SspiCli.SecBufferDesc(NumSecBuffers)
{
pBuffers = unmanagedBuffer
pBuffers = Unsafe.AsPointer(ref MemoryMarshal.GetReference(unmanagedBuffers))
};
Interop.SECURITY_STATUS errorCode = (Interop.SECURITY_STATUS)GlobalSSPI.SSPISecureChannel.DecryptMessage(securityContext!, ref sdcInOut, out _);

Expand All @@ -525,12 +528,12 @@ public static unsafe SecurityStatusPal DecryptMessage(SafeDeleteSslContext? secu
for (int i = 0; i < NumSecBuffers; i++)
{
// Successfully decoded data and placed it at the following position in the buffer,
if ((errorCode == Interop.SECURITY_STATUS.OK && unmanagedBuffer[i].BufferType == SecurityBufferType.SECBUFFER_DATA)
if ((errorCode == Interop.SECURITY_STATUS.OK && unmanagedBuffers[i].BufferType == SecurityBufferType.SECBUFFER_DATA)
// or we failed to decode the data, here is the encoded data.
|| (errorCode != Interop.SECURITY_STATUS.OK && unmanagedBuffer[i].BufferType == SecurityBufferType.SECBUFFER_EXTRA))
|| (errorCode != Interop.SECURITY_STATUS.OK && unmanagedBuffers[i].BufferType == SecurityBufferType.SECBUFFER_EXTRA))
{
offset = (int)((byte*)unmanagedBuffer[i].pvBuffer - bufferPtr);
count = unmanagedBuffer[i].cbBuffer;
offset = (int)((byte*)unmanagedBuffers[i].pvBuffer - bufferPtr);
count = unmanagedBuffers[i].cbBuffer;

// output is ignored on Windows. We always decrypt in place and we set outputOffset to indicate where the data start.
Debug.Assert(offset >= 0 && count >= 0, $"Expected offset and count greater than 0, got {offset} and {count}");
Expand Down