Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to #23761 - UseRelationalNulls causes subquery to include NOT IN (NULL, x) #23856

Merged
merged 1 commit into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal
public class SqlExpressionSimplifyingExpressionVisitor : ExpressionVisitor
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly bool _useRelationalNulls;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public SqlExpressionSimplifyingExpressionVisitor([NotNull] ISqlExpressionFactory sqlExpressionFactory)
public SqlExpressionSimplifyingExpressionVisitor([NotNull] ISqlExpressionFactory sqlExpressionFactory, bool useRelationalNulls)
{
_sqlExpressionFactory = sqlExpressionFactory;
_useRelationalNulls = useRelationalNulls;
}

/// <summary>
Expand Down Expand Up @@ -264,6 +266,15 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression)
leftValue = leftCandidateInfo.ConstantValue;
rightValue = rightCandidateInfo.ConstantValue;

// for relational nulls we can't combine comparisons that contain null
// a != 1 && a != null would be converted to a NOT IN (1, null), which never returns any results
// we need to keep it in the original form so that a != null gets converted to a IS NOT NULL instead
// for c# null semantics it's fine because null semantics visitor extracts null back into proper null checks
if (_useRelationalNulls && (leftValue == null || rightValue == null))
{
return sqlBinaryExpression.Update(left, right);
}

resultArray = ConstructCollection(leftValue, rightValue);
}
else if (leftConstantIsEnumerable && rightConstantIsEnumerable)
Expand All @@ -284,6 +295,11 @@ private Expression SimplifySqlBinary(SqlBinaryExpression sqlBinaryExpression)
? rightCandidateInfo.ConstantValue
: leftCandidateInfo.ConstantValue;

if (_useRelationalNulls && rightValue == null)
{
return sqlBinaryExpression.Update(left, right);
}

resultArray = AddToCollection((IEnumerable)leftValue, rightValue);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Utilities;

Expand All @@ -14,6 +15,8 @@ namespace Microsoft.EntityFrameworkCore.Query
/// <inheritdoc />
public class RelationalQueryTranslationPostprocessor : QueryTranslationPostprocessor
{
private readonly bool _useRelationalNulls;

/// <summary>
/// Creates a new instance of the <see cref="RelationalQueryTranslationPostprocessor" /> class.
/// </summary>
Expand All @@ -30,6 +33,7 @@ public RelationalQueryTranslationPostprocessor(
Check.NotNull(queryCompilationContext, nameof(queryCompilationContext));

RelationalDependencies = relationalDependencies;
_useRelationalNulls = RelationalOptionsExtension.Extract(queryCompilationContext.ContextOptions).UseRelationalNulls;
}

/// <summary>
Expand All @@ -45,7 +49,7 @@ public override Expression Process(Expression query)
query = new CollectionJoinApplyingExpressionVisitor((RelationalQueryCompilationContext)QueryCompilationContext).Visit(query);
query = new TableAliasUniquifyingExpressionVisitor().Visit(query);
query = new SelectExpressionPruningExpressionVisitor().Visit(query);
query = new SqlExpressionSimplifyingExpressionVisitor(RelationalDependencies.SqlExpressionFactory).Visit(query);
query = new SqlExpressionSimplifyingExpressionVisitor(RelationalDependencies.SqlExpressionFactory, _useRelationalNulls).Visit(query);
query = new RelationalValueConverterCompensatingExpressionVisitor(RelationalDependencies.SqlExpressionFactory).Visit(query);

#pragma warning disable 618
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1530,13 +1530,137 @@ await AssertQuery(

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task False_compared_to_negated_is_null(bool async)
public virtual Task False_compared_to_negated_is_null(bool async)
{
await AssertQuery(
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => false == (!(e.NullableStringA == null))));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Multiple_non_equality_comparisons_with_null_in_the_middle(bool async)
{
return AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(e => e.NullableIntA != 1 && e.NullableIntA != null && e.NullableIntA != 2));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_non_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => e.NullableIntA != 1 && e.NullableIntA != null).ToList();
ClearLog();
var query = ctx.Entities1.Where(e => e.NullableIntA != 1 && e.NullableIntA != null);

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_non_equality_comparisons_without_null_comparison_work_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => e.NullableIntA != 1 && e.NullableIntA != 2 && e.NullableIntA != null).ToList();
ClearLog();
var query = ctx.Entities1.Where(e => e.NullableIntA != 1 && e.NullableIntA != 2);

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => e.NullableIntA == 1 || e.NullableIntA == null).ToList();
ClearLog();
var query = ctx.Entities1.Where(e => e.NullableIntA == 1 || e.NullableIntA == null);
maumar marked this conversation as resolved.
Show resolved Hide resolved

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_contains_calls_get_combined_into_one_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => new int?[] { 1, 2, 3 }.Contains(e.NullableIntA)).ToList();
maumar marked this conversation as resolved.
Show resolved Hide resolved

