Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions src/EFCore/Query/Internal/ParameterExtractingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ public class ParameterExtractingExpressionVisitor : ExpressionVisitor
private static readonly bool UseOldBehavior35100 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35100", out var enabled35100) && enabled35100;

private static readonly bool UseOldBehavior37176 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue37176", out var enabled37176) && enabled37176;
Comment on lines +37 to +38
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you just reuse the above one? Seems like overkill to add a switch for each method

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well technically these are two separate changes... this new one only fixes additional recognition for MemoryExtensions.Contains, and in theory could be problematic, causing users to need to roll back without also rolling back the general MemoryExtensions.Contains support.

I tend not to worry too much about quirks (they're in servicing branches only), but if you really prefer it I'll remove the new switch and keep only the existing one for MemoryExtensions.Contains.


/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -210,15 +213,32 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
switch (method.Name)
{
case nameof(MemoryExtensions.Contains)
when methodCallExpression.Arguments is [var arg0, var arg1] &&
TryUnwrapSpanImplicitCast(arg0, out var unwrappedArg0):
when UseOldBehavior37176
&& methodCallExpression.Arguments is [var arg0, var arg1]
&& TryUnwrapSpanImplicitCast(arg0, out var unwrappedArg0):
{
return Visit(
Expression.Call(
EnumerableMethods.Contains.MakeGenericMethod(method.GetGenericArguments()[0]),
unwrappedArg0, arg1));
}

// In .NET 10, MemoryExtensions.Contains has an overload that accepts a third, optional comparer, in addition to the older
// overload that accepts two parameters only.
case nameof(MemoryExtensions.Contains)
when !UseOldBehavior37176
&& methodCallExpression.Arguments is [var spanArg, var valueArg, ..]
&& (methodCallExpression.Arguments.Count is 2
|| methodCallExpression.Arguments.Count is 3
&& methodCallExpression.Arguments[2] is ConstantExpression { Value: null })
&& TryUnwrapSpanImplicitCast(spanArg, out var unwrappedSpanArg):
{
return Visit(
Expression.Call(
EnumerableMethods.Contains.MakeGenericMethod(method.GetGenericArguments()[0]),
unwrappedSpanArg, valueArg));
}

case nameof(MemoryExtensions.SequenceEqual)
when methodCallExpression.Arguments is [var arg0, var arg1]
&& TryUnwrapSpanImplicitCast(arg0, out var unwrappedArg0)
Expand All @@ -231,20 +251,37 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

static bool TryUnwrapSpanImplicitCast(Expression expression, [NotNullWhen(true)] out Expression? result)
{
if (expression is MethodCallExpression
switch (expression)
{
// With newer versions of the SDK, the implicit cast is represented as a MethodCallExpression;
// with older versions, it's a Convert node.
case MethodCallExpression
{
Method: { Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType },
Arguments: [var unwrapped]
} when implicitCastDeclaringType.GetGenericTypeDefinition() is var genericTypeDefinition
&& (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>)):
{
result = unwrapped;
return true;
}
&& implicitCastDeclaringType.GetGenericTypeDefinition() is var genericTypeDefinition
&& (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>)))
{
result = unwrapped;
return true;
}

result = null;
return false;
case UnaryExpression
{
NodeType: ExpressionType.Convert,
Operand: var unwrapped,
Type: { IsGenericType: true } convertType
} when !UseOldBehavior37176 && convertType.GetGenericTypeDefinition() is var genericTypeDefinition
&& (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>)):
{
result = unwrapped;
return true;
}

default:
result = null;
return false;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,34 @@ public virtual Task Column_collection_of_bools_Contains(bool async)
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => c.Bools.Contains(true)));

// C# 14 first-class spans caused MemoryExtensions.Contains to get resolved instead of Enumerable.Contains.
// The following tests that the various overloads are all supported.
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_on_Enumerable(bool async)
=> AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => Enumerable.Contains(new[] { 10, 999 }, c.Int)));

// C# 14 first-class spans caused MemoryExtensions.Contains to get resolved instead of Enumerable.Contains.
// The following tests that the various overloads are all supported.
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_on_MemoryExtensions(bool async)
=> AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => MemoryExtensions.Contains(new[] { 10, 999 }, c.Int)));

// Note that we don't test EF 8/9 with .NET 10; this test is here for completeness/documentation purposes.
#if NET10_0_OR_GREATER
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_with_MemoryExtensions_with_null_comparer(bool async)
=> AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => MemoryExtensions.Contains(new[] { 10, 999 }, c.Int, comparer: null)));
#endif

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Column_collection_Count_method(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,45 @@ await context.Database.SqlQuery<string>($"SELECT [Bools] AS [Value] FROM [Primit
.SingleAsync());
}

public override async Task Contains_on_Enumerable(bool async)
{
await base.Contains_on_Enumerable(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (10, 999)
""");
}


public override async Task Contains_on_MemoryExtensions(bool async)
{
await base.Contains_on_MemoryExtensions(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (10, 999)
""");
}

#if NET10_0_OR_GREATER
public override async Task Contains_with_MemoryExtensions_with_null_comparer(bool async)
{
await base.Contains_with_MemoryExtensions_with_null_comparer(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (10, 999)
""");
}
#endif

public override Task Column_collection_Count_method(bool async)
=> AssertCompatibilityLevelTooLow(() => base.Column_collection_Count_method(async));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,45 @@ await context.Database.SqlQuery<string>($"SELECT [Bools] AS [Value] FROM [Primit
.SingleAsync());
}

public override async Task Contains_on_Enumerable(bool async)
{
await base.Contains_on_Enumerable(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (10, 999)
""");
}


public override async Task Contains_on_MemoryExtensions(bool async)
{
await base.Contains_on_MemoryExtensions(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (10, 999)
""");
}

#if NET10_0_OR_GREATER
public override async Task Contains_with_MemoryExtensions_with_null_comparer(bool async)
{
await base.Contains_with_MemoryExtensions_with_null_comparer(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (10, 999)
""");
}
#endif

public override async Task Column_collection_Count_method(bool async)
{
await base.Column_collection_Count_method(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,45 @@ FROM json_each("p"."Bools") AS "b"
""");
}

public override async Task Contains_on_Enumerable(bool async)
{
await base.Contains_on_Enumerable(async);

AssertSql(
"""
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" IN (10, 999)
""");
}


public override async Task Contains_on_MemoryExtensions(bool async)
{
await base.Contains_on_MemoryExtensions(async);

AssertSql(
"""
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" IN (10, 999)
""");
}

#if NET10_0_OR_GREATER
public override async Task Contains_with_MemoryExtensions_with_null_comparer(bool async)
{
await base.Contains_with_MemoryExtensions_with_null_comparer(async);

AssertSql(
"""
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" IN (10, 999)
""");
}
#endif

public override async Task Column_collection_Count_method(bool async)
{
await base.Column_collection_Count_method(async);
Expand Down