From 95e779b5b623c71f74bad7bcd7b98723d4892cf6 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Thu, 28 May 2020 18:36:53 -0700 Subject: [PATCH] Test: Add tests for string based include (#21058) --- .../EntityFrameworkQueryableExtensions.cs | 7 +- ...NorthwindStringIncludeQueryInMemoryTest.cs | 25 ++ ...NorthwindIncludeNoTrackingQueryTestBase.cs | 4 +- .../Query/NorthwindIncludeQueryTestBase.cs | 12 - .../NorthwindStringIncludeQueryTestBase.cs | 227 ++++++++++++++++++ ...orthwindStringIncludeQuerySqlServerTest.cs | 17 ++ .../NorthwindStringIncludeQuerySqliteTest.cs | 23 ++ 7 files changed, 296 insertions(+), 19 deletions(-) create mode 100644 test/EFCore.InMemory.FunctionalTests/Query/NorthwindStringIncludeQueryInMemoryTest.cs create mode 100644 test/EFCore.Specification.Tests/Query/NorthwindStringIncludeQueryTestBase.cs create mode 100644 test/EFCore.SqlServer.FunctionalTests/Query/NorthwindStringIncludeQuerySqlServerTest.cs create mode 100644 test/EFCore.Sqlite.FunctionalTests/Query/NorthwindStringIncludeQuerySqliteTest.cs diff --git a/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs b/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs index 82d4d79c2c5..2154059af6a 100644 --- a/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs +++ b/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs @@ -2443,10 +2443,7 @@ source.Provider is EntityQueryProvider } internal static readonly MethodInfo ThenIncludeAfterEnumerableMethodInfo - = GetThenIncludeMethodInfo(typeof(IEnumerable<>)); - - private static MethodInfo GetThenIncludeMethodInfo(Type navType) - => typeof(EntityFrameworkQueryableExtensions) + = typeof(EntityFrameworkQueryableExtensions) .GetTypeInfo().GetDeclaredMethods(nameof(ThenInclude)) .Where(mi => mi.GetGenericArguments().Count() == 3) .Single( @@ -2454,7 +2451,7 @@ private static MethodInfo GetThenIncludeMethodInfo(Type navType) { var typeInfo = mi.GetParameters()[0].ParameterType.GenericTypeArguments[1]; return typeInfo.IsGenericType - && typeInfo.GetGenericTypeDefinition() == navType; + && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>); }); internal static readonly MethodInfo ThenIncludeAfterReferenceMethodInfo diff --git a/test/EFCore.InMemory.FunctionalTests/Query/NorthwindStringIncludeQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/NorthwindStringIncludeQueryInMemoryTest.cs new file mode 100644 index 00000000000..2b53fe08a7d --- /dev/null +++ b/test/EFCore.InMemory.FunctionalTests/Query/NorthwindStringIncludeQueryInMemoryTest.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.TestUtilities; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.EntityFrameworkCore.Query +{ + public class NorthwindStringIncludeQueryInMemoryTest : NorthwindStringIncludeQueryTestBase> + { + public NorthwindStringIncludeQueryInMemoryTest(NorthwindQueryInMemoryFixture fixture, ITestOutputHelper testOutputHelper) + : base(fixture) + { + //TestLoggerFactory.TestOutputHelper = testOutputHelper; + } + + [ConditionalTheory(Skip = "Issue#17386")] + public override Task Include_collection_with_last_no_orderby(bool async) + { + return base.Include_collection_with_last_no_orderby(async); + } + } +} diff --git a/test/EFCore.Specification.Tests/Query/NorthwindIncludeNoTrackingQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindIncludeNoTrackingQueryTestBase.cs index 5cf56614aab..a81d8df4a7d 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindIncludeNoTrackingQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindIncludeNoTrackingQueryTestBase.cs @@ -219,8 +219,8 @@ public virtual Task NoTracking_Include_with_cycles_does_not_throw_when_performin return AssertQuery( async, ss => (from i in ss.Set().Include(o => o.Customer.Orders) - where i.OrderID < 10800 - select i) + where i.OrderID < 10800 + select i) .PerformIdentityResolution()); } diff --git a/test/EFCore.Specification.Tests/Query/NorthwindIncludeQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindIncludeQueryTestBase.cs index 6703667e51b..5c1b3c7ce43 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindIncludeQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindIncludeQueryTestBase.cs @@ -76,18 +76,6 @@ public virtual async Task Include_property(bool async) ss => ss.Set().Include(o => o.OrderDate)))).Message); } - [ConditionalTheory(Skip = "issue #15312")] - [MemberData(nameof(IsAsyncData))] - public virtual async Task Include_non_existing_navigation(bool async) - { - Assert.Equal( - CoreStrings.IncludeBadNavigation("ArcticMonkeys", nameof(Order)), - (await Assert.ThrowsAsync( - () => AssertQuery( - async, - ss => ss.Set().Include("ArcticMonkeys")))).Message); - } - [ConditionalTheory(Skip = "issue #15312")] [MemberData(nameof(IsAsyncData))] public virtual async Task Include_property_expression_invalid(bool async) diff --git a/test/EFCore.Specification.Tests/Query/NorthwindStringIncludeQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindStringIncludeQueryTestBase.cs new file mode 100644 index 00000000000..5d62b3a812c --- /dev/null +++ b/test/EFCore.Specification.Tests/Query/NorthwindStringIncludeQueryTestBase.cs @@ -0,0 +1,227 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.TestUtilities; +using Microsoft.EntityFrameworkCore.TestModels.Northwind; +using Xunit; +using Microsoft.EntityFrameworkCore.Internal; + +// ReSharper disable InconsistentNaming +// ReSharper disable StringStartsWithIsCultureSpecific + +#pragma warning disable RCS1202 // Avoid NullReferenceException. + +namespace Microsoft.EntityFrameworkCore.Query +{ + public abstract class NorthwindStringIncludeQueryTestBase : NorthwindIncludeQueryTestBase + where TFixture : NorthwindQueryFixtureBase, new() + { + private static readonly IncludeRewritingExpressionVisitor _includeRewritingExpressionVisitor = new IncludeRewritingExpressionVisitor(); + protected NorthwindStringIncludeQueryTestBase(TFixture fixture) + : base(fixture) + { + } + + [ConditionalTheory(Skip = "issue #15312")] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Include_non_existing_navigation(bool async) + { + Assert.Equal( + CoreStrings.IncludeBadNavigation("ArcticMonkeys", nameof(Order)), + (await Assert.ThrowsAsync( + () => AssertQuery( + async, + ss => ss.Set().Include("ArcticMonkeys")))).Message); + } + + // Property expression cannot be converted to string include + public override Task Include_property_expression_invalid(bool async) => Task.CompletedTask; + + // Property expression cannot be converted to string include + public override Task Then_include_property_expression_invalid(bool async) => Task.CompletedTask; + + public override async Task Include_closes_reader(bool async) + { + using var context = CreateContext(); + if (async) + { + Assert.NotNull(await context.Set().Include("Orders").FirstOrDefaultAsync()); + Assert.NotNull(await context.Set().ToListAsync()); + } + else + { + Assert.NotNull(context.Set().Include("Orders").FirstOrDefault()); + Assert.NotNull(context.Set().ToList()); + } + } + + public override async Task Include_collection_dependent_already_tracked(bool async) + { + using var context = CreateContext(); + var orders = context.Set().Where(o => o.CustomerID == "ALFKI").ToList(); + Assert.Equal(6, context.ChangeTracker.Entries().Count()); + + var customer + = async + ? await context.Set() + .Include("Orders") + .SingleAsync(c => c.CustomerID == "ALFKI") + : context.Set() + .Include("Orders") + .Single(c => c.CustomerID == "ALFKI"); + + Assert.Equal(orders, customer.Orders, LegacyReferenceEqualityComparer.Instance); + Assert.Equal(6, customer.Orders.Count); + Assert.True(orders.All(o => ReferenceEquals(o.Customer, customer))); + Assert.Equal(6 + 1, context.ChangeTracker.Entries().Count()); + } + + public override async Task Include_collection_principal_already_tracked(bool async) + { + using var context = CreateContext(); + var customer1 = context.Set().Single(c => c.CustomerID == "ALFKI"); + Assert.Single(context.ChangeTracker.Entries()); + + var customer2 + = async + ? await context.Set() + .Include("Orders") + .SingleAsync(c => c.CustomerID == "ALFKI") + : context.Set() + .Include("Orders") + .Single(c => c.CustomerID == "ALFKI"); + + Assert.Same(customer1, customer2); + Assert.Equal(6, customer2.Orders.Count); + Assert.True(customer2.Orders.All(o => o.Customer != null)); + Assert.Equal(7, context.ChangeTracker.Entries().Count()); + } + + public override async Task Include_reference_dependent_already_tracked(bool async) + { + using var context = CreateContext(); + var customer = context.Set().Single(o => o.CustomerID == "ALFKI"); + Assert.Single(context.ChangeTracker.Entries()); + + var orders + = async + ? await context.Set().Include("Customer").Where(o => o.CustomerID == "ALFKI").ToListAsync() + : context.Set().Include("Customer").Where(o => o.CustomerID == "ALFKI").ToList(); + + Assert.Equal(6, orders.Count); + Assert.True(orders.All(o => ReferenceEquals(o.Customer, customer))); + Assert.Equal(7, context.ChangeTracker.Entries().Count()); + } + + protected override Expression RewriteServerQueryExpression(Expression serverQueryExpression) + { + serverQueryExpression = base.RewriteServerQueryExpression(serverQueryExpression); + + return _includeRewritingExpressionVisitor.Visit(serverQueryExpression); + } + + private class IncludeRewritingExpressionVisitor : ExpressionVisitor + { + private static readonly MethodInfo _includeMethodInfo + = typeof(EntityFrameworkQueryableExtensions) + .GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.Include)) + .Single( + mi => + mi.GetGenericArguments().Count() == 2 + && mi.GetParameters().Any( + pi => pi.Name == "navigationPropertyPath" && pi.ParameterType != typeof(string))); + + private static readonly MethodInfo _stringIncludeMethodInfo + = typeof(EntityFrameworkQueryableExtensions) + .GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.Include)) + .Single( + mi => mi.GetParameters().Any( + pi => pi.Name == "navigationPropertyPath" && pi.ParameterType == typeof(string))); + + private static readonly MethodInfo _thenIncludeAfterReferenceMethodInfo + = typeof(EntityFrameworkQueryableExtensions) + .GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.ThenInclude)) + .Single( + mi => mi.GetGenericArguments().Count() == 3 + && mi.GetParameters()[0].ParameterType.GenericTypeArguments[1].IsGenericParameter); + + private static readonly MethodInfo _thenIncludeAfterEnumerableMethodInfo + = typeof(EntityFrameworkQueryableExtensions) + .GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.ThenInclude)) + .Where(mi => mi.GetGenericArguments().Count() == 3) + .Single( + mi => + { + var typeInfo = mi.GetParameters()[0].ParameterType.GenericTypeArguments[1]; + return typeInfo.IsGenericType + && typeInfo.GetGenericTypeDefinition() == typeof(IEnumerable<>); + }); + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + if (methodCallExpression.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) + && methodCallExpression.Method.IsGenericMethod) + { + var genericMethodDefinition = methodCallExpression.Method.GetGenericMethodDefinition(); + if (genericMethodDefinition == _includeMethodInfo) + { + var source = Visit(methodCallExpression.Arguments[0]); + + return Expression.Call( + _stringIncludeMethodInfo.MakeGenericMethod(methodCallExpression.Method.GetGenericArguments()[0]), + source, + Expression.Constant(GetPath(methodCallExpression.Arguments[1].UnwrapLambdaFromQuote().Body))); + } + + if (genericMethodDefinition == _thenIncludeAfterEnumerableMethodInfo + || genericMethodDefinition == _thenIncludeAfterReferenceMethodInfo) + { + var innerIncludeMethodCall = (MethodCallExpression)Visit(methodCallExpression.Arguments[0]); + var innerNavigationPath = (string)((ConstantExpression)innerIncludeMethodCall.Arguments[1]).Value; + var currentNavigationpath = GetPath(methodCallExpression.Arguments[1].UnwrapLambdaFromQuote().Body); + + return innerIncludeMethodCall.Update( + innerIncludeMethodCall.Object, + new[] + { + innerIncludeMethodCall.Arguments[0], + Expression.Constant($"{innerNavigationPath}.{currentNavigationpath}") + }); + } + } + + return base.VisitMethodCall(methodCallExpression); + } + + private static string GetPath(Expression expression) + { + switch (expression) + { + case MemberExpression memberExpression: + if (memberExpression.Expression is ParameterExpression) + { + return memberExpression.Member.Name; + } + + return $"{GetPath(memberExpression.Expression)}.{memberExpression.Member.Name}"; + + case UnaryExpression unaryExpression + when unaryExpression.NodeType == ExpressionType.Convert + || unaryExpression.NodeType == ExpressionType.Convert + || unaryExpression.NodeType == ExpressionType.TypeAs: + return GetPath(unaryExpression.Operand); + + default: + throw new NotImplementedException("Unhandled expression tree in Include lambda"); + } + } + } + } +} diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindStringIncludeQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindStringIncludeQuerySqlServerTest.cs new file mode 100644 index 00000000000..0b3dcd2b89f --- /dev/null +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindStringIncludeQuerySqlServerTest.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.EntityFrameworkCore.TestUtilities; + +namespace Microsoft.EntityFrameworkCore.Query +{ + public class NorthwindStringIncludeQuerySqlServerTest : NorthwindStringIncludeQueryTestBase> + { + // ReSharper disable once UnusedParameter.Local + public NorthwindStringIncludeQuerySqlServerTest(NorthwindQuerySqlServerFixture fixture) + : base(fixture) + { + Fixture.TestSqlLoggerFactory.Clear(); + } + } +} diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindStringIncludeQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindStringIncludeQuerySqliteTest.cs new file mode 100644 index 00000000000..0d9fe2a0a11 --- /dev/null +++ b/test/EFCore.Sqlite.FunctionalTests/Query/NorthwindStringIncludeQuerySqliteTest.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.TestUtilities; +using Xunit.Abstractions; + +namespace Microsoft.EntityFrameworkCore.Query +{ + public class NorthwindStringIncludeQuerySqliteTest : NorthwindStringIncludeQueryTestBase> + { + public NorthwindStringIncludeQuerySqliteTest(NorthwindQuerySqliteFixture fixture, ITestOutputHelper testOutputHelper) + : base(fixture) + { + //TestSqlLoggerFactory.CaptureOutput(testOutputHelper); + } + + // Sqlite does not support Apply operations + public override Task Include_collection_with_cross_apply_with_filter(bool async) => Task.CompletedTask; + + public override Task Include_collection_with_outer_apply_with_filter(bool async) => Task.CompletedTask; + } +}