Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using System.Runtime.CompilerServices;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Abstractions;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.AspNetCore.Builder;

Expand All @@ -21,6 +21,7 @@ public static class UseMiddlewareExtensions
internal const string InvokeAsyncMethodName = "InvokeAsync";

private static readonly MethodInfo GetServiceInfo = typeof(UseMiddlewareExtensions).GetMethod(nameof(GetService), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo GetKeyedServiceInfo = typeof(UseMiddlewareExtensions).GetMethod(nameof(GetKeyedService), BindingFlags.NonPublic | BindingFlags.Static)!;

// We're going to keep all public constructors and public methods on middleware
private const DynamicallyAccessedMemberTypes MiddlewareAccessibility =
Expand Down Expand Up @@ -209,19 +210,70 @@ private static Func<T, HttpContext, IServiceProvider, Task> ReflectionFallback<T
}
}

// Performance optimization: Precompute and cache the key results for each parameter
var precomputedKeys = new object?[parameters.Length];
for (var i = 1; i < parameters.Length; i++)
{
var hasServiceKey = TryGetServiceKey(parameters[i], out object? key);
precomputedKeys[i] = key;
}

return (middleware, context, serviceProvider) =>
{
var methodArguments = new object[parameters.Length];
methodArguments[0] = context;
for (var i = 1; i < parameters.Length; i++)
{
methodArguments[i] = GetService(serviceProvider, parameters[i].ParameterType, methodInfo.DeclaringType!);
var key = precomputedKeys[i];
var parameterType = parameters[i].ParameterType;
var declaringType = methodInfo.DeclaringType!;

methodArguments[i] = key == null ? GetService(serviceProvider, parameterType, declaringType) : GetKeyedService(serviceProvider, key, parameterType, declaringType);
}

return (Task)methodInfo.Invoke(middleware, BindingFlags.DoNotWrapExceptions, binder: null, methodArguments, culture: null)!;
};
}

private static bool TryGetServiceKey(ParameterInfo parameterInfo, out object? key)
{
if (parameterInfo.CustomAttributes != null)
{
foreach (var attribute in parameterInfo.GetCustomAttributes(true))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameterInfo.OfType<FromKeyedServicesAttribute>().FirstOrDefault() might be more straightforward here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@captainsafia, @BrennanConroy, the changes within this pull-request target main, and thus the .NET 9.0 version. If I'd want the changes to be available in the latest .NET 8.x.x version as well, what would the appropriate steps for me to take?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@captainsafia, @BrennanConroy, any follow-up on my question above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a backport candidate. This change is a feature, not fixing a major bug/regression introduced in 8.0.

{
if (attribute is FromKeyedServicesAttribute keyed)
{
key = keyed.Key;

return true;
}
}
}

key = null;

return false;
}

private static UnaryExpression GetMethodArgument(ParameterInfo parameter, ParameterExpression providerArg, Type parameterType, Type? declaringType)
{
var parameterTypeExpression = new List<Expression>() { providerArg };
var hasServiceKey = TryGetServiceKey(parameter, out object? key);

if (hasServiceKey)
{
parameterTypeExpression.Add(Expression.Constant(key, typeof(object)));
}

parameterTypeExpression.Add(Expression.Constant(parameterType, typeof(Type)));
parameterTypeExpression.Add(Expression.Constant(declaringType, typeof(Type)));

var getServiceCall = Expression.Call(hasServiceKey ? GetKeyedServiceInfo : GetServiceInfo, parameterTypeExpression);
var methodArgument = Expression.Convert(getServiceCall, parameterType);

return methodArgument;
}

