Skip to content

Commit

Permalink
Fix to #23761 - UseRelationalNulls causes subquery to include NOT IN …
Browse files Browse the repository at this point in the history
…(NULL, x)

We do optimization which combines comparisons based on the same property into IN and NOT IN. Problem is that if the comparisons contain nulls, we add those into the list of IN values, generating queries like  a IN (1, 2, NULL). In normal circumstances, during null semantics processing we extract these nulls and produce correct IS NULL / IS NOT NULL calls, however when useRelationalNulls is enabled, we don't run null semantics visitor and therefore don't "fix" the IN expressions.

Fix is to only apply the initial optimization for c# null semantics.

Fixes #23761
  • Loading branch information
maumar committed Jan 15, 2021
1 parent 09e8bad commit 545de3a
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal
public class SqlExpressionSimplifyingExpressionVisitor : ExpressionVisitor
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly bool _useRelationalNulls;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public SqlExpressionSimplifyingExpressionVisitor([NotNull] ISqlExpressionFactory sqlExpressionFactory)
public SqlExpressionSimplifyingExpressionVisitor([NotNull] ISqlExpressionFactory sqlExpressionFactory, bool useRelationalNulls)
{
_sqlExpressionFactory = sqlExpressionFactory;
_useRelationalNulls = useRelationalNulls;
}

/// <summary>
Expand Down Expand Up @@ -264,6 +266,15 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression)
leftValue = leftCandidateInfo.ConstantValue;
rightValue = rightCandidateInfo.ConstantValue;

// for relational nulls we can't combine comparisons that contain null
// a != 1 && a != null would be converted to a NOT IN (1, null), which never returns any results
// we need to keep it in the original form so that a != null gets converted to a IS NOT NULL instead
// for c# null semantics it's fine because null semantics visitor extracts null back into proper null checks
if (_useRelationalNulls && (leftValue == null || rightValue == null))
{
return sqlBinaryExpression.Update(left, right);
}

resultArray = ConstructCollection(leftValue, rightValue);
}
else if (leftConstantIsEnumerable && rightConstantIsEnumerable)
Expand All @@ -284,6 +295,11 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression)
? rightCandidateInfo.ConstantValue
: leftCandidateInfo.ConstantValue;

if (_useRelationalNulls && rightValue == null)
{
return sqlBinaryExpression.Update(left, right);
}

resultArray = AddToCollection((IEnumerable)leftValue, rightValue);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Utilities;

Expand All @@ -14,6 +15,8 @@ namespace Microsoft.EntityFrameworkCore.Query
/// <inheritdoc />
public class RelationalQueryTranslationPostprocessor : QueryTranslationPostprocessor
{
private readonly bool _useRelationalNulls;

/// <summary>
/// Creates a new instance of the <see cref="RelationalQueryTranslationPostprocessor" /> class.
/// </summary>
Expand All @@ -30,6 +33,7 @@ public RelationalQueryTranslationPostprocessor(
Check.NotNull(queryCompilationContext, nameof(queryCompilationContext));

RelationalDependencies = relationalDependencies;
_useRelationalNulls = RelationalOptionsExtension.Extract(queryCompilationContext.ContextOptions).UseRelationalNulls;
}

/// <summary>
Expand All @@ -45,7 +49,7 @@ public override Expression Process(Expression query)
query = new CollectionJoinApplyingExpressionVisitor((RelationalQueryCompilationContext)QueryCompilationContext).Visit(query);
query = new TableAliasUniquifyingExpressionVisitor().Visit(query);
query = new SelectExpressionPruningExpressionVisitor().Visit(query);
query = new SqlExpressionSimplifyingExpressionVisitor(RelationalDependencies.SqlExpressionFactory).Visit(query);
query = new SqlExpressionSimplifyingExpressionVisitor(RelationalDependencies.SqlExpressionFactory, _useRelationalNulls).Visit(query);
query = new RelationalValueConverterCompensatingExpressionVisitor(RelationalDependencies.SqlExpressionFactory).Visit(query);

#pragma warning disable 618
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1530,13 +1530,137 @@ await AssertQuery(

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task False_compared_to_negated_is_null(bool async)
public virtual Task False_compared_to_negated_is_null(bool async)
{
await AssertQuery(
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => false == (!(e.NullableStringA == null))));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Multiple_non_equality_comparisons_with_null_in_the_middle(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableIntA != 1 && e.NullableIntA != null && e.NullableIntA != 2));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_non_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => e.NullableIntA != 1 && e.NullableIntA != null).ToList();
ClearLog();
var query = ctx.Entities1.Where(e => e.NullableIntA != 1 && e.NullableIntA != null);

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_non_equality_comparisons_without_null_comparison_work_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => e.NullableIntA != 1 && e.NullableIntA != 2 && e.NullableIntA != null).ToList();
ClearLog();
var query = ctx.Entities1.Where(e => e.NullableIntA != 1 && e.NullableIntA != 2);

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => e.NullableIntA == 1 || e.NullableIntA == null).ToList();
ClearLog();
var query = ctx.Entities1.Where(e => e.NullableIntA == 1 || e.NullableIntA == null);

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_contains_calls_get_combined_into_one_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => new int?[] { 1, 2, 3 }.Contains(e.NullableIntA)).ToList();

