Skip to content

Commit

Permalink
Test: Add tests for string based include (#21058)
Browse files Browse the repository at this point in the history
  • Loading branch information
smitpatel authored May 29, 2020
1 parent 82a7309 commit 95e779b
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 19 deletions.
7 changes: 2 additions & 5 deletions src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2443,18 +2443,15 @@ source.Provider is EntityQueryProvider
}

internal static readonly MethodInfo ThenIncludeAfterEnumerableMethodInfo
= GetThenIncludeMethodInfo(typeof(IEnumerable<>));

private static MethodInfo GetThenIncludeMethodInfo(Type navType)
=> typeof(EntityFrameworkQueryableExtensions)
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(ThenInclude))
.Where(mi => mi.GetGenericArguments().Count() == 3)
.Single(
mi =>
{
var typeInfo = mi.GetParameters()[0].ParameterType.GenericTypeArguments[1];
return typeInfo.IsGenericType
&& typeInfo.GetGenericTypeDefinition() == navType;
&& typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>);
});

internal static readonly MethodInfo ThenIncludeAfterReferenceMethodInfo
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.EntityFrameworkCore.Query
{
public class NorthwindStringIncludeQueryInMemoryTest : NorthwindStringIncludeQueryTestBase<NorthwindQueryInMemoryFixture<NoopModelCustomizer>>
{
public NorthwindStringIncludeQueryInMemoryTest(NorthwindQueryInMemoryFixture<NoopModelCustomizer> fixture, ITestOutputHelper testOutputHelper)
: base(fixture)
{
//TestLoggerFactory.TestOutputHelper = testOutputHelper;
}

[ConditionalTheory(Skip = "Issue#17386")]
public override Task Include_collection_with_last_no_orderby(bool async)
{
return base.Include_collection_with_last_no_orderby(async);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ public virtual Task NoTracking_Include_with_cycles_does_not_throw_when_performin
return AssertQuery(
async,
ss => (from i in ss.Set<Order>().Include(o => o.Customer.Orders)
where i.OrderID < 10800
select i)
where i.OrderID < 10800
select i)
.PerformIdentityResolution());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,6 @@ public virtual async Task Include_property(bool async)
ss => ss.Set<Order>().Include(o => o.OrderDate)))).Message);
}

[ConditionalTheory(Skip = "issue #15312")]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Include_non_existing_navigation(bool async)
{
Assert.Equal(
CoreStrings.IncludeBadNavigation("ArcticMonkeys", nameof(Order)),
(await Assert.ThrowsAsync<InvalidOperationException>(
() => AssertQuery(
async,
ss => ss.Set<Order>().Include("ArcticMonkeys")))).Message);
}

[ConditionalTheory(Skip = "issue #15312")]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Include_property_expression_invalid(bool async)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
using Xunit;
using Microsoft.EntityFrameworkCore.Internal;

// ReSharper disable InconsistentNaming
// ReSharper disable StringStartsWithIsCultureSpecific

#pragma warning disable RCS1202 // Avoid NullReferenceException.

