Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ protected override Expression VisitMethodCall(MethodCallExpression node)

static Expression VisitContainsMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection<Expression> arguments)
{
if (method.IsOneOf(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValue, MemoryExtensionsMethod.ContainsWithSpanAndValue))
var hasNoComparer = method.IsOneOf(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValue, MemoryExtensionsMethod.ContainsWithSpanAndValue);
var hasNullComparer = method.Is(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValueAndComparer) && arguments[2] is ConstantExpression { Value: null };

// C# 14 targets MemoryExtensionsMethod.Contains, rewrite it back to Enumerable.Contains
if (hasNoComparer || hasNullComparer)
{
var itemType = method.GetGenericArguments().Single();
var span = arguments[0];
Expand All @@ -65,29 +69,17 @@ static Expression VisitContainsMethod(MethodCallExpression node, MethodInfo meth
[unwrappedSpan, value]);
}
}
else if (method.Is(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValueAndComparer))
{
var itemType = method.GetGenericArguments().Single();
var span = arguments[0];
var value = arguments[1];
var comparer = arguments[2];

if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) &&
unwrappedSpan.Type.ImplementsIEnumerableOf(itemType))
{
return
Expression.Call(
EnumerableMethod.ContainsWithComparer.MakeGenericMethod(itemType),
[unwrappedSpan, value, comparer]);
}
}

return node;
}

static Expression VisitSequenceEqualMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection<Expression> arguments)
{
if (method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpan, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpan))
var hasNoComparer = method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpan, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpan);
var hasNullComparer = method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpanAndComparer) && arguments[2] is ConstantExpression { Value: null };

// C# 14 targets MemoryExtensionsMethod.SequenceEquals, rewrite it back to Enumerable.SequenceEquals
if (hasNoComparer || hasNullComparer)
{
var itemType = method.GetGenericArguments().Single();
var span = arguments[0];
Expand All @@ -104,24 +96,6 @@ static Expression VisitSequenceEqualMethod(MethodCallExpression node, MethodInfo
[unwrappedSpan, unwrappedOther]);
}
}
else if (method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpanAndComparer))
{
var itemType = method.GetGenericArguments().Single();
var span = arguments[0];
var other = arguments[1];
var comparer = arguments[2];

if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) &&
TryUnwrapSpanImplicitCast(other, out var unwrappedOther) &&
unwrappedSpan.Type.ImplementsIEnumerableOf(itemType) &&
unwrappedOther.Type.ImplementsIEnumerableOf(itemType))
{
return
Expression.Call(
EnumerableMethod.SequenceEqualWithComparer.MakeGenericMethod(itemType),
[unwrappedSpan, unwrappedOther, comparer]);
}
}

return node;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ public void MemoryExtensions_Contains_in_Where_should_work()
results.Select(x => x.Id).Should().Equal(2, 3);
}

[Fact]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new tests are fine.

I would also add these new tests:

    [Fact]
    public void Enumerable_Contains_with_null_comparer_should_work()
    {
        var collection = Fixture.Collection;
        var names = new[] { "Two", "Three" };

        var queryable = collection.AsQueryable().Where((C x) => names.Contains(x.Name, null));

        var results = queryable.ToArray();
        results.Select(x => x.Id).Should().Equal(2, 3);
    }

    [Fact]
    public void Enumerable_SequenceEqual_with_null_comparer_work()
    {
        var collection = Fixture.Collection;
        var ratings = new[] { 1, 9, 6 };

        var queryable = collection.AsQueryable().Where((C x) => ratings.SequenceEqual(x.Ratings, null));

        var results = queryable.ToArray();
        results.Select(x => x.Id).Should().Equal(3);
    }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

public void MemoryExtensions_Contains_in_Where_should_work_with_enum()
{
var collection = Fixture.Collection;
var daysOfWeek = new[] { DayOfWeek.Monday, DayOfWeek.Tuesday };

// Can't actually rewrite/fake these with MemoryExtensions.Contains overload with 3 args from .NET 10
// This test will activate correctly on .NET 10+
var queryable = collection.AsQueryable().Where(x => daysOfWeek.Contains(x.Day));

var results = queryable.ToArray();
results.Select(x => x.Id).Should().Equal(2, 3);
}

[Fact]
public void MemoryExtensions_Contains_in_Single_should_work()
{
Expand Down Expand Up @@ -93,6 +107,20 @@ public void MemoryExtensions_SequenceEqual_in_Where_should_work()
results.Select(x => x.Id).Should().Equal(3);
}

[Fact]
public void MemoryExtensions_SequenceEqual_in_Where_should_work_with_enum()
{
var collection = Fixture.Collection;
var daysOfWeek = new[] { DayOfWeek.Monday, DayOfWeek.Tuesday };

// Can't actually rewrite/fake these with MemoryExtensions.Contains overload with 3 args from .NET 10
// This test will activate correctly on .NET 10+
var queryable = collection.AsQueryable().Where(x => daysOfWeek.SequenceEqual(x.Days));

var results = queryable.ToArray();
results.Select(x => x.Id).Should().Equal(1);
}

[Fact]
public void MemoryExtensions_SequenceEqual_in_Single_should_work()
{
Expand Down Expand Up @@ -129,17 +157,19 @@ public void MemoryExtensions_SequenceEqual_in_Count_should_work()
public class C
{
public int Id { get; set; }
public DayOfWeek Day { get; set; }
public string Name { get; set; }
public int[] Ratings { get; set; }
public DayOfWeek[] Days { get; set; }
}

public sealed class ClassFixture : MongoCollectionFixture<C, BsonDocument>
{
protected override IEnumerable<BsonDocument> InitialData =>
[
BsonDocument.Parse("{ _id : 1, Name : \"One\", Ratings : [1, 2, 3, 4, 5] }"),
BsonDocument.Parse("{ _id : 2, Name : \"Two\", Ratings : [3, 4, 5, 6, 7] }"),
BsonDocument.Parse("{ _id : 3, Name : \"Three\", Ratings : [1, 9, 6] }")
BsonDocument.Parse("{ _id : 1, Name : \"One\", Day : 0, Ratings : [1, 2, 3, 4, 5], Days : [1, 2] }"),
BsonDocument.Parse("{ _id : 2, Name : \"Two\", Day : 1, Ratings : [3, 4, 5, 6, 7], Days: [1, 2, 3] }"),
BsonDocument.Parse("{ _id : 3, Name : \"Three\", Day : 2, Ratings : [1, 9, 6], Days: [2, 3, 4] }")
];
}

Expand Down Expand Up @@ -175,10 +205,13 @@ static Expression VisitContainsMethod(MethodCallExpression node, MethodInfo meth
if (source.Type.IsArray)
{
var readOnlySpan = ImplicitCastArrayToSpan(source, typeof(ReadOnlySpan<>), itemType);
return
Expression.Call(
MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValue.MakeGenericMethod(itemType),
[readOnlySpan, value]);

// Not worth checking for IEquatable<T> and generating 3 args overload as that requires .NET 10
// which if we had we could run the tests on natively without this visitor.

return Expression.Call(
MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValue.MakeGenericMethod(itemType),
[readOnlySpan, value]);
}
}
else if (method.Is(EnumerableMethod.ContainsWithComparer))
Expand Down
Loading