ClearLog();
var query = ctx.Entities1.Where(e => new int?[] { 1, null }.Contains(e.NullableIntA)
|| new int?[] { 2, null, 3 }.Contains(e.NullableIntA));

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_negated_contains_calls_get_combined_into_one_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);
var query = ctx.Entities1.Where(e => !(new int?[] { 1, null }.Contains(e.NullableIntA))
&& !(new int?[] { 2, null, 3 }.Contains(e.NullableIntA)));

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Empty(result);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Contains_with_comparison_dont_get_combined_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => new int?[] { 1, 2 }.Contains(e.NullableIntA) || e.NullableIntA == null).ToList();

ClearLog();
var query = ctx.Entities1.Where(e => new int?[] { 1, 2 }.Contains(e.NullableIntA) || e.NullableIntA == null);

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Negated_contains_with_comparison_dont_get_combined_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => !(new int?[] { 1, 2 }.Contains(e.NullableIntA)) && e.NullableIntA != null).ToList();

ClearLog();
var query = ctx.Entities1.Where(e => e.NullableIntA != null && !(new int?[] { 1, 2 }.Contains(e.NullableIntA)));

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Negated_contains_with_comparison_without_null_get_combined_for_relational_null_semantics(bool async)
{
var ctx = CreateContext(useRelationalNulls: true);

var expected = ctx.Entities1.AsEnumerable().Where(e => !(new int?[] { 1, 2, 3, null }.Contains(e.NullableIntA))).ToList();

ClearLog();
var query = ctx.Entities1.Where(e => e.NullableIntA != 3 && !(new int?[] { 1, 2 }.Contains(e.NullableIntA)));

var result = async ? await query.ToListAsync() : query.ToList();
Assert.Equal(expected.Count, result.Count);
}

private string NormalizeDelimitersInRawString(string sql)
=> Fixture.TestStore.NormalizeDelimitersInRawString(sql);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,96 @@ FROM [Entities1] AS [e]
WHERE [e].[NullableStringA] IS NULL");
}

public override async Task Multiple_non_equality_comparisons_with_null_in_the_middle(bool async)
{
await base.Multiple_non_equality_comparisons_with_null_in_the_middle(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] NOT IN (1, 2) AND [e].[NullableIntA] IS NOT NULL");
}

public override async Task Multiple_non_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(bool async)
{
await base.Multiple_non_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE ([e].[NullableIntA] <> 1) AND [e].[NullableIntA] IS NOT NULL");
}

public override async Task Multiple_non_equality_comparisons_without_null_comparison_work_for_relational_null_semantics(bool async)
{
await base.Multiple_non_equality_comparisons_without_null_comparison_work_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] NOT IN (1, 2)");
}

public override async Task Multiple_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(bool async)
{
await base.Multiple_equality_comparisons_including_null_comparison_work_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE ([e].[NullableIntA] = 1) OR [e].[NullableIntA] IS NULL");
}

public override async Task Multiple_contains_calls_get_combined_into_one_for_relational_null_semantics(bool async)
{
await base.Multiple_contains_calls_get_combined_into_one_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IN (1, NULL, 2, 3)");
}

public override async Task Multiple_negated_contains_calls_get_combined_into_one_for_relational_null_semantics(bool async)
{
await base.Multiple_negated_contains_calls_get_combined_into_one_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] NOT IN (1, NULL, 2, 3)");
}

public override async Task Contains_with_comparison_dont_get_combined_for_relational_null_semantics(bool async)
{
await base.Contains_with_comparison_dont_get_combined_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IN (1, 2) OR [e].[NullableIntA] IS NULL");
}

public override async Task Negated_contains_with_comparison_dont_get_combined_for_relational_null_semantics(bool async)
{
await base.Negated_contains_with_comparison_dont_get_combined_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] IS NOT NULL AND [e].[NullableIntA] NOT IN (1, 2)");
}

public override async Task Negated_contains_with_comparison_without_null_get_combined_for_relational_null_semantics(bool async)
{
await base.Negated_contains_with_comparison_without_null_get_combined_for_relational_null_semantics(async);

AssertSql(
@"SELECT [e].[Id], [e].[BoolA], [e].[BoolB], [e].[BoolC], [e].[IntA], [e].[IntB], [e].[IntC], [e].[NullableBoolA], [e].[NullableBoolB], [e].[NullableBoolC], [e].[NullableIntA], [e].[NullableIntB], [e].[NullableIntC], [e].[NullableStringA], [e].[NullableStringB], [e].[NullableStringC], [e].[StringA], [e].[StringB], [e].[StringC]
FROM [Entities1] AS [e]
WHERE [e].[NullableIntA] NOT IN (1, 2, 3)");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand All @@ -1911,5 +2001,8 @@ protected override NullSemanticsContext CreateContext(bool useRelationalNulls =

return context;
}

protected override void ClearLog()
=> Fixture.TestSqlLoggerFactory.Clear();
}
}