Skip to content

Commit

Permalink
Query: Optimize Contains to Any recursively (#26599)
Browse files Browse the repository at this point in the history
- Also unwrap convert node around projection when translating Contains

Resolves #26593
  • Loading branch information
smitpatel authored Nov 11, 2021
1 parent cb6d523 commit b38c269
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,20 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent
selectExpression.ClearOrdering();
}

if (source.ShaperExpression is ProjectionBindingExpression projectionBindingExpression)
var shaperExpression = source.ShaperExpression;
if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue26593", out var enabled)
&& enabled))
{
// No need to check ConvertChecked since this is convert node which we may have added during projection
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
&& unaryExpression.Operand.Type.IsNullableType()
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
{
shaperExpression = unaryExpression.Operand;
}
}

if (shaperExpression is ProjectionBindingExpression projectionBindingExpression)
{
var projection = selectExpression.GetProjection(projectionBindingExpression);
if (projection is SqlExpression sqlExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
anyLambdaParameter,
methodCallExpression.Arguments[1]),
anyLambdaParameter);
var source = methodCallExpression.Arguments[0];
if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue26593", out var enabled) && enabled))
{
source = Visit(source);
}

return Expression.Call(null, anyMethod, new[] { methodCallExpression.Arguments[0], anyLambda });
return Expression.Call(null, anyMethod, new[] { source, anyLambda });
}

var @object = default(Expression);
Expand Down
131 changes: 130 additions & 1 deletion test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ public virtual async Task IsDeleted_query_filter_with_conversion_to_int_works(bo
protected class Context26428 : DbContext
{
public Context26428(DbContextOptions options)
: base(options)
: base(options)
{
}

Expand Down Expand Up @@ -422,5 +422,134 @@ protected class Location
}

#nullable disable

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery(bool async)
{
var contextFactory = await InitializeAsync<Context26593>(seed: c => c.Seed());
using var context = contextFactory.CreateContext();

var currentUserId = 1;

var currentUserGroupIds = context.Memberships
.Where(m => m.UserId == currentUserId)
.Select(m => m.GroupId);

var hasMembership = context.Memberships
.Where(m => currentUserGroupIds.Contains(m.GroupId))
.Select(m => m.User);

var query = context.Users
.Select(u => new
{
HasAccess = hasMembership.Contains(u)
});

var users = async
? await query.ToListAsync()
: query.ToList();
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_2(bool async)
{
var contextFactory = await InitializeAsync<Context26593>(seed: c => c.Seed());
using var context = contextFactory.CreateContext();

var currentUserId = 1;

var currentUserGroupIds = context.Memberships
.Where(m => m.UserId == currentUserId)
.Select(m => m.Group);

var hasMembership = context.Memberships
.Where(m => currentUserGroupIds.Contains(m.Group))
.Select(m => m.User);

var query = context.Users
.Select(u => new
{
HasAccess = hasMembership.Contains(u)
});

var users = async
? await query.ToListAsync()
: query.ToList();
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_3(bool async)
{
var contextFactory = await InitializeAsync<Context26593>(seed: c => c.Seed());
using var context = contextFactory.CreateContext();

var currentUserId = 1;

var currentUserGroupIds = context.Memberships
.Where(m => m.UserId == currentUserId)
.Select(m => m.GroupId);

var hasMembership = context.Memberships
.Where(m => currentUserGroupIds.Contains(m.GroupId))
.Select(m => m.User);

var query = context.Users
.Select(u => new
{
HasAccess = hasMembership.Any(e => e == u)
});

var users = async
? await query.ToListAsync()
: query.ToList();
}

protected class Context26593 : DbContext
{
public Context26593(DbContextOptions options)
: base(options)
{
}

public DbSet<User> Users { get; set; }
public DbSet<Group> Groups { get; set; }
public DbSet<Membership> Memberships { get; set; }

public void Seed()
{
var user = new User();
var group = new Group();
var membership = new Membership { Group = group, User = user };
AddRange(user, group, membership);

SaveChanges();
}
}

protected class User
{
public int Id { get; set; }

public ICollection<Membership> Memberships { get; set; }
}

protected class Group
{
public int Id { get; set; }

public ICollection<Membership> Memberships { get; set; }
}

protected class Membership
{
public int Id { get; set; }
public User User { get; set; }
public int UserId { get; set; }
public Group Group { get; set; }
public int GroupId { get; set; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,70 @@ public override async Task Comparing_byte_column_to_enum_in_vb_creating_double_c
FROM [Food] AS [f]
WHERE [f].[Taste] = CAST(1 AS tinyint)");
}

public override async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery(bool async)
{
await base.Unwrap_convert_node_over_projection_when_translating_contains_over_subquery(async);

AssertSql(
@"@__currentUserId_0='1'
SELECT CASE
WHEN EXISTS (
SELECT 1
FROM [Memberships] AS [m]
INNER JOIN [Users] AS [u0] ON [m].[UserId] = [u0].[Id]
WHERE EXISTS (
SELECT 1
FROM [Memberships] AS [m0]
WHERE ([m0].[UserId] = @__currentUserId_0) AND ([m0].[GroupId] = [m].[GroupId])) AND ([u0].[Id] = [u].[Id])) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [HasAccess]
FROM [Users] AS [u]");
}

public override async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_2(bool async)
{
await base.Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_2(async);

AssertSql(
@"@__currentUserId_0='1'
SELECT CASE
WHEN EXISTS (
SELECT 1
FROM [Memberships] AS [m]
INNER JOIN [Groups] AS [g] ON [m].[GroupId] = [g].[Id]
INNER JOIN [Users] AS [u0] ON [m].[UserId] = [u0].[Id]
WHERE EXISTS (
SELECT 1
FROM [Memberships] AS [m0]
INNER JOIN [Groups] AS [g0] ON [m0].[GroupId] = [g0].[Id]
WHERE ([m0].[UserId] = @__currentUserId_0) AND ([g0].[Id] = [g].[Id])) AND ([u0].[Id] = [u].[Id])) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [HasAccess]
FROM [Users] AS [u]");
}

public override async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_3(bool async)
{
await base.Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_3(async);

AssertSql(
@"@__currentUserId_0='1'
SELECT CASE
WHEN EXISTS (
SELECT 1
FROM [Memberships] AS [m]
INNER JOIN [Users] AS [u0] ON [m].[UserId] = [u0].[Id]
WHERE EXISTS (
SELECT 1
FROM [Memberships] AS [m0]
WHERE ([m0].[UserId] = @__currentUserId_0) AND ([m0].[GroupId] = [m].[GroupId])) AND ([u0].[Id] = [u].[Id])) THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END AS [HasAccess]
FROM [Users] AS [u]");
}
}
}

0 comments on commit b38c269

Please sign in to comment.