private static Func<T, HttpContext, IServiceProvider, Task> CompileExpression<T>(MethodInfo methodInfo, ParameterInfo[] parameters)
{
Debug.Assert(RuntimeFeature.IsDynamicCodeSupported, "Use compiled expression when dynamic code is supported.");
Expand Down Expand Up @@ -262,21 +314,14 @@ private static Func<T, HttpContext, IServiceProvider, Task> CompileExpression<T>
methodArguments[0] = httpContextArg;
for (var i = 1; i < parameters.Length; i++)
{
var parameterType = parameters[i].ParameterType;
var parameter = parameters[i];
var parameterType = parameter.ParameterType;
if (parameterType.IsByRef)
{
throw new NotSupportedException(Resources.FormatException_InvokeDoesNotSupportRefOrOutParams(InvokeMethodName));
}

var parameterTypeExpression = new Expression[]
{
providerArg,
Expression.Constant(parameterType, typeof(Type)),
Expression.Constant(methodInfo.DeclaringType, typeof(Type))
};

var getServiceCall = Expression.Call(GetServiceInfo, parameterTypeExpression);
methodArguments[i] = Expression.Convert(getServiceCall, parameterType);
methodArguments[i] = GetMethodArgument(parameter, providerArg, parameterType, methodInfo.DeclaringType);
}

Expression middlewareInstanceArg = instanceArg;
Expand All @@ -294,12 +339,20 @@ private static Func<T, HttpContext, IServiceProvider, Task> CompileExpression<T>

private static object GetService(IServiceProvider sp, Type type, Type middleware)
{
var service = sp.GetService(type);
if (service == null)
var service = sp.GetService(type) ?? throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));

return service;
}

private static object GetKeyedService(IServiceProvider sp, object key, Type type, Type middleware)
{
if (sp is IKeyedServiceProvider ksp)
{
throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));
var service = ksp.GetKeyedService(type, key) ?? throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));

return service;
}

return service;
throw new InvalidOperationException(Resources.Exception_KeyedServicesNotSupported);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ Microsoft.AspNetCore.Http.HttpResponse</Description>
<Reference Include="Microsoft.AspNetCore.Http.Features" />
<Reference Include="Microsoft.Net.Http.Headers" />
<Reference Include="Microsoft.Extensions.Logging.Abstractions" />
<Reference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />

<Compile Include="$(SharedSourceRoot)ActivatorUtilities\*.cs" />
<Compile Include="$(SharedSourceRoot)ParameterDefaultValue\*.cs" />
<Compile Include="$(SharedSourceRoot)PropertyHelper\**\*.cs" />
<Compile Include="$(SharedSourceRoot)\UrlDecoder\UrlDecoder.cs" Link="UrlDecoder.cs" />
Expand Down
6 changes: 6 additions & 0 deletions src/Http/Http.Abstractions/src/Resources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,10 @@
<data name="RouteValueDictionary_DuplicatePropertyName" xml:space="preserve">
<value>The type '{0}' defines properties '{1}' and '{2}' which differ only by casing. This is not supported by {3} which uses case-insensitive comparisons.</value>
</data>
<data name="Exception_KeyedServicesNotSupported" xml:space="preserve">
<value>This service provider doesn't support keyed services.</value>
</data>
<data name="Exception_NoServiceRegistered" xml:space="preserve">
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, there is this one too 😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hehe, don't sweat it, should've checked it. Will run by all resources 👍

<value>No service for type '{0}' has been registered.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<Reference Include="Microsoft.AspNetCore.Routing" />
<Reference Include="Microsoft.AspNetCore.TestHost" />
<Reference Include="Mono.TextTemplating" />
<Reference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />
</ItemGroup>

<ItemGroup>
Expand Down
100 changes: 100 additions & 0 deletions src/Http/Http.Abstractions/test/UseMiddlewareTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http.Abstractions;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.AspNetCore.Http;

Expand Down Expand Up @@ -130,6 +131,32 @@ public async Task UseMiddleware_ThrowsIfArgCantBeResolvedFromContainer()
exception.Message);
}

[Fact]
public async Task UseMiddleware_ThrowsIfKeyedArgCantBeResolvedFromContainer()
{
var builder = new ApplicationBuilder(new DummyKeyedServiceProvider());
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvokeNoService));
var app = builder.Build();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => app(new DefaultHttpContext()));
Assert.Equal(
Resources.FormatException_InvokeMiddlewareNoService(
typeof(object),
typeof(MiddlewareKeyedInjectInvokeNoService)),
exception.Message);
}

[Fact]
public async Task UseMiddleware_ThrowsIfServiceProviderIsNotAIKeyedServiceProvider()
{
var builder = new ApplicationBuilder(new DummyServiceProvider());
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvokeNoService));
var app = builder.Build();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => app(new DefaultHttpContext()));
Assert.Equal(
Resources.Exception_KeyedServicesNotSupported,
exception.Message);
}

