From b38c2698045f44073a76b29c8e9126adbb379b79 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Thu, 11 Nov 2021 12:41:16 -0800 Subject: [PATCH] Query: Optimize Contains to Any recursively (#26599) - Also unwrap convert node around projection when translating Contains Resolves #26593 --- ...yableMethodTranslatingExpressionVisitor.cs | 15 +- .../QueryOptimizingExpressionVisitor.cs | 7 +- .../Query/SimpleQueryTestBase.cs | 131 +++++++++++++++++- .../Query/SimpleQuerySqlServerTest.cs | 65 +++++++++ 4 files changed, 215 insertions(+), 3 deletions(-) diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 9adff93a033..d41c2aa5d77 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -317,7 +317,20 @@ private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType ent selectExpression.ClearOrdering(); } - if (source.ShaperExpression is ProjectionBindingExpression projectionBindingExpression) + var shaperExpression = source.ShaperExpression; + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue26593", out var enabled) + && enabled)) + { + // No need to check ConvertChecked since this is convert node which we may have added during projection + if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression + && unaryExpression.Operand.Type.IsNullableType() + && unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type) + { + shaperExpression = unaryExpression.Operand; + } + } + + if (shaperExpression is ProjectionBindingExpression projectionBindingExpression) { var projection = selectExpression.GetProjection(projectionBindingExpression); if (projection is SqlExpression sqlExpression) diff --git a/src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs b/src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs index a1358d8caaf..d1463185cb2 100644 --- a/src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/QueryOptimizingExpressionVisitor.cs @@ -285,8 +285,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp anyLambdaParameter, methodCallExpression.Arguments[1]), anyLambdaParameter); + var source = methodCallExpression.Arguments[0]; + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue26593", out var enabled) && enabled)) + { + source = Visit(source); + } - return Expression.Call(null, anyMethod, new[] { methodCallExpression.Arguments[0], anyLambda }); + return Expression.Call(null, anyMethod, new[] { source, anyLambda }); } var @object = default(Expression); diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index 865e9dd7715..1d10ea74f0c 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -347,7 +347,7 @@ public virtual async Task IsDeleted_query_filter_with_conversion_to_int_works(bo protected class Context26428 : DbContext { public Context26428(DbContextOptions options) - : base(options) + : base(options) { } @@ -422,5 +422,134 @@ protected class Location } #nullable disable + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery(bool async) + { + var contextFactory = await InitializeAsync(seed: c => c.Seed()); + using var context = contextFactory.CreateContext(); + + var currentUserId = 1; + + var currentUserGroupIds = context.Memberships + .Where(m => m.UserId == currentUserId) + .Select(m => m.GroupId); + + var hasMembership = context.Memberships + .Where(m => currentUserGroupIds.Contains(m.GroupId)) + .Select(m => m.User); + + var query = context.Users + .Select(u => new + { + HasAccess = hasMembership.Contains(u) + }); + + var users = async + ? await query.ToListAsync() + : query.ToList(); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_2(bool async) + { + var contextFactory = await InitializeAsync(seed: c => c.Seed()); + using var context = contextFactory.CreateContext(); + + var currentUserId = 1; + + var currentUserGroupIds = context.Memberships + .Where(m => m.UserId == currentUserId) + .Select(m => m.Group); + + var hasMembership = context.Memberships + .Where(m => currentUserGroupIds.Contains(m.Group)) + .Select(m => m.User); + + var query = context.Users + .Select(u => new + { + HasAccess = hasMembership.Contains(u) + }); + + var users = async + ? await query.ToListAsync() + : query.ToList(); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_3(bool async) + { + var contextFactory = await InitializeAsync(seed: c => c.Seed()); + using var context = contextFactory.CreateContext(); + + var currentUserId = 1; + + var currentUserGroupIds = context.Memberships + .Where(m => m.UserId == currentUserId) + .Select(m => m.GroupId); + + var hasMembership = context.Memberships + .Where(m => currentUserGroupIds.Contains(m.GroupId)) + .Select(m => m.User); + + var query = context.Users + .Select(u => new + { + HasAccess = hasMembership.Any(e => e == u) + }); + + var users = async + ? await query.ToListAsync() + : query.ToList(); + } + + protected class Context26593 : DbContext + { + public Context26593(DbContextOptions options) + : base(options) + { + } + + public DbSet Users { get; set; } + public DbSet Groups { get; set; } + public DbSet Memberships { get; set; } + + public void Seed() + { + var user = new User(); + var group = new Group(); + var membership = new Membership { Group = group, User = user }; + AddRange(user, group, membership); + + SaveChanges(); + } + } + + protected class User + { + public int Id { get; set; } + + public ICollection Memberships { get; set; } + } + + protected class Group + { + public int Id { get; set; } + + public ICollection Memberships { get; set; } + } + + protected class Membership + { + public int Id { get; set; } + public User User { get; set; } + public int UserId { get; set; } + public Group Group { get; set; } + public int GroupId { get; set; } + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index 4fad38a5246..821ba1ab853 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -63,5 +63,70 @@ public override async Task Comparing_byte_column_to_enum_in_vb_creating_double_c FROM [Food] AS [f] WHERE [f].[Taste] = CAST(1 AS tinyint)"); } + + public override async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery(bool async) + { + await base.Unwrap_convert_node_over_projection_when_translating_contains_over_subquery(async); + + AssertSql( + @"@__currentUserId_0='1' + +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Memberships] AS [m] + INNER JOIN [Users] AS [u0] ON [m].[UserId] = [u0].[Id] + WHERE EXISTS ( + SELECT 1 + FROM [Memberships] AS [m0] + WHERE ([m0].[UserId] = @__currentUserId_0) AND ([m0].[GroupId] = [m].[GroupId])) AND ([u0].[Id] = [u].[Id])) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END AS [HasAccess] +FROM [Users] AS [u]"); + } + + public override async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_2(bool async) + { + await base.Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_2(async); + + AssertSql( + @"@__currentUserId_0='1' + +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Memberships] AS [m] + INNER JOIN [Groups] AS [g] ON [m].[GroupId] = [g].[Id] + INNER JOIN [Users] AS [u0] ON [m].[UserId] = [u0].[Id] + WHERE EXISTS ( + SELECT 1 + FROM [Memberships] AS [m0] + INNER JOIN [Groups] AS [g0] ON [m0].[GroupId] = [g0].[Id] + WHERE ([m0].[UserId] = @__currentUserId_0) AND ([g0].[Id] = [g].[Id])) AND ([u0].[Id] = [u].[Id])) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END AS [HasAccess] +FROM [Users] AS [u]"); + } + + public override async Task Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_3(bool async) + { + await base.Unwrap_convert_node_over_projection_when_translating_contains_over_subquery_3(async); + + AssertSql( + @"@__currentUserId_0='1' + +SELECT CASE + WHEN EXISTS ( + SELECT 1 + FROM [Memberships] AS [m] + INNER JOIN [Users] AS [u0] ON [m].[UserId] = [u0].[Id] + WHERE EXISTS ( + SELECT 1 + FROM [Memberships] AS [m0] + WHERE ([m0].[UserId] = @__currentUserId_0) AND ([m0].[GroupId] = [m].[GroupId])) AND ([u0].[Id] = [u].[Id])) THEN CAST(1 AS bit) + ELSE CAST(0 AS bit) +END AS [HasAccess] +FROM [Users] AS [u]"); + } } }