Skip to content

Commit

Permalink
Fix to #30575 - Multiple LeftJoins (GroupJoins) lead to GroupJoin Exc…
Browse files Browse the repository at this point in the history
…eption when the same where is used twice (#30794)

Problem is in QueryableMethodNormalizingExpressionVisitor and specifically part where we convert from GroupJoin-SelectMany-DefaultIfEmpty into left join (SelectManyVerifyingExpressionVisitor). We check if the collection selector is correlated, and we do that by looking at parameters in that lambda. Problem is that the affected queries reference outside variable that gets parameterized and that breaks the correlation finding logic. Fix is to add a step that scans entire query and identifies external parameters before we try to normalize GJSMDIE into LeftJoins, so that those external parameters are not counted as correlated.

Fixes #30575
  • Loading branch information
maumar authored May 2, 2023
1 parent 099b1e2 commit 03541ad
Show file tree
Hide file tree
Showing 4 changed files with 585 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

private sealed class SelectManyVerifyingExpressionVisitor : ExpressionVisitor
{
private static readonly bool UseOldBehavior30575
= AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue30575", out var enabled30575) && enabled30575;

private readonly List<ParameterExpression> _allowedParameters = new();
private readonly ISet<string> _allowedMethods = new HashSet<string> { nameof(Queryable.Where), nameof(Queryable.AsQueryable) };

Expand Down Expand Up @@ -774,9 +777,20 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

protected override Expression VisitParameter(ParameterExpression parameterExpression)
{
if (_allowedParameters.Contains(parameterExpression))
if (!UseOldBehavior30575)
{
return parameterExpression;
if (_allowedParameters.Contains(parameterExpression)
|| parameterExpression.Name?.StartsWith(QueryCompilationContext.QueryParameterPrefix, StringComparison.Ordinal) == true)
{
return parameterExpression;
}
}
else
{
if (_allowedParameters.Contains(parameterExpression))
{
return parameterExpression;
}
}

if (parameterExpression == _rootParameter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3785,4 +3785,139 @@ public virtual Task Prune_does_not_throw_null_ref(bool async)
select l2.Level1_Required_Id).DefaultIfEmpty()
from l1 in ss.Set<Level1>().Where(x => x.Id != ids)
select l1);


[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_SelectMany_DefaultIfEmpty_with_predicate_using_closure(bool async)
{
var prm = 10;

return AssertQuery(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into grouping
from l2 in grouping.Where(x => x.Id != prm).DefaultIfEmpty()
select new { Id1 = l1.Id, Id2 = (int?)l2.Id },
elementSorter: e => (e.Id1, e.Id2),
elementAsserter: (e, a) =>
{
Assert.Equal(e.Id1, a.Id1);
Assert.Equal(e.Id2, a.Id2);
});
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_SelectMany_with_predicate_using_closure(bool async)
{
var prm = 10;

return AssertQuery(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into grouping
from l2 in grouping.Where(x => x.Id != prm)
select new { Id1 = l1.Id, Id2 = l2.Id },
elementSorter: e => (e.Id1, e.Id2),
elementAsserter: (e, a) =>
{
Assert.Equal(e.Id1, a.Id1);
Assert.Equal(e.Id2, a.Id2);
});
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_SelectMany_DefaultIfEmpty_with_predicate_using_closure_nested(bool async)
{
var prm1 = 10;
var prm2 = 20;

return AssertQuery(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into grouping1
from l2 in grouping1.Where(x => x.Id != prm1).DefaultIfEmpty()
join l3 in ss.Set<Level3>() on l2.Id equals l3.Level2_Optional_Id into grouping2
from l3 in grouping2.Where(x => x.Id != prm2).DefaultIfEmpty()
select new { Id1 = l1.Id, Id2 = (int?)l2.Id, Id3 = (int?)l3.Id },
elementSorter: e => (e.Id1, e.Id2, e.Id3),
elementAsserter: (e, a) =>
{
Assert.Equal(e.Id1, a.Id1);
Assert.Equal(e.Id2, a.Id2);
Assert.Equal(e.Id3, a.Id3);
});
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_SelectMany_with_predicate_using_closure_nested(bool async)
{
var prm1 = 10;
var prm2 = 20;

return AssertQuery(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into grouping1
from l2 in grouping1.Where(x => x.Id != prm1)
join l3 in ss.Set<Level3>() on l2.Id equals l3.Level2_Optional_Id into grouping2
from l3 in grouping2.Where(x => x.Id != prm2)
select new { Id1 = l1.Id, Id2 = l2.Id, Id3 = l3.Id },
elementSorter: e => (e.Id1, e.Id2, e.Id3),
elementAsserter: (e, a) =>
{
Assert.Equal(e.Id1, a.Id1);
Assert.Equal(e.Id2, a.Id2);
Assert.Equal(e.Id3, a.Id3);
});
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_SelectMany_DefaultIfEmpty_with_predicate_using_closure_nested_same_param(bool async)
{
var prm = 10;

return AssertQuery(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into grouping1
from l2 in grouping1.Where(x => x.Id != prm).DefaultIfEmpty()
join l3 in ss.Set<Level3>() on l2.Id equals l3.Level2_Optional_Id into grouping2
from l3 in grouping2.Where(x => x.Id != prm).DefaultIfEmpty()
select new { Id1 = l1.Id, Id2 = (int?)l2.Id, Id3 = (int?)l3.Id },
elementSorter: e => (e.Id1, e.Id2, e.Id3),
elementAsserter: (e, a) =>
{
Assert.Equal(e.Id1, a.Id1);
Assert.Equal(e.Id2, a.Id2);
Assert.Equal(e.Id3, a.Id3);
});
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_SelectMany_with_predicate_using_closure_nested_same_param(bool async)
{
var prm = 10;

return AssertQuery(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into grouping1
from l2 in grouping1.Where(x => x.Id != prm)
join l3 in ss.Set<Level3>() on l2.Id equals l3.Level2_Optional_Id into grouping2
from l3 in grouping2.Where(x => x.Id != prm)
select new { Id1 = l1.Id, Id2 = l2.Id, Id3 = l3.Id },
elementSorter: e => (e.Id1, e.Id2, e.Id3),
elementAsserter: (e, a) =>
{
Assert.Equal(e.Id1, a.Id1);
Assert.Equal(e.Id2, a.Id2);
Assert.Equal(e.Id3, a.Id3);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4576,6 +4576,136 @@ public override async Task Multiple_required_navigation_with_EF_Property_Include
AssertSql();
}

public override async Task GroupJoin_SelectMany_DefaultIfEmpty_with_predicate_using_closure(bool async)
{
await base.GroupJoin_SelectMany_DefaultIfEmpty_with_predicate_using_closure(async);

AssertSql(
"""
@__prm_0='10'
SELECT [l].[Id] AS [Id1], [t].[Id] AS [Id2]
FROM [LevelOne] AS [l]
LEFT JOIN (
SELECT [l0].[Id], [l0].[Level1_Optional_Id]
FROM [LevelTwo] AS [l0]
WHERE [l0].[Id] <> @__prm_0
) AS [t] ON [l].[Id] = [t].[Level1_Optional_Id]
""");
}

public override async Task GroupJoin_SelectMany_with_predicate_using_closure(bool async)
{
await base.GroupJoin_SelectMany_with_predicate_using_closure(async);

AssertSql(
"""
@__prm_0='10'
SELECT [l].[Id] AS [Id1], [t].[Id] AS [Id2]
FROM [LevelOne] AS [l]
INNER JOIN (
SELECT [l0].[Id], [l0].[Level1_Optional_Id]
FROM [LevelTwo] AS [l0]
WHERE [l0].[Id] <> @__prm_0
) AS [t] ON [l].[Id] = [t].[Level1_Optional_Id]
""");
}

public override async Task GroupJoin_SelectMany_DefaultIfEmpty_with_predicate_using_closure_nested(bool async)
{
await base.GroupJoin_SelectMany_DefaultIfEmpty_with_predicate_using_closure_nested(async);

AssertSql(
"""
@__prm1_0='10'
@__prm2_1='20'
SELECT [l].[Id] AS [Id1], [t].[Id] AS [Id2], [t0].[Id] AS [Id3]
FROM [LevelOne] AS [l]
LEFT JOIN (
SELECT [l0].[Id], [l0].[Level1_Optional_Id]
FROM [LevelTwo] AS [l0]
WHERE [l0].[Id] <> @__prm1_0
) AS [t] ON [l].[Id] = [t].[Level1_Optional_Id]
LEFT JOIN (
SELECT [l1].[Id], [l1].[Level2_Optional_Id]
FROM [LevelThree] AS [l1]
WHERE [l1].[Id] <> @__prm2_1
) AS [t0] ON [t].[Id] = [t0].[Level2_Optional_Id]
""");
}

public override async Task GroupJoin_SelectMany_with_predicate_using_closure_nested(bool async)
{
await base.GroupJoin_SelectMany_with_predicate_using_closure_nested(async);

AssertSql(
"""
@__prm1_0='10'
@__prm2_1='20'
SELECT [l].[Id] AS [Id1], [t].[Id] AS [Id2], [t0].[Id] AS [Id3]
FROM [LevelOne] AS [l]
INNER JOIN (
SELECT [l0].[Id], [l0].[Level1_Optional_Id]
FROM [LevelTwo] AS [l0]
WHERE [l0].[Id] <> @__prm1_0
) AS [t] ON [l].[Id] = [t].[Level1_Optional_Id]
INNER JOIN (
SELECT [l1].[Id], [l1].[Level2_Optional_Id]
FROM [LevelThree] AS [l1]
WHERE [l1].[Id] <> @__prm2_1
) AS [t0] ON [t].[Id] = [t0].[Level2_Optional_Id]
""");
}

public override async Task GroupJoin_SelectMany_DefaultIfEmpty_with_predicate_using_closure_nested_same_param(bool async)
{
await base.GroupJoin_SelectMany_DefaultIfEmpty_with_predicate_using_closure_nested_same_param(async);

AssertSql(
"""
@__prm_0='10'
SELECT [l].[Id] AS [Id1], [t].[Id] AS [Id2], [t0].[Id] AS [Id3]
FROM [LevelOne] AS [l]
LEFT JOIN (
SELECT [l0].[Id], [l0].[Level1_Optional_Id]
FROM [LevelTwo] AS [l0]
WHERE [l0].[Id] <> @__prm_0
) AS [t] ON [l].[Id] = [t].[Level1_Optional_Id]
LEFT JOIN (
SELECT [l1].[Id], [l1].[Level2_Optional_Id]
FROM [LevelThree] AS [l1]
WHERE [l1].[Id] <> @__prm_0
) AS [t0] ON [t].[Id] = [t0].[Level2_Optional_Id]
""");
}

public override async Task GroupJoin_SelectMany_with_predicate_using_closure_nested_same_param(bool async)
{
await base.GroupJoin_SelectMany_with_predicate_using_closure_nested_same_param(async);

AssertSql(
"""
@__prm_0='10'
SELECT [l].[Id] AS [Id1], [t].[Id] AS [Id2], [t0].[Id] AS [Id3]
FROM [LevelOne] AS [l]
INNER JOIN (
SELECT [l0].[Id], [l0].[Level1_Optional_Id]
FROM [LevelTwo] AS [l0]
WHERE [l0].[Id] <> @__prm_0
) AS [t] ON [l].[Id] = [t].[Level1_Optional_Id]
INNER JOIN (
SELECT [l1].[Id], [l1].[Level2_Optional_Id]
FROM [LevelThree] AS [l1]
WHERE [l1].[Id] <> @__prm_0
) AS [t0] ON [t].[Id] = [t0].[Level2_Optional_Id]
""");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
}
Loading

0 comments on commit 03541ad

Please sign in to comment.