Skip to content

Commit

Permalink
Handle nulls passed to typed equality method expression
Browse files Browse the repository at this point in the history
Fixes #25492
  • Loading branch information
ajcvickers committed Sep 10, 2021
1 parent 50daaa1 commit 236be8f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 16 deletions.
21 changes: 18 additions & 3 deletions src/EFCore/ChangeTracking/ValueComparer`.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,23 @@ public ValueComparer(
&& m.GetParameters().Length == 1
&& m.GetParameters()[0].ParameterType == typeof(T));

if (typedEquals != null)
{
return Expression.Lambda<Func<T?, T?, bool>>(
type.IsClass
? Expression.OrElse(
Expression.AndAlso(
Expression.Equal(param1, Expression.Constant(null, type)),
Expression.Equal(param2, Expression.Constant(null, type))),
Expression.AndAlso(
Expression.AndAlso(
Expression.NotEqual(param1, Expression.Constant(null, type)),
Expression.NotEqual(param2, Expression.Constant(null, type))),
Expression.Call(param1, typedEquals, param2)))
: Expression.Call(param1, typedEquals, param2),
param1, param2);
}

while (typedEquals == null
&& type != null)
{
Expand All @@ -151,9 +168,7 @@ public ValueComparer(
ObjectEqualsMethod,
Expression.Convert(param1, typeof(object)),
Expression.Convert(param2, typeof(object)))
: typedEquals.IsStatic
? Expression.Call(typedEquals, param1, param2)
: Expression.Call(param1, typedEquals, param2),
: Expression.Call(typedEquals, param1, param2),
param1, param2);
}

Expand Down
55 changes: 42 additions & 13 deletions test/EFCore.Tests/Storage/ValueComparerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,37 @@ private void GenericCompareTest<T>(T value1, T value2, int? hashCode = null)
Assert.False(keyEquals(value2, value1));

Assert.Equal(hashCode ?? value1.GetHashCode(), getHashCode(value1));
Assert.Equal(hashCode ?? value1.GetHashCode(), getKeyHashCode(value1));
}

private void GenericCompareTestWithNulls<T>(T value1, T value2, int? hashCode = null)
where T : class
{
var comparer = new ValueComparer<T>(false);
var equals = comparer.EqualsExpression.Compile();
var getHashCode = comparer.HashCodeExpression.Compile();

Assert.True(equals(value1, value1));
Assert.True(equals(value2, value2));
Assert.False(equals(value1, value2));
Assert.False(equals(value2, value1));
Assert.False(equals(value1, null));
Assert.False(equals(null, value2));
Assert.True(equals(null, null));

var keyComparer = new ValueComparer<T>(true);
var keyEquals = keyComparer.EqualsExpression.Compile();
var getKeyHashCode = keyComparer.HashCodeExpression.Compile();

Assert.True(keyEquals(value1, value1));
Assert.True(keyEquals(value2, value2));
Assert.False(keyEquals(value1, value2));
Assert.False(keyEquals(value2, value1));
Assert.False(keyEquals(value1, null));
Assert.False(keyEquals(null, value2));
Assert.True(keyEquals(null, null));

Assert.Equal(hashCode ?? value1.GetHashCode(), getHashCode(value1));
Assert.Equal(hashCode ?? value1.GetHashCode(), getKeyHashCode(value1));
}

Expand All @@ -305,20 +335,19 @@ public void Default_raw_comparer_works_for_non_null_normal_types()
GenericCompareTest<decimal>(1, 2);
GenericCompareTest('A', 'B', 'A');
GenericCompareTest("A", "B");
GenericCompareTest<object>(1, "A");
GenericCompareTest(JustAnEnum.A, JustAnEnum.B);
GenericCompareTest(
new JustAClass { A = 1 }, new JustAClass { A = 2 });
GenericCompareTest(
new JustAClassWithEquality { A = 1 }, new JustAClassWithEquality { A = 2 });
GenericCompareTest(
new JustAClassWithEqualityOperators { A = 1 }, new JustAClassWithEqualityOperators { A = 2 });
GenericCompareTest(
new JustAStruct { A = 1 }, new JustAStruct { A = 2 });
GenericCompareTest(
new JustAStructWithEquality { A = 1 }, new JustAStructWithEquality { A = 2 });
GenericCompareTest(
new JustAStructWithEqualityOperators { A = 1 }, new JustAStructWithEqualityOperators { A = 2 });
GenericCompareTest(new JustAStruct { A = 1 }, new JustAStruct { A = 2 });
GenericCompareTest(new JustAStructWithEquality { A = 1 }, new JustAStructWithEquality { A = 2 });
GenericCompareTest(new JustAStructWithEqualityOperators { A = 1 }, new JustAStructWithEqualityOperators { A = 2 });
}

[ConditionalFact]
public void Default_raw_comparer_works_for_reference_types()
{
GenericCompareTestWithNulls<object>(1, "A");
GenericCompareTestWithNulls(new JustAClass { A = 1 }, new JustAClass { A = 2 });
GenericCompareTestWithNulls(new JustAClassWithEquality { A = 1 }, new JustAClassWithEquality { A = 2 });
GenericCompareTestWithNulls(new JustAClassWithEqualityOperators { A = 1 }, new JustAClassWithEqualityOperators { A = 2 });
}

[ConditionalFact]
Expand Down

0 comments on commit 236be8f

Please sign in to comment.