[Fact]
public void UseMiddlewareWithInvokeArg()
{
Expand All @@ -139,6 +166,17 @@ public void UseMiddlewareWithInvokeArg()
app(new DefaultHttpContext());
}

[Fact]
public void UseMiddlewareWithInvokeKeyedArg()
{
var keyedServiceProvider = new DummyKeyedServiceProvider();
keyedServiceProvider.AddKeyedService("test", typeof(DummyKeyedServiceProvider), keyedServiceProvider);
var builder = new ApplicationBuilder(keyedServiceProvider);
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvoke));
var app = builder.Build();
app(new DefaultHttpContext());
}

[Fact]
public void UseMiddlewareWithInvokeWithOutAndRefThrows()
{
Expand Down Expand Up @@ -274,6 +312,54 @@ private class DummyServiceProvider : IServiceProvider
}
}

private class DummyKeyedServiceProvider : IKeyedServiceProvider
{
private readonly Dictionary<object, Tuple<Type, object>> _services = new Dictionary<object, Tuple<Type, object>>();

public DummyKeyedServiceProvider()
{

}

public void AddKeyedService(object key, Type type, object value) => _services[key] = new Tuple<Type, object>(type, value);

public object? GetKeyedService(Type serviceType, object? serviceKey)
{
if (_services.TryGetValue(serviceKey!, out var value))
{
return value.Item2;
}

return null;
}

public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
{
var service = GetKeyedService(serviceType, serviceKey);

if (service == null)
{
throw new InvalidOperationException(Resources.FormatException_NoServiceRegistered(serviceType));
}

return service;
}

public object? GetService(Type serviceType)
{
if (serviceType == typeof(IServiceProvider))
{
return this;
}

if (_services.TryGetValue(serviceType, out var value))
{
return value;
}
return null;
}
}

public class MiddlewareInjectWithOutAndRefParams
{
public MiddlewareInjectWithOutAndRefParams(RequestDelegate next) { }
Expand All @@ -293,13 +379,27 @@ public MiddlewareInjectInvokeNoService(RequestDelegate next) { }
public Task Invoke(HttpContext context, object value) => Task.CompletedTask;
}

private class MiddlewareKeyedInjectInvokeNoService
{
public MiddlewareKeyedInjectInvokeNoService(RequestDelegate next) { }

public Task Invoke(HttpContext context, [FromKeyedServices("test")] object value) => Task.CompletedTask;
}

private class MiddlewareInjectInvoke
{
public MiddlewareInjectInvoke(RequestDelegate next) { }

public Task Invoke(HttpContext context, IServiceProvider provider) => Task.CompletedTask;
}

private class MiddlewareKeyedInjectInvoke
{
public MiddlewareKeyedInjectInvoke(RequestDelegate next) { }

public Task Invoke(HttpContext context, [FromKeyedServices("test")] IKeyedServiceProvider provider) => Task.CompletedTask;
}

private class MiddlewareNoParametersStub
{
public MiddlewareNoParametersStub(RequestDelegate next) { }
Expand Down
25 changes: 23 additions & 2 deletions src/Http/Http.Abstractions/test/UsePathBaseExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ public IServiceProvider ApplicationServices
public IFeatureCollection ServerFeatures => _wrappedBuilder.ServerFeatures;
public RequestDelegate Build() => _wrappedBuilder.Build();
public IApplicationBuilder New() => _wrappedBuilder.New();

}

[Theory]
Expand Down Expand Up @@ -238,6 +237,28 @@ private static HttpContext CreateRequest(string pathBase, string requestPath)

private static ApplicationBuilder CreateBuilder()
{
return new ApplicationBuilder(serviceProvider: null!);
return new ApplicationBuilder(new DummyServiceProvider());
}

private class DummyServiceProvider : IServiceProvider
{
private readonly Dictionary<Type, object> _services = new Dictionary<Type, object>();

public void AddService(Type type, object value) => _services[type] = value;

public object? GetService(Type serviceType)
{
if (serviceType == typeof(IServiceProvider))
{
return this;
}

if (_services.TryGetValue(serviceType, out var value))
{
return value;
}

return null;
}
}
}