ClearLog();
var query = ctx.Entities1.Where(e => new int?[] { 1, null }.Contains(e.NullableIntA)
|| new int?[] { 2, null, 3 }.Contains(e.NullableIntA));

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_negated_contains_calls_get_combined_into_one_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);
var query = ctx.Entities1.Where(e => !(new int?[] { 1, null }.Contains(e.NullableIntA))
&& !(new int?[] { 2, null, 3 }.Contains(e.NullableIntA)));

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Empty(result);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Contains_with_comparison_dont_get_combined_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => new int?[] { 1, 2 }.Contains(e.NullableIntA) || e.NullableIntA == null).ToList();

ClearLog();
var query = ctx.Entities1.Where(e => new int?[] { 1, 2 }.Contains(e.NullableIntA) || e.NullableIntA == null);

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Negated_contains_with_comparison_dont_get_combined_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => !(new int?[] { 1, 2 }.Contains(e.NullableIntA)) && e.NullableIntA != null).ToList();

ClearLog();
var query = ctx.Entities1.Where(e => e.NullableIntA != null && !(new int?[] { 1, 2 }.Contains(e.NullableIntA)));

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Negated_contains_with_comparison_without_null_get_combined_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => !(new int?[] { 1, 2, 3, null }.Contains(e.NullableIntA))).ToList();

ClearLog();
var query = ctx.Entities1.Where(e => e.NullableIntA != 3 && !(new int?[] { 1, 2 }.Contains(e.NullableIntA)));

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

private string NormalizeDelimitersInRawString(string sql)
=> Fixture.TestStore.NormalizeDelimitersInRawString(sql);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,96 @@ FROM [Entities1] AS [e]
WHERE [e].[NullableStringA] IS NULL");
}

public override async Task Multiple_non_equality_comparisons_with_null_in_the_middle(bool async)
{
await base.Multiple_non_equality_comparisons_with_null_in_the_middle(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] NOT IN (1, 2) AND [e].[NullableIntA] IS NOT NULL");
}

public override async Task Multiple_non_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(bool async)
{
await base.Multiple_non_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE ([e].[NullableIntA] <> 1) AND [e].[NullableIntA] IS NOT NULL");
}

public override async Task Multiple_non_equality_comparisons_without_null_comparison_work_for_relational_null_semantics(bool async)
{
await base.Multiple_non_equality_comparisons_without_null_comparison_work_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] NOT IN (1, 2)");
}

public override async Task Multiple_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(bool async)
{
await base.Multiple_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE ([e].[NullableIntA] = 1) OR [e].[NullableIntA] IS NULL");
}

public override async Task Multiple_contains_calls_get_combined_into_one_for_relational_null_semantics(bool async)
{
await base.Multiple_contains_calls_get_combined_into_one_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IN (1, NULL, 2, 3)");
}

public override async Task Multiple_negated_contains_calls_get_combined_into_one_for_relational_null_semantics(bool async)
{
await base.Multiple_negated_contains_calls_get_combined_into_one_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] NOT IN (1, NULL, 2, 3)");
}

public override async Task Contains_with_comparison_dont_get_combined_for_relational_null_semantics(bool async)
{
await base.Contains_with_comparison_dont_get_combined_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IN (1, 2) OR [e].[NullableIntA] IS NULL");
}

public override async Task Negated_contains_with_comparison_dont_get_combined_for_relational_null_semantics(bool async)
{
await base.Negated_contains_with_comparison_dont_get_combined_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IS NOT NULL AND [e].[NullableIntA] NOT IN (1, 2)");
}

public override async Task Negated_contains_with_comparison_without_null_get_combined_for_relational_null_semantics(bool async)
{
await base.Negated_contains_with_comparison_without_null_get_combined_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] NOT IN (1, 2, 3)");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand All @@ -1911,5 +2001,8 @@ protected override NullSemanticsContext CreateContext(bool useRelationalNulls =

return context;
}

protected override void ClearLog()
=> Fixture.TestSqlLoggerFactory.Clear();
}
}

0 comments on commit 545de3a

Please sign in to comment.