Skip to content

Fix NativeAOT with minimal actions #35167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ namespace Microsoft.AspNetCore.Http
/// </summary>
public static partial class RequestDelegateFactory
{
private static readonly NullabilityInfoContext NullabilityContext = new NullabilityInfoContext();
private static readonly NullabilityInfoContext NullabilityContext = new();
private static readonly TryParseMethodCache TryParseMethodCache = new();

private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTask), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo ExecuteTaskOfStringMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!;
Expand All @@ -48,7 +49,6 @@ public static partial class RequestDelegateFactory
private static readonly ParameterExpression HttpContextExpr = Expression.Parameter(typeof(HttpContext), "httpContext");
private static readonly ParameterExpression BodyValueExpr = Expression.Parameter(typeof(object), "bodyValue");
private static readonly ParameterExpression WasParamCheckFailureExpr = Expression.Variable(typeof(bool), "wasParamCheckFailure");
private static readonly ParameterExpression TempSourceStringExpr = TryParseMethodCache.TempSourceStringExpr;

private static readonly MemberExpression RequestServicesExpr = Expression.Property(HttpContextExpr, nameof(HttpContext.RequestServices));
private static readonly MemberExpression HttpRequestExpr = Expression.Property(HttpContextExpr, nameof(HttpContext.Request));
Expand All @@ -61,8 +61,8 @@ public static partial class RequestDelegateFactory
private static readonly MemberExpression StatusCodeExpr = Expression.Property(HttpResponseExpr, nameof(HttpResponse.StatusCode));
private static readonly MemberExpression CompletedTaskExpr = Expression.Property(null, (PropertyInfo)GetMemberInfo<Func<Task>>(() => Task.CompletedTask));

private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TempSourceStringExpr, Expression.Constant(null));
private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null));
private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TryParseMethodCache.TempSourceStringExpr, Expression.Constant(null));
private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TryParseMethodCache.TempSourceStringExpr, Expression.Constant(null));

/// <summary>
/// Creates a <see cref="RequestDelegate"/> implementation for <paramref name="action"/>.
Expand Down Expand Up @@ -174,7 +174,7 @@ public static RequestDelegate Create(MethodInfo methodInfo, Func<HttpContext, ob

if (factoryContext.UsingTempSourceString)
{
responseWritingMethodCall = Expression.Block(new[] { TempSourceStringExpr }, responseWritingMethodCall);
responseWritingMethodCall = Expression.Block(new[] { TryParseMethodCache.TempSourceStringExpr }, responseWritingMethodCall);
}

return HandleRequestBodyAndCompileRequestDelegate(responseWritingMethodCall, factoryContext);
Expand Down Expand Up @@ -555,7 +555,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres
Expression.IfThen(Expression.Equal(argument, Expression.Constant(null)),
Expression.Block(
Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)),
Expression.Call(LogRequiredParameterNotProvidedMethod,
Expression.Call(LogRequiredParameterNotProvidedMethod,
HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name))
)
)
Expand Down Expand Up @@ -642,7 +642,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres
var failBlock = Expression.Block(
Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)),
Expression.Call(LogParameterBindingFailureMethod,
HttpContextExpr, parameterTypeNameConstant, parameterNameConstant, TempSourceStringExpr));
HttpContextExpr, parameterTypeNameConstant, parameterNameConstant, TryParseMethodCache.TempSourceStringExpr));

var tryParseCall = tryParseMethodCall(parsedValue);

Expand Down Expand Up @@ -681,14 +681,14 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres
var fullParamCheckBlock = !isOptional
? Expression.Block(
// tempSourceString = httpContext.RequestValue["id"];
Expression.Assign(TempSourceStringExpr, valueExpression),
Expression.Assign(TryParseMethodCache.TempSourceStringExpr, valueExpression),
// if (tempSourceString == null) { ... } only produced when parameter is required
checkRequiredParaseableParameterBlock,
// if (tempSourceString != null) { ... }
ifNotNullTryParse)
ifNotNullTryParse)
: Expression.Block(
// tempSourceString = httpContext.RequestValue["id"];
Expression.Assign(TempSourceStringExpr, valueExpression),
Expression.Assign(TryParseMethodCache.TempSourceStringExpr, valueExpression),
// if (tempSourceString != null) { ... }
ifNotNullTryParse);

