Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using System.ServiceModel;
using System.ServiceModel.Channels;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Infrastructure.Common;
using Xunit;

public static class HttpStreamingAbortTests
{
private const int BufferSize = 1024;
private const int LargeStreamSize = 500000; // 500KB

[WcfFact]
[OuterLoop]
public static void HttpStreaming_Abort_During_Response_Receiving()
{
// This test validates that calling Abort() on an HTTP channel works correctly
// when the channel is in the middle of receiving a streamed response.
// If Abort() doesn't propagate correctly, the test will timeout with a TimeoutException.
// If Abort() works correctly, a CommunicationObjectAbortedException should be thrown.

ChannelFactory<IWcfService> factory = null;
IWcfService serviceProxy = null;
CustomBinding binding = null;
Stream responseStream = null;
Exception caughtException = null;

try
{
// *** SETUP *** \\
// Create a binding with streamed transfer mode
binding = new CustomBinding(
new TextMessageEncodingBindingElement(),
new HttpTransportBindingElement
{
TransferMode = TransferMode.StreamedResponse,
MaxReceivedMessageSize = 1024 * 1024 // 1 MB
});

// Set a reasonable ReceiveTimeout - if Abort() doesn't work, this will cause a timeout
binding.ReceiveTimeout = TimeSpan.FromSeconds(10);

factory = new ChannelFactory<IWcfService>(binding, new EndpointAddress(Endpoints.CustomTextEncoderStreamed_Address));
serviceProxy = factory.CreateChannel();

// Create a large string to ensure the response takes time to read
string testString = new string('a', LargeStreamSize);

// *** EXECUTE *** \\
// Start the call to get a stream response
responseStream = serviceProxy.GetStreamFromString(testString);

// Start reading a small amount from the stream to ensure we're in the receiving phase
byte[] buffer = new byte[BufferSize];
int bytesRead = responseStream.Read(buffer, 0, buffer.Length);

// Verify we actually received some data
Assert.True(bytesRead > 0, "Expected to read some data from the stream");

// Now abort the channel while we're in the middle of receiving the response
// This should cause the ongoing read operation to be cancelled
((ICommunicationObject)serviceProxy).Abort();

// Try to continue reading from the stream
// If Abort() works correctly, this should throw an exception
// If Abort() doesn't work, this will hang until the ReceiveTimeout expires
try
{
while (responseStream.Read(buffer, 0, buffer.Length) > 0)
{
// Keep reading
}
}
catch (Exception ex)
{
caughtException = ex;
}

// *** VALIDATE *** \\
// We expect an exception to be thrown after Abort() is called
Assert.NotNull(caughtException);

// The exception should be related to the communication object being aborted
// It could be CommunicationObjectAbortedException or an IOException wrapping it
Assert.True(
caughtException is CommunicationObjectAbortedException ||
caughtException is IOException ||
caughtException is CommunicationException,
$"Expected CommunicationObjectAbortedException, IOException, or CommunicationException, but got: {caughtException.GetType().Name}");
}
catch (TimeoutException)
{
// If we get a TimeoutException, it means Abort() didn't work correctly
Assert.Fail("Test timed out, which indicates that Abort() did not properly cancel the ongoing stream read operation.");
}
finally
{
// *** ENSURE CLEANUP *** \\
responseStream?.Dispose();
ScenarioTestHelpers.CloseCommunicationObjects((ICommunicationObject)serviceProxy, factory);
}
}

[WcfFact]
[OuterLoop]
public static async Task HttpStreaming_Abort_During_Async_Response_Receiving()
{
// This test validates that calling Abort() on an HTTP channel works correctly
// when the channel is in the middle of receiving a streamed response asynchronously.
// If Abort() doesn't propagate correctly, the test will timeout with a TimeoutException.
// If Abort() works correctly, a CommunicationObjectAbortedException should be thrown.

ChannelFactory<IWcfService> factory = null;
IWcfService serviceProxy = null;
CustomBinding binding = null;
Stream responseStream = null;
Exception caughtException = null;

try
{
// *** SETUP *** \\
// Create a binding with streamed transfer mode
binding = new CustomBinding(
new TextMessageEncodingBindingElement(),
new HttpTransportBindingElement
{
TransferMode = TransferMode.StreamedResponse,
MaxReceivedMessageSize = 1024 * 1024 // 1 MB
});

// Set a reasonable ReceiveTimeout - if Abort() doesn't work, this will cause a timeout
binding.ReceiveTimeout = TimeSpan.FromSeconds(10);

factory = new ChannelFactory<IWcfService>(binding, new EndpointAddress(Endpoints.CustomTextEncoderStreamed_Address));
serviceProxy = factory.CreateChannel();

// Create a large string to ensure the response takes time to read
string testString = new string('a', LargeStreamSize);

// *** EXECUTE *** \\
// Start the call to get a stream response
responseStream = serviceProxy.GetStreamFromString(testString);

// Start reading a small amount from the stream to ensure we're in the receiving phase
byte[] buffer = new byte[BufferSize];
int bytesRead = await responseStream.ReadAsync(buffer, 0, buffer.Length);

// Verify we actually received some data
Assert.True(bytesRead > 0, "Expected to read some data from the stream");

// Now abort the channel while we're in the middle of receiving the response
((ICommunicationObject)serviceProxy).Abort();

// Try to continue reading from the stream asynchronously
// If Abort() works correctly, this should throw an exception
// If Abort() doesn't work, this will hang until the ReceiveTimeout expires
try
{
while ((await responseStream.ReadAsync(buffer, 0, buffer.Length)) > 0)
{
// Keep reading
}
}
catch (Exception ex)
{
caughtException = ex;
}

// *** VALIDATE *** \\
// We expect an exception to be thrown after Abort() is called
Assert.NotNull(caughtException);

// The exception should be related to the communication object being aborted
// It could be CommunicationObjectAbortedException or an IOException wrapping it
Assert.True(
caughtException is CommunicationObjectAbortedException ||
caughtException is IOException ||
caughtException is CommunicationException ||
caughtException is OperationCanceledException,
$"Expected CommunicationObjectAbortedException, IOException, CommunicationException, or OperationCanceledException, but got: {caughtException.GetType().Name}");
}
catch (TimeoutException)
{
// If we get a TimeoutException, it means Abort() didn't work correctly
Assert.Fail("Test timed out, which indicates that Abort() did not properly cancel the ongoing stream read operation.");
}
finally
{
// *** ENSURE CLEANUP *** \\
responseStream?.Dispose();
ScenarioTestHelpers.CloseCommunicationObjects((ICommunicationObject)serviceProxy, factory);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ public async Task<Message> ReceiveReplyAsync(TimeoutHelper timeoutHelper)
try
{
_timeoutHelper = timeoutHelper;
var responseHelper = new HttpResponseMessageHelper(_httpResponseMessage, _factory);
var responseHelper = new HttpResponseMessageHelper(_httpResponseMessage, _factory, _httpSendCts);
var replyMessage = await responseHelper.ParseIncomingResponse(timeoutHelper);
TryCompleteHttpRequest(_httpRequestMessage);
return replyMessage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,28 @@ namespace System.ServiceModel.Channels
{
internal class HttpResponseMessageHelper
{
private static readonly Action<object> s_cancelCts = state =>
{
try
{
((CancellationTokenSource)state).Cancel();
}
catch (ObjectDisposedException)
{
// CancellationTokenSource may have been disposed by the time this callback executes
// due to a race condition between timeout/abort and cleanup
}
};

private readonly HttpChannelFactory<IRequestChannel> _factory;
private readonly MessageEncoder _encoder;
private readonly HttpRequestMessage _httpRequestMessage;
private readonly HttpResponseMessage _httpResponseMessage;
private readonly CancellationTokenSource _httpSendCts;
private string _contentType;
private long _contentLength;

public HttpResponseMessageHelper(HttpResponseMessage httpResponseMessage, HttpChannelFactory<IRequestChannel> factory)
public HttpResponseMessageHelper(HttpResponseMessage httpResponseMessage, HttpChannelFactory<IRequestChannel> factory, CancellationTokenSource httpSendCts = null)
{
Contract.Assert(httpResponseMessage != null);
Contract.Assert(httpResponseMessage.RequestMessage != null);
Expand All @@ -33,35 +47,53 @@ public HttpResponseMessageHelper(HttpResponseMessage httpResponseMessage, HttpCh
_httpRequestMessage = httpResponseMessage.RequestMessage;
_factory = factory;
_encoder = factory.MessageEncoderFactory.Encoder;
_httpSendCts = httpSendCts;
}

internal async Task<Message> ParseIncomingResponse(TimeoutHelper timeoutHelper)
{
ValidateAuthentication();
ValidateResponseStatusCode();
bool hasContent = await ValidateContentTypeAsync(timeoutHelper);
Message message = null;
// If we have an httpSendCts, register the timeout token to cancel it
// This allows both timeout and abort to cancel stream operations
CancellationTokenRegistration? timeoutCancellationRegistration = null;
if (_httpSendCts != null)
{
var timeoutToken = await timeoutHelper.GetCancellationTokenAsync();
timeoutCancellationRegistration = timeoutToken.UnsafeRegister(s_cancelCts, _httpSendCts);
}

if (!hasContent)
try
{
if (_encoder.MessageVersion == MessageVersion.None)
ValidateAuthentication();
ValidateResponseStatusCode();
bool hasContent = await ValidateContentTypeAsync(timeoutHelper);
Message message = null;

if (!hasContent)
{
message = new NullMessage();
if (_encoder.MessageVersion == MessageVersion.None)
{
message = new NullMessage();
}
else
{
return null;
}
}
else
{
return null;
message = await ReadStreamAsMessageAsync(timeoutHelper);
}

var exception = ProcessHttpAddressing(message);
Contract.Assert(exception == null, "ProcessHttpAddressing should not set an exception after parsing a response message.");

return message;
}
else
finally
{
message = await ReadStreamAsMessageAsync(timeoutHelper);
// Unregister the timeout callback to prevent memory leaks and avoid invoking the callback after the operation completes
timeoutCancellationRegistration?.Dispose();
}

var exception = ProcessHttpAddressing(message);
Contract.Assert(exception == null, "ProcessHttpAddressing should not set an exception after parsing a response message.");

return message;
}

private Exception ProcessHttpAddressing(Message message)
Expand Down Expand Up @@ -188,7 +220,7 @@ private async Task<Message> ReadChunkedBufferedMessageAsync(Task<Stream> inputSt
{
try
{
return await _encoder.ReadMessageAsync(await inputStreamTask, _factory.BufferManager, _factory.MaxBufferSize, _contentType, await timeoutHelper.GetCancellationTokenAsync());
return await _encoder.ReadMessageAsync(await inputStreamTask, _factory.BufferManager, _factory.MaxBufferSize, _contentType, await GetCancellationTokenAsync(timeoutHelper));
}
catch (XmlException xmlException)
{
Expand All @@ -212,7 +244,7 @@ private async Task<Message> ReadBufferedMessageAsync(Task<Stream> inputStreamTas
byte[] buffer = messageBuffer.Array;
int offset = 0;
int count = messageBuffer.Count;
var ct = await timeoutHelper.GetCancellationTokenAsync();
var ct = await GetCancellationTokenAsync(timeoutHelper);

while (count > 0)
{
Expand Down Expand Up @@ -273,7 +305,7 @@ private async Task<Message> DecodeBufferedMessageAsync(ArraySegment<byte> buffer
{
try
{
var ct = await timeoutHelper.GetCancellationTokenAsync();
var ct = await GetCancellationTokenAsync(timeoutHelper);
// if we're chunked, make sure we've consumed the whole body
if (_contentLength == -1 && buffer.Count == _factory.MaxReceivedMessageSize)
{
Expand Down Expand Up @@ -304,6 +336,18 @@ private async Task<Message> DecodeBufferedMessageAsync(ArraySegment<byte> buffer
}
}

private async Task<CancellationToken> GetCancellationTokenAsync(TimeoutHelper timeoutHelper)
{
// If no httpSendCts is provided, just use the timeout token
if (_httpSendCts == null)
{
return await timeoutHelper.GetCancellationTokenAsync();
}

// Use the _httpSendCts.Token for all operations
return _httpSendCts.Token;
}

private async Task<Stream> GetStreamAsync(TimeoutHelper timeoutHelper)
{
var content = _httpResponseMessage.Content;
Expand All @@ -313,7 +357,7 @@ private async Task<Stream> GetStreamAsync(TimeoutHelper timeoutHelper)
{
contentStream = await content.ReadAsStreamAsync();
_contentLength = content.Headers.ContentLength.HasValue ? content.Headers.ContentLength.Value : -1;
var cancellationToken = await timeoutHelper.GetCancellationTokenAsync();
var cancellationToken = await GetCancellationTokenAsync(timeoutHelper);
if (_contentLength <= 0)
{
var preReadBuffer = new byte[1];
Expand Down
Loading