namespace Microsoft.EntityFrameworkCore.Query
{
public abstract class NorthwindStringIncludeQueryTestBase<TFixture> : NorthwindIncludeQueryTestBase<TFixture>
where TFixture : NorthwindQueryFixtureBase<NoopModelCustomizer>, new()
{
private static readonly IncludeRewritingExpressionVisitor _includeRewritingExpressionVisitor = new IncludeRewritingExpressionVisitor();
protected NorthwindStringIncludeQueryTestBase(TFixture fixture)
: base(fixture)
{
}

[ConditionalTheory(Skip = "issue #15312")]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Include_non_existing_navigation(bool async)
{
Assert.Equal(
CoreStrings.IncludeBadNavigation("ArcticMonkeys", nameof(Order)),
(await Assert.ThrowsAsync<InvalidOperationException>(
() => AssertQuery(
async,
ss => ss.Set<Order>().Include("ArcticMonkeys")))).Message);
}

// Property expression cannot be converted to string include
public override Task Include_property_expression_invalid(bool async) => Task.CompletedTask;

// Property expression cannot be converted to string include
public override Task Then_include_property_expression_invalid(bool async) => Task.CompletedTask;

public override async Task Include_closes_reader(bool async)
{
using var context = CreateContext();
if (async)
{
Assert.NotNull(await context.Set<Customer>().Include("Orders").FirstOrDefaultAsync());
Assert.NotNull(await context.Set<Product>().ToListAsync());
}
else
{
Assert.NotNull(context.Set<Customer>().Include("Orders").FirstOrDefault());
Assert.NotNull(context.Set<Product>().ToList());
}
}

public override async Task Include_collection_dependent_already_tracked(bool async)
{
using var context = CreateContext();
var orders = context.Set<Order>().Where(o => o.CustomerID == "ALFKI").ToList();
Assert.Equal(6, context.ChangeTracker.Entries().Count());

var customer
= async
? await context.Set<Customer>()
.Include("Orders")
.SingleAsync(c => c.CustomerID == "ALFKI")
: context.Set<Customer>()
.Include("Orders")
.Single(c => c.CustomerID == "ALFKI");

Assert.Equal(orders, customer.Orders, LegacyReferenceEqualityComparer.Instance);
Assert.Equal(6, customer.Orders.Count);
Assert.True(orders.All(o => ReferenceEquals(o.Customer, customer)));
Assert.Equal(6 + 1, context.ChangeTracker.Entries().Count());
}

public override async Task Include_collection_principal_already_tracked(bool async)
{
using var context = CreateContext();
var customer1 = context.Set<Customer>().Single(c => c.CustomerID == "ALFKI");
Assert.Single(context.ChangeTracker.Entries());

var customer2
= async
? await context.Set<Customer>()
.Include("Orders")
.SingleAsync(c => c.CustomerID == "ALFKI")
: context.Set<Customer>()
.Include("Orders")
.Single(c => c.CustomerID == "ALFKI");

Assert.Same(customer1, customer2);
Assert.Equal(6, customer2.Orders.Count);
Assert.True(customer2.Orders.All(o => o.Customer != null));
Assert.Equal(7, context.ChangeTracker.Entries().Count());
}

public override async Task Include_reference_dependent_already_tracked(bool async)
{
using var context = CreateContext();
var customer = context.Set<Customer>().Single(o => o.CustomerID == "ALFKI");
Assert.Single(context.ChangeTracker.Entries());

var orders
= async
? await context.Set<Order>().Include("Customer").Where(o => o.CustomerID == "ALFKI").ToListAsync()
: context.Set<Order>().Include("Customer").Where(o => o.CustomerID == "ALFKI").ToList();

Assert.Equal(6, orders.Count);
Assert.True(orders.All(o => ReferenceEquals(o.Customer, customer)));
Assert.Equal(7, context.ChangeTracker.Entries().Count());
}

protected override Expression RewriteServerQueryExpression(Expression serverQueryExpression)
{
serverQueryExpression = base.RewriteServerQueryExpression(serverQueryExpression);

return _includeRewritingExpressionVisitor.Visit(serverQueryExpression);
}

private class IncludeRewritingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo _includeMethodInfo
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.Include))
.Single(
mi =>
mi.GetGenericArguments().Count() == 2
&& mi.GetParameters().Any(
pi => pi.Name == "navigationPropertyPath" && pi.ParameterType != typeof(string)));

private static readonly MethodInfo _stringIncludeMethodInfo
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.Include))
.Single(
mi => mi.GetParameters().Any(
pi => pi.Name == "navigationPropertyPath" && pi.ParameterType == typeof(string)));

