Skip to content

Commit

Permalink
Fixes dotnet#5105
Browse files Browse the repository at this point in the history
Add support for HttpRequestMessage objects containing StreamContent to
the AddStandardHedgingHandler() resilience API.

This change does not update any public API contracts. It updates
internal and private API contracts only.

Link to issue: dotnet#5105
  • Loading branch information
Adam Hammond committed Apr 16, 2024
1 parent 5fc05f8 commit 51c0c53
Show file tree
Hide file tree
Showing 5 changed files with 316 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.IO;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
Expand All @@ -25,10 +26,15 @@ protected override async ValueTask<Outcome<TResult>> ExecuteCore<TResult, TState
Throw.InvalidOperationException("The HTTP request message was not found in the resilience context.");
}

using var snapshot = RequestMessageSnapshot.Create(request);

context.Properties.Set(ResilienceKeys.RequestSnapshot, snapshot);

return await callback(context, state).ConfigureAwait(context.ContinueOnCapturedContext);
try
{
using var snapshot = await RequestMessageSnapshot.CreateAsync(request).ConfigureAwait(context.ContinueOnCapturedContext);
context.Properties.Set(ResilienceKeys.RequestSnapshot, snapshot);
return await callback(context, state).ConfigureAwait(context.ContinueOnCapturedContext);
}
catch (IOException e)
{
return Outcome.FromException<TResult>(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

using System;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Http.Resilience;
Expand Down Expand Up @@ -88,26 +90,48 @@ public static IStandardHedgingHandlerBuilder AddStandardHedgingHandler(this IHtt
Throw.InvalidOperationException("Request message snapshot is not attached to the resilience context.");
}
var requestMessage = snapshot.CreateRequestMessage();
// The secondary request message should use the action resilience context
requestMessage.SetResilienceContext(args.ActionContext);
// replace the request message
args.ActionContext.Properties.Set(ResilienceKeys.RequestMessage, requestMessage);
// if a routing strategy has been configured but it does not return the next route, then no more routes
// are availabe, stop hedging
Uri? route;
if (args.PrimaryContext.Properties.TryGetValue(ResilienceKeys.RoutingStrategy, out var routingPipeline))
{
if (!routingPipeline.TryGetNextRoute(out var route))
if (!routingPipeline.TryGetNextRoute(out route))
{
// no routes left, stop hedging
return null;
}
requestMessage.RequestUri = requestMessage.RequestUri!.ReplaceHost(route);
}
else
{
route = null;
}
return async () =>
{
Outcome<HttpResponseMessage>? actionResult = null;
try
{
var requestMessage = await snapshot.CreateRequestMessageAsync().ConfigureAwait(false);
// The secondary request message should use the action resilience context
requestMessage.SetResilienceContext(args.ActionContext);
// replace the request message
args.ActionContext.Properties.Set(ResilienceKeys.RequestMessage, requestMessage);
if (route != null)
{
// replace the RequestUri of the request per the routing strategy
requestMessage.RequestUri = requestMessage.RequestUri!.ReplaceHost(route);
}
}
catch (IOException e)
{
actionResult = Outcome.FromException<HttpResponseMessage>(e);
}
return () => args.Callback(args.ActionContext);
return actionResult ?? await args.Callback(args.ActionContext).ConfigureAwait(args.ActionContext.ContinueOnCapturedContext);
};
};
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

using System;
using System.Collections.Generic;
using System.IO;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.Extensions.ObjectPool;
using Microsoft.Shared.Diagnostics;
using Microsoft.Shared.Pools;
Expand All @@ -22,21 +24,40 @@ internal sealed class RequestMessageSnapshot : IResettable, IDisposable
private Version? _version;
private HttpContent? _content;

public static RequestMessageSnapshot Create(HttpRequestMessage request)
[System.Diagnostics.CodeAnalysis.SuppressMessage("Resilience", "EA0014:The async method doesn't support cancellation", Justification = "Past the point of no cancellation.")]
public static async Task<RequestMessageSnapshot> CreateAsync(HttpRequestMessage request)
{
_ = Throw.IfNull(request);

var snapshot = _snapshots.Get();
snapshot.Initialize(request);
await snapshot.InitializeAsync(request).ConfigureAwait(false);
return snapshot;
}

public HttpRequestMessage CreateRequestMessage()
[System.Diagnostics.CodeAnalysis.SuppressMessage("Resilience", "EA0014:The async method doesn't support cancellation", Justification = "Past the point of no cancellation.")]
public async Task<HttpRequestMessage> CreateRequestMessageAsync()
{
if (IsReset())
{
throw new InvalidOperationException($"{nameof(CreateRequestMessageAsync)}() cannot be called on a snapshot object that has been reset and has not been initialized");
}

var clone = new HttpRequestMessage(_method!, _requestUri)
{
Content = _content,
Version = _version!
};

if (_content is StreamContent)
{
(HttpContent? content, HttpContent? clonedContent) = await CloneContentAsync(_content).ConfigureAwait(false);
_content = content;
clone.Content = clonedContent;
}
else
{
clone.Content = _content;
}

#if NET5_0_OR_GREATER
foreach (var prop in _properties)
{
Expand All @@ -56,6 +77,7 @@ public HttpRequestMessage CreateRequestMessage()
return clone;
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Critical Bug", "S2952:Classes should \"Dispose\" of members from the classes' own \"Dispose\" methods", Justification = "Handled by ObjectPool")]
bool IResettable.TryReset()
{
_properties.Clear();
Expand All @@ -64,24 +86,76 @@ bool IResettable.TryReset()
_method = null;
_version = null;
_requestUri = null;
if (_content is StreamContent)
{
// a snapshot's StreamContent is always a unique copy (deep clone)
// therefore, it is safe to dispose when snapshot is no longer needed
_content.Dispose();
}

_content = null;

return true;
}

void IDisposable.Dispose() => _snapshots.Return(this);

private void Initialize(HttpRequestMessage request)
[System.Diagnostics.CodeAnalysis.SuppressMessage("Resilience", "EA0014:The async method doesn't support cancellation", Justification = "Past the point of no cancellation.")]
private static async Task<(HttpContent? content, HttpContent? clonedContent)> CloneContentAsync(HttpContent? content)
{
if (request.Content is StreamContent)
HttpContent? clonedContent = null;
if (content != null)
{
Throw.InvalidOperationException($"{nameof(StreamContent)} content cannot by cloned.");
HttpContent originalContent = content;
Stream originalRequestBody = await content.ReadAsStreamAsync().ConfigureAwait(false);
MemoryStream clonedRequestBody = new MemoryStream();
await originalRequestBody.CopyToAsync(clonedRequestBody).ConfigureAwait(false);
clonedRequestBody.Position = 0;
if (originalRequestBody.CanSeek)
{
originalRequestBody.Position = 0;
}
else
{
originalRequestBody = new MemoryStream();
await clonedRequestBody.CopyToAsync(originalRequestBody).ConfigureAwait(false);
originalRequestBody.Position = 0;
clonedRequestBody.Position = 0;
}

clonedContent = new StreamContent(clonedRequestBody);
content = new StreamContent(originalRequestBody);
foreach (KeyValuePair<string, IEnumerable<string>> header in originalContent.Headers)
{
_ = clonedContent.Headers.TryAddWithoutValidation(header.Key, header.Value);
_ = content.Headers.TryAddWithoutValidation(header.Key, header.Value);
}
}

return (content, clonedContent);
}

private bool IsReset()
{
return _method == null;
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Resilience", "EA0014:The async method doesn't support cancellation", Justification = "Past the point of no cancellation.")]
private async Task InitializeAsync(HttpRequestMessage request)
{
_method = request.Method;
_version = request.Version;
_requestUri = request.RequestUri;
_content = request.Content;
if (request.Content is StreamContent)
{
(HttpContent? requestContent, HttpContent? clonedRequestContent) = await CloneContentAsync(request.Content).ConfigureAwait(false);
_content = clonedRequestContent;
request.Content = requestContent;
}
else
{
_content = request.Content;
}

// headers
_headers.AddRange(request.Headers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public void Configure_ValidConfigurationSection_ShouldInitialize()
}

[Fact]
public void ActionGenerator_Ok()
public async Task ActionGenerator_Ok()
{
var options = Builder.Services.BuildServiceProvider().GetRequiredService<IOptionsMonitor<HttpStandardHedgingResilienceOptions>>().Get(Builder.Name);
var generator = options.Hedging.ActionGenerator;
Expand All @@ -115,7 +115,7 @@ public void ActionGenerator_Ok()
generator.Invoking(g => g(args)).Should().Throw<InvalidOperationException>().WithMessage("Request message snapshot is not attached to the resilience context.");

using var request = new HttpRequestMessage();
using var snapshot = RequestMessageSnapshot.Create(request);
using var snapshot = await RequestMessageSnapshot.CreateAsync(request).ConfigureAwait(false);
primary.Properties.Set(ResilienceKeys.RequestSnapshot, snapshot);
generator.Invoking(g => g(args)).Should().NotThrow();
}
Expand Down
Loading

0 comments on commit 51c0c53

Please sign in to comment.