Expand Down Expand Up @@ -740,7 +740,7 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al
Expression.Equal(argument, Expression.Constant(null)),
Expression.Block(
Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)),
Expression.Call(LogRequiredParameterNotProvidedMethod,
Expression.Call(LogRequiredParameterNotProvidedMethod,
HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name))
)
)
Expand Down Expand Up @@ -867,8 +867,8 @@ static async Task ExecuteAwaited(Task<string> task, HttpContext httpContext)

private static Task ExecuteWriteStringResponseAsync(HttpContext httpContext, string text)
{
SetPlaintextContentType(httpContext);
return httpContext.Response.WriteAsync(text);
SetPlaintextContentType(httpContext);
return httpContext.Response.WriteAsync(text);
}

private static Task ExecuteValueTask(ValueTask task)
Expand Down
71 changes: 62 additions & 9 deletions src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ public class TryParseMethodCacheTests
[InlineData(typeof(ulong))]
public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCulture(Type @type)
{
var methodFound = TryParseMethodCache.FindTryParseMethod(@type);
var methodFound = new TryParseMethodCache().FindTryParseMethod(@type);

Assert.NotNull(methodFound);

var call = methodFound!(Expression.Variable(type, "parsedValue"));
var parameters = call.Method.GetParameters();
var call = methodFound!(Expression.Variable(type, "parsedValue")) as MethodCallExpression;
Assert.NotNull(call);
var parameters = call!.Method.GetParameters();

Assert.Equal(4, parameters.Length);
Assert.Equal(typeof(string), parameters[0].ParameterType);
Expand All @@ -49,12 +50,13 @@ public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCult
[InlineData(typeof(TimeSpan))]
public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCultureDateType(Type @type)
{
var methodFound = TryParseMethodCache.FindTryParseMethod(@type);
var methodFound = new TryParseMethodCache().FindTryParseMethod(@type);

Assert.NotNull(methodFound);

var call = methodFound!(Expression.Variable(type, "parsedValue"));
var parameters = call.Method.GetParameters();
var call = methodFound!(Expression.Variable(type, "parsedValue")) as MethodCallExpression;
Assert.NotNull(call);
var parameters = call!.Method.GetParameters();

if (@type == typeof(TimeSpan))
{
Expand All @@ -77,12 +79,13 @@ public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCult
[InlineData(typeof(TryParsableInvariantRecord))]
public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCultureCustomType(Type @type)
{
var methodFound = TryParseMethodCache.FindTryParseMethod(@type);
var methodFound = new TryParseMethodCache().FindTryParseMethod(@type);

Assert.NotNull(methodFound);

var call = methodFound!(Expression.Variable(type, "parsedValue"));
var parameters = call.Method.GetParameters();
var call = methodFound!(Expression.Variable(type, "parsedValue")) as MethodCallExpression;
Assert.NotNull(call);
var parameters = call!.Method.GetParameters();

Assert.Equal(3, parameters.Length);
Assert.Equal(typeof(string), parameters[0].ParameterType);
Expand All @@ -91,6 +94,56 @@ public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCult
Assert.True(((call.Arguments[1] as ConstantExpression)!.Value as CultureInfo)!.Equals(CultureInfo.InvariantCulture));
}

[Fact]
public void FindTryParseMethodForEnums()
{
var type = typeof(Choice);
var methodFound = new TryParseMethodCache().FindTryParseMethod(type);

Assert.NotNull(methodFound);

var call = methodFound!(Expression.Variable(type, "parsedValue")) as MethodCallExpression;
Assert.NotNull(call);
var method = call!.Method;
var parameters = method.GetParameters();

// By default, we use Enum.TryParse<T>
Assert.True(method.IsGenericMethod);
Assert.Equal(2, parameters.Length);
Assert.Equal(typeof(string), parameters[0].ParameterType);
Assert.True(parameters[1].IsOut);
}

[Fact]
public void FindTryParseMethodForEnumsWhenNonGenericEnumParseIsUsed()
{
var type = typeof(Choice);
var cache = new TryParseMethodCache(preferNonGenericEnumParseOverload: true);
var methodFound = cache.FindTryParseMethod(type);

Assert.NotNull(methodFound);

var parsedValue = Expression.Variable(type, "parsedValue");
var block = methodFound!(parsedValue) as BlockExpression;
Assert.NotNull(block);
Assert.Equal(typeof(bool), block!.Type);

var parseEnum = Expression.Lambda<Func<string, Choice>>(Expression.Block(new[] { parsedValue },
block,
parsedValue), cache.TempSourceStringExpr).Compile();

Assert.Equal(Choice.One, parseEnum("One"));
Assert.Equal(Choice.Two, parseEnum("Two"));
Assert.Equal(Choice.Three, parseEnum("Three"));
}

enum Choice
{
One,
Two,
Three
}

private record TryParsableInvariantRecord(int value)
{
public static bool TryParse(string? value, IFormatProvider formatProvider, out TryParsableInvariantRecord? result)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ internal class EndpointMetadataApiDescriptionProvider : IApiDescriptionProvider
private readonly EndpointDataSource _endpointDataSource;
private readonly IHostEnvironment _environment;
private readonly IServiceProviderIsService? _serviceProviderIsService;
private readonly TryParseMethodCache TryParseMethodCache = new();

// Executes before MVC's DefaultApiDescriptionProvider and GrpcHttpApiDescriptionProvider for no particular reason.
public int Order => -1100;
Expand Down
90 changes: 73 additions & 17 deletions src/Shared/TryParseMethodCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,64 @@

namespace Microsoft.AspNetCore.Http
{
internal static class TryParseMethodCache
internal sealed class TryParseMethodCache
{
private static readonly MethodInfo EnumTryParseMethod = GetEnumTryParseMethod();
private readonly MethodInfo _enumTryParseMethod;

// Since this is shared source, the cache won't be shared between RequestDelegateFactory and the ApiDescriptionProvider sadly :(
private static readonly ConcurrentDictionary<Type, Func<Expression, MethodCallExpression>?> MethodCallCache = new();
internal static readonly ParameterExpression TempSourceStringExpr = Expression.Variable(typeof(string), "tempSourceString");
private readonly ConcurrentDictionary<Type, Func<Expression, Expression>?> _methodCallCache = new();

public static bool HasTryParseMethod(ParameterInfo parameter)
internal readonly ParameterExpression TempSourceStringExpr = Expression.Variable(typeof(string), "tempSourceString");

public TryParseMethodCache() : this(preferNonGenericEnumParseOverload: false)
{
}

// This is for testing
public TryParseMethodCache(bool preferNonGenericEnumParseOverload)
{
_enumTryParseMethod = GetEnumTryParseMethod(preferNonGenericEnumParseOverload);
}

public bool HasTryParseMethod(ParameterInfo parameter)
{
var nonNullableParameterType = Nullable.GetUnderlyingType(parameter.ParameterType) ?? parameter.ParameterType;
return FindTryParseMethod(nonNullableParameterType) is not null;
}

public static Func<Expression, MethodCallExpression>? FindTryParseMethod(Type type)
public Func<Expression, Expression>? FindTryParseMethod(Type type)
{
static Func<Expression, MethodCallExpression>? Finder(Type type)
Func<Expression, Expression>? Finder(Type type)
{
MethodInfo? methodInfo;

if (type.IsEnum)
{
methodInfo = EnumTryParseMethod.MakeGenericMethod(type);
if (methodInfo != null)
if (_enumTryParseMethod.IsGenericMethod)
{
methodInfo = _enumTryParseMethod.MakeGenericMethod(type);

return (expression) => Expression.Call(methodInfo!, TempSourceStringExpr, expression);
}

return null;
return (expression) =>
{
var enumAsObject = Expression.Variable(typeof(object), "enumAsObject");
var success = Expression.Variable(typeof(bool), "success");

// object enumAsObject;
// bool success;
// success = Enum.TryParse(type, tempSourceString, out enumAsObject);
// parsedValue = success ? (Type)enumAsObject : default;
// return success;

return Expression.Block(new[] { success, enumAsObject },
Expression.Assign(success, Expression.Call(_enumTryParseMethod, Expression.Constant(type), TempSourceStringExpr, enumAsObject)),
Expression.Assign(expression,
Expression.Condition(success, Expression.Convert(enumAsObject, type), Expression.Default(type))),
success);
};

}

if (TryGetDateTimeTryParseMethod(type, out methodInfo))
Expand Down Expand Up @@ -87,32 +116,59 @@ public static bool HasTryParseMethod(ParameterInfo parameter)
return null;
}

return MethodCallCache.GetOrAdd(type, Finder);
return _methodCallCache.GetOrAdd(type, Finder);
}

private static MethodInfo GetEnumTryParseMethod()
private static MethodInfo GetEnumTryParseMethod(bool preferNonGenericEnumParseOverload)
{
var staticEnumMethods = typeof(Enum).GetMethods(BindingFlags.Public | BindingFlags.Static);

// With NativeAOT, if there's no static usage of Enum.TryParse<T>, it will be removed
// we fallback to the non-generic version if that is the case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason the non-generic one does not get trimmed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume it's because something else uses it. I have no idea what.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we force it to be kept (e.g using DynamicDependency)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it's needed TBH, I think typeof(Enum).GetMethods(BindingFlags.Public | BindingFlags.Static) is visible to the linker but I'd need to make sure.

cc @MichalStrehovsky @eerhardt

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is surprising to me.

var staticEnumMethods = typeof(Enum).GetMethods(BindingFlags.Public | BindingFlags.Static);

That should "mark" all public static methods on Enum, and the AOT compiler shouldn't trim any of them.

Is it possible to build up the GetMethod query correctly so you only select the one method you are looking for?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typeof(Enum).GetMethods(BindingFlags.Public | BindingFlags.Static) is visible to the linker but I'd need to make sure.

Yes. It is a bug in NativeAOT: dotnet/runtimelab#1402

Once this bug gets fixed, this will be broken again.

FWIW, I think fallbacks like this are wrong way to fix trimming and AOT friendliness issues.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realize this is going to return TryParse<T> if that method got reflection-rooted through other means. The fact that it's not just rooted from GetMethod is kind of a bug but it only happens because the compiler figured there's not much it can do there - by default the AOT compiler will try to enable metadata for the method and generate a shared instantiation (over reference types), but the constraints on TryParse<T> rule out reference types. Right now, because it can't make a usable method body, it will not generate the metadata either. But we can fix it to create the metadata.

Yes, this is likely going to fail at MakeGenericMethod later.

Copy link
Member Author

@davidfowl davidfowl Aug 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this bug gets fixed, this will be broken again.

Thanks. I'll look into that.

FWIW, I think fallbacks like this are wrong way to fix trimming and AOT friendliness issues.

It just needs to work well enough (or barely work) to enable people to try it out and give feedback. Seems like we need something a bit more creative here (like a check for IsDynamicCodeCompiled maybe). Remember it's using the expression tree interpreter here so the bar is low 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsDynamicCodeSupported is what I did in the DataContractSerializer as a temporary workaround while we figure things out: https://github.com/dotnet/runtimelab/pull/692/files#diff-5c36d3b2cc664a8e51d53ac77390c98eca051ca63b13443233afc38ccee5e65fR60-R64

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is what I'll do. FWIW I have a pretty good handle on most of the AOT friendly problems that will plague ASP.NET Core (like JSON 😄 ) and have prototypes for how to tackle most of them.

MethodInfo? genericCandidate = null;
MethodInfo? nonGenericCandidate = null;

foreach (var method in staticEnumMethods)
{
if (!method.IsGenericMethod || method.Name != nameof(Enum.TryParse) || method.ReturnType != typeof(bool))
if (method.Name != nameof(Enum.TryParse) || method.ReturnType != typeof(bool))
{
continue;
}

var tryParseParameters = method.GetParameters();

if (tryParseParameters.Length == 2 &&
// Enum.TryParse<T>(string, out object)
if (method.IsGenericMethod &&
tryParseParameters.Length == 2 &&
tryParseParameters[0].ParameterType == typeof(string) &&
tryParseParameters[1].IsOut)
{
return method;
genericCandidate = method;
}

// Enum.TryParse(type, string, out object)
if (!method.IsGenericMethod &&
tryParseParameters.Length == 3 &&
tryParseParameters[0].ParameterType == typeof(Type) &&
tryParseParameters[1].ParameterType == typeof(string) &&
tryParseParameters[2].IsOut)
{
nonGenericCandidate = method;
}
}

if (genericCandidate is null && nonGenericCandidate is null)
{
Debug.Fail("No suitable System.Enum.TryParse method found.");
throw new MissingMethodException("No suitable System.Enum.TryParse method found.");
}

if (preferNonGenericEnumParseOverload)
{
return nonGenericCandidate!;
}

Debug.Fail("static bool System.Enum.TryParse<TEnum>(string? value, out TEnum result) not found.");
throw new Exception("static bool System.Enum.TryParse<TEnum>(string? value, out TEnum result) not found.");
return genericCandidate ?? nonGenericCandidate!;
}

private static bool TryGetDateTimeTryParseMethod(Type type, [NotNullWhen(true)] out MethodInfo? methodInfo)
Expand Down