private static readonly MethodInfo _thenIncludeAfterReferenceMethodInfo
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.ThenInclude))
.Single(
mi => mi.GetGenericArguments().Count() == 3
&& mi.GetParameters()[0].ParameterType.GenericTypeArguments[1].IsGenericParameter);

private static readonly MethodInfo _thenIncludeAfterEnumerableMethodInfo
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.ThenInclude))
.Where(mi => mi.GetGenericArguments().Count() == 3)
.Single(
mi =>
{
var typeInfo = mi.GetParameters()[0].ParameterType.GenericTypeArguments[1];
return typeInfo.IsGenericType
&& typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>);
});

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
&& methodCallExpression.Method.IsGenericMethod)
{
var genericMethodDefinition = methodCallExpression.Method.GetGenericMethodDefinition();
if (genericMethodDefinition == _includeMethodInfo)
{
var source = Visit(methodCallExpression.Arguments[0]);

return Expression.Call(
_stringIncludeMethodInfo.MakeGenericMethod(methodCallExpression.Method.GetGenericArguments()[0]),
source,
Expression.Constant(GetPath(methodCallExpression.Arguments[1].UnwrapLambdaFromQuote().Body)));
}

if (genericMethodDefinition == _thenIncludeAfterEnumerableMethodInfo
|| genericMethodDefinition == _thenIncludeAfterReferenceMethodInfo)
{
var innerIncludeMethodCall = (MethodCallExpression)Visit(methodCallExpression.Arguments[0]);
var innerNavigationPath = (string)((ConstantExpression)innerIncludeMethodCall.Arguments[1]).Value;
var currentNavigationpath = GetPath(methodCallExpression.Arguments[1].UnwrapLambdaFromQuote().Body);

return innerIncludeMethodCall.Update(
innerIncludeMethodCall.Object,
new[]
{
innerIncludeMethodCall.Arguments[0],
Expression.Constant($"{innerNavigationPath}.{currentNavigationpath}")
});
}
}

return base.VisitMethodCall(methodCallExpression);
}

private static string GetPath(Expression expression)
{
switch (expression)
{
case MemberExpression memberExpression:
if (memberExpression.Expression is ParameterExpression)
{
return memberExpression.Member.Name;
}

return $"{GetPath(memberExpression.Expression)}.{memberExpression.Member.Name}";

case UnaryExpression unaryExpression
when unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.TypeAs:
return GetPath(unaryExpression.Operand);

default:
throw new NotImplementedException("Unhandled expression tree in Include lambda");
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using Microsoft.EntityFrameworkCore.TestUtilities;

namespace Microsoft.EntityFrameworkCore.Query
{
public class NorthwindStringIncludeQuerySqlServerTest : NorthwindStringIncludeQueryTestBase<NorthwindQuerySqlServerFixture<NoopModelCustomizer>>
{
// ReSharper disable once UnusedParameter.Local
public NorthwindStringIncludeQuerySqlServerTest(NorthwindQuerySqlServerFixture<NoopModelCustomizer> fixture)
: base(fixture)
{
Fixture.TestSqlLoggerFactory.Clear();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Xunit.Abstractions;

namespace Microsoft.EntityFrameworkCore.Query
{
public class NorthwindStringIncludeQuerySqliteTest : NorthwindStringIncludeQueryTestBase<NorthwindQuerySqliteFixture<NoopModelCustomizer>>
{
public NorthwindStringIncludeQuerySqliteTest(NorthwindQuerySqliteFixture<NoopModelCustomizer> fixture, ITestOutputHelper testOutputHelper)
: base(fixture)
{
//TestSqlLoggerFactory.CaptureOutput(testOutputHelper);
}

// Sqlite does not support Apply operations
public override Task Include_collection_with_cross_apply_with_filter(bool async) => Task.CompletedTask;

public override Task Include_collection_with_outer_apply_with_filter(bool async) => Task.CompletedTask;
}
}

0 comments on commit 95e779b

Please sign in to comment.