diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs index 6afeb04123d..e323fa44b56 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosProjectionBindingExpressionVisitor.cs @@ -42,8 +42,10 @@ private readonly IDictionary _projectionMapping private readonly Stack _projectionMembers = new(); +#pragma warning disable CS0618 // Type or member is obsolete private readonly IDictionary _collectionShaperMapping = new Dictionary(); +#pragma warning restore CS0618 // Type or member is obsolete private readonly Stack _includedNavigations = new(); @@ -366,7 +368,9 @@ protected override Expression VisitMember(MemberExpression memberExpression) Expression.Convert(objectArrayProjectionExpression.InnerProjection, typeof(object)), typeof(ValueBuffer)), nullable: true); +#pragma warning disable CS0618 // Type or member is obsolete return new CollectionShaperExpression( +#pragma warning restore CS0618 // Type or member is obsolete objectArrayProjectionExpression, innerShaperExpression, navigation, @@ -570,7 +574,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp Expression.Convert(objectArrayProjectionExpression.InnerProjection, typeof(object)), typeof(ValueBuffer)), nullable: true); +#pragma warning disable CS0618 // Type or member is obsolete return new CollectionShaperExpression( +#pragma warning restore CS0618 // Type or member is obsolete objectArrayProjectionExpression, innerShaperExpression, navigation, @@ -599,7 +605,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Queryable.Select) when genericMethod == QueryableMethods.Select: +#pragma warning disable CS0618 // Type or member is obsolete if (!(visitedSource is CollectionShaperExpression shaper)) +#pragma warning restore CS0618 // Type or member is obsolete { return null; } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs index 16ca6f86119..e2ed6b2828d 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.CosmosProjectionBindingRemovingExpressionVisitorBase.cs @@ -248,7 +248,9 @@ protected override Expression VisitExtension(Expression extensionExpression) projectionBindingExpression.Type, (projection.Expression as SqlExpression)?.TypeMapping); } +#pragma warning disable CS0618 // Type or member is obsolete case CollectionShaperExpression collectionShaperExpression: +#pragma warning restore CS0618 // Type or member is obsolete { ObjectArrayProjectionExpression objectArrayProjection; switch (collectionShaperExpression.Projection) diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs index 0ce72f8ba61..db278fe49d1 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.JObjectInjectingExpressionVisitor.cs @@ -53,7 +53,9 @@ protected override Expression VisitExtension(Expression extensionExpression) expressions); } +#pragma warning disable CS0618 // Type or member is obsolete case CollectionShaperExpression collectionShaperExpression: +#pragma warning restore CS0618 // Type or member is obsolete { _currentEntityIndex++; diff --git a/src/EFCore.InMemory/Properties/InMemoryStrings.Designer.cs b/src/EFCore.InMemory/Properties/InMemoryStrings.Designer.cs index f0ca3dfe42b..21cc04b226b 100644 --- a/src/EFCore.InMemory/Properties/InMemoryStrings.Designer.cs +++ b/src/EFCore.InMemory/Properties/InMemoryStrings.Designer.cs @@ -29,6 +29,12 @@ private static readonly ResourceManager _resourceManager public static string DefaultIfEmptyAppliedAfterProjection => GetString("DefaultIfEmptyAppliedAfterProjection"); + /// + /// Using 'Distinct' operation on a projection containing a subquery is not supported. + /// + public static string DistinctOnSubqueryNotSupported + => GetString("DistinctOnSubqueryNotSupported"); + /// /// The specified entity type '{derivedType}' is not derived from '{entityType}'. /// @@ -59,6 +65,12 @@ public static string NullabilityErrorExceptionSensitive(object? requiredProperti GetString("NullabilityErrorExceptionSensitive", nameof(requiredProperties), nameof(entityType), nameof(keyValue)), requiredProperties, entityType, keyValue); + /// + /// Unable to translate set operation after client projection has been applied. Consider moving the set operation before the last 'Select' call. + /// + public static string SetOperationsNotAllowedAfterClientEvaluation + => GetString("SetOperationsNotAllowedAfterClientEvaluation"); + /// /// Unable to bind '{memberType}' '{member}' to entity projection of '{entityType}'. /// diff --git a/src/EFCore.InMemory/Properties/InMemoryStrings.resx b/src/EFCore.InMemory/Properties/InMemoryStrings.resx index 4eaae2fee06..f914f2fc4a0 100644 --- a/src/EFCore.InMemory/Properties/InMemoryStrings.resx +++ b/src/EFCore.InMemory/Properties/InMemoryStrings.resx @@ -120,6 +120,9 @@ Cannot apply 'DefaultIfEmpty' after a client-evaluated projection. Consider applying 'DefaultIfEmpty' before last 'Select' or use 'AsEnumerable' before 'DefaultIfEmpty' to apply it on client-side. + + Using 'Distinct' operation on a projection containing a subquery is not supported. + The specified entity type '{derivedType}' is not derived from '{entityType}'. @@ -140,6 +143,9 @@ Required properties '{requiredProperties}' are missing for the instance of entity type '{entityType}' with the key value '{keyValue}'. + + Unable to translate set operation after client projection has been applied. Consider moving the set operation before the last 'Select' call. + Unable to bind '{memberType}' '{member}' to entity projection of '{entityType}'. diff --git a/src/EFCore.InMemory/Query/Internal/CollectionResultShaperExpression.cs b/src/EFCore.InMemory/Query/Internal/CollectionResultShaperExpression.cs new file mode 100644 index 00000000000..bb9849ddb5f --- /dev/null +++ b/src/EFCore.InMemory/Query/Internal/CollectionResultShaperExpression.cs @@ -0,0 +1,128 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public class CollectionResultShaperExpression : Expression, IPrintableExpression + { + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public CollectionResultShaperExpression( + Expression projection, + Expression innerShaper, + INavigationBase? navigation, + Type elementType) + { + Check.NotNull(projection, nameof(projection)); + Check.NotNull(innerShaper, nameof(innerShaper)); + + Projection = projection; + InnerShaper = innerShaper; + Navigation = navigation; + ElementType = elementType; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression Projection { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression InnerShaper { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual INavigationBase? Navigation { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Type ElementType { get; } + + /// + public sealed override ExpressionType NodeType + => ExpressionType.Extension; + + /// + public override Type Type + => Navigation?.ClrType ?? typeof(List<>).MakeGenericType(ElementType); + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + Check.NotNull(visitor, nameof(visitor)); + + var projection = visitor.Visit(Projection); + var innerShaper = visitor.Visit(InnerShaper); + + return Update(projection, innerShaper); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual CollectionResultShaperExpression Update( + Expression projection, + Expression innerShaper) + { + Check.NotNull(projection, nameof(projection)); + Check.NotNull(innerShaper, nameof(innerShaper)); + + return projection != Projection || innerShaper != InnerShaper + ? new CollectionResultShaperExpression(projection, innerShaper, Navigation, ElementType) + : this; + } + + /// + void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) + { + Check.NotNull(expressionPrinter, nameof(expressionPrinter)); + + expressionPrinter.AppendLine("CollectionResultShaperExpression:"); + using (expressionPrinter.Indent()) + { + expressionPrinter.Append("("); + expressionPrinter.Visit(Projection); + expressionPrinter.Append(", "); + expressionPrinter.Visit(InnerShaper); + expressionPrinter.AppendLine($", {Navigation?.Name}, {ElementType.ShortDisplayName()})"); + } + } + } +} diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 8a9f0de5a15..7df44c10cc3 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -278,12 +278,7 @@ protected override Expression VisitExtension(Expression extensionExpression) case ProjectionBindingExpression projectionBindingExpression when projectionBindingExpression.ProjectionMember != null: - return ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression) - .GetMappedProjection(projectionBindingExpression.ProjectionMember); - - //case ProjectionBindingExpression projectionBindingExpression - // when projectionBindingExpression.Index is int index: - // return ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression).Projection[index]; + return ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression).GetProjection(projectionBindingExpression); case InMemoryGroupByShaperExpression inMemoryGroupByShaperExpression: return new GroupingElementExpression( @@ -747,9 +742,8 @@ static Expression RemapLambda(GroupingElementExpression groupingElement, LambdaE } return ProcessSingleResultScalar( - subquery.ServerQueryExpression, - subquery.GetMappedProjection(projectionBindingExpression.ProjectionMember), - subquery.CurrentParameter, + subquery, + subquery.GetProjection(projectionBindingExpression), methodCallExpression.Type); } @@ -1209,23 +1203,13 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) Expression readValueExpression; var projectionBindingExpression = (ProjectionBindingExpression)entityShaper.ValueBufferExpression; - if (projectionBindingExpression.ProjectionMember != null) - { - var entityProjectionExpression = (EntityProjectionExpression)inMemoryQueryExpression.GetMappedProjection( - projectionBindingExpression.ProjectionMember); - readValueExpression = entityProjectionExpression.BindProperty(property); - } - else - { - // This has to be index map since entities cannot map to just integer index - var index = projectionBindingExpression.IndexMap![property]; - readValueExpression = inMemoryQueryExpression.Projection[index]; - } + var entityProjectionExpression = (EntityProjectionExpression)inMemoryQueryExpression.GetProjection( + projectionBindingExpression); + readValueExpression = entityProjectionExpression.BindProperty(property); return ProcessSingleResultScalar( - inMemoryQueryExpression.ServerQueryExpression, + inMemoryQueryExpression, readValueExpression, - inMemoryQueryExpression.CurrentParameter, type); } @@ -1233,36 +1217,40 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) } private static Expression ProcessSingleResultScalar( - Expression serverQuery, + InMemoryQueryExpression inMemoryQueryExpression, Expression readValueExpression, - Expression valueBufferParameter, Type type) { - var singleResult = ((LambdaExpression)((NewExpression)serverQuery).Arguments[0]).Body; - if (readValueExpression is UnaryExpression unaryExpression + if (inMemoryQueryExpression.ServerQueryExpression is not NewExpression) + { + // The terminating operator is not applied + // It is of FirstOrDefault kind + // So we change to single column projection and then apply it. + inMemoryQueryExpression.ReplaceProjection(new Dictionary + { + { new ProjectionMember(), readValueExpression } + }); + inMemoryQueryExpression.ApplyProjection(); + } + + var serverQuery = inMemoryQueryExpression.ServerQueryExpression; + serverQuery = ((LambdaExpression)((NewExpression)serverQuery).Arguments[0]).Body; + if (serverQuery is UnaryExpression unaryExpression && unaryExpression.NodeType == ExpressionType.Convert && unaryExpression.Type == typeof(object)) { - readValueExpression = unaryExpression.Operand; + serverQuery = unaryExpression.Operand; } var valueBufferVariable = Expression.Variable(typeof(ValueBuffer)); - var replacedReadExpression = ReplacingExpressionVisitor.Replace( - valueBufferParameter, - valueBufferVariable, - readValueExpression); - - replacedReadExpression = replacedReadExpression.Type == type - ? replacedReadExpression - : Expression.Convert(replacedReadExpression, type); - + var readExpression = valueBufferVariable.CreateValueBufferReadValueExpression(type, index: 0, property: null); return Expression.Block( variables: new[] { valueBufferVariable }, - Expression.Assign(valueBufferVariable, singleResult), + Expression.Assign(valueBufferVariable, serverQuery), Expression.Condition( Expression.MakeMemberAccess(valueBufferVariable, _valueBufferIsEmpty), Expression.Default(type), - replacedReadExpression)); + readExpression)); } [UsedImplicitly] diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs index b07de3e2c13..7bbea8f094c 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryProjectionBindingExpressionVisitor.cs @@ -27,11 +27,12 @@ public class InMemoryProjectionBindingExpressionVisitor : ExpressionVisitor private readonly InMemoryExpressionTranslatingExpressionVisitor _expressionTranslatingExpressionVisitor; private InMemoryQueryExpression _queryExpression; - private bool _clientEval; + private bool _indexBasedBinding; private Dictionary? _entityProjectionCache; private readonly Dictionary _projectionMapping = new(); + private List? _clientProjections; private readonly Stack _projectionMembers = new(); /// @@ -58,7 +59,7 @@ public InMemoryProjectionBindingExpressionVisitor( public virtual Expression Translate(InMemoryQueryExpression queryExpression, Expression expression) { _queryExpression = queryExpression; - _clientEval = false; + _indexBasedBinding = false; _projectionMembers.Push(new ProjectionMember()); @@ -67,20 +68,25 @@ public virtual Expression Translate(InMemoryQueryExpression queryExpression, Exp if (result == QueryCompilationContext.NotTranslatedExpression) { - _clientEval = true; + _indexBasedBinding = true; + _projectionMapping.Clear(); _entityProjectionCache = new(); + _clientProjections = new(); expandedExpression = _queryableMethodTranslatingExpressionVisitor.ExpandWeakEntities(_queryExpression, expression); result = Visit(expandedExpression); + _queryExpression.ReplaceProjection(_clientProjections); + _clientProjections = null; + } + else + { + _queryExpression.ReplaceProjection(_projectionMapping); _projectionMapping.Clear(); } - _queryExpression.ReplaceProjectionMapping(_projectionMapping); _queryExpression = null!; - _projectionMapping.Clear(); _projectionMembers.Clear(); - result = MatchTypes(result!, expression.Type); return result; @@ -113,33 +119,54 @@ public virtual Expression Translate(InMemoryQueryExpression queryExpression, Exp return parameter; } - if (_clientEval) + if (_indexBasedBinding) { switch (expression) { case ConstantExpression _: return expression; + case ProjectionBindingExpression projectionBindingExpression: + var mappedProjection = _queryExpression.GetProjection(projectionBindingExpression); + if (mappedProjection is EntityProjectionExpression entityProjection) + { + return AddClientProjection(entityProjection, typeof(ValueBuffer)); + } + + if (mappedProjection is not InMemoryQueryExpression) + { + return AddClientProjection(mappedProjection, expression.Type.MakeNullable()); + } + + throw new InvalidOperationException(CoreStrings.TranslationFailed(projectionBindingExpression.Print())); + case MaterializeCollectionNavigationExpression materializeCollectionNavigationExpression: - return AddCollectionProjection( - _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( - materializeCollectionNavigationExpression.Subquery)!, + { + var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( + materializeCollectionNavigationExpression.Subquery)!; + _clientProjections!.Add(subquery.QueryExpression); + return new CollectionResultShaperExpression( + new ProjectionBindingExpression(_queryExpression, _clientProjections.Count - 1, typeof(IEnumerable)), + subquery.ShaperExpression, materializeCollectionNavigationExpression.Navigation, materializeCollectionNavigationExpression.Navigation.ClrType.GetSequenceType()); + } case MethodCallExpression methodCallExpression: - { if (methodCallExpression.Method.IsGenericMethod && methodCallExpression.Method.DeclaringType == typeof(Enumerable) - && methodCallExpression.Method.Name == nameof(Enumerable.ToList)) + && methodCallExpression.Method.Name == nameof(Enumerable.ToList) + && methodCallExpression.Arguments.Count == 1 + && methodCallExpression.Arguments[0].Type.TryGetElementType(typeof(IQueryable<>)) != null) { - var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( + var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery( methodCallExpression.Arguments[0]); - - if (subqueryTranslation != null) + if (subquery != null) { - return AddCollectionProjection( - subqueryTranslation, + _clientProjections!.Add(subquery.QueryExpression); + return new CollectionResultShaperExpression( + new ProjectionBindingExpression(_queryExpression, _clientProjections.Count - 1, typeof(IEnumerable)), + subquery.ShaperExpression, null, methodCallExpression.Method.GetGenericArguments()[0]); } @@ -149,30 +176,40 @@ public virtual Expression Translate(InMemoryQueryExpression queryExpression, Exp var subquery = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); if (subquery != null) { - if (subquery.ResultCardinality == ResultCardinality.Enumerable) + // This simplifies the check when subquery is translated and can be lifted as scalar. + var scalarTranslation = _expressionTranslatingExpressionVisitor.Translate(subquery); + if (scalarTranslation != null) { - return AddCollectionProjection(subquery, null, subquery.ShaperExpression.Type); + return AddClientProjection(scalarTranslation, expression.Type.MakeNullable()); } - return new SingleResultShaperExpression( - new ProjectionBindingExpression( - _queryExpression, - _queryExpression.AddSubqueryProjection(subquery, out var innerShaper), - typeof(ValueBuffer)), - innerShaper, - subquery.ShaperExpression.Type); + if (subquery.ResultCardinality == ResultCardinality.Enumerable) + { + _clientProjections!.Add(subquery.QueryExpression); + var projectionBindingExpression = new ProjectionBindingExpression( + _queryExpression, _clientProjections.Count - 1, typeof(IEnumerable)); + return new CollectionResultShaperExpression( + projectionBindingExpression, subquery.ShaperExpression, navigation: null, subquery.ShaperExpression.Type); + } + else + { + _clientProjections!.Add(subquery.QueryExpression); + var projectionBindingExpression = new ProjectionBindingExpression( + _queryExpression, _clientProjections.Count - 1, typeof(ValueBuffer)); + return new SingleResultShaperExpression(projectionBindingExpression, subquery.ShaperExpression); + } } } - break; - } } var translation = _expressionTranslatingExpressionVisitor.Translate(expression); - return translation == null - ? base.Visit(expression) - : new ProjectionBindingExpression( - _queryExpression, _queryExpression.AddToProjection(translation), expression.Type.MakeNullable()); + if (translation != null) + { + return AddClientProjection(translation, expression.Type.MakeNullable()); + } + + return base.Visit(expression); } else { @@ -247,19 +284,18 @@ protected override Expression VisitExtension(Expression extensionExpression) } entityProjectionExpression = (EntityProjectionExpression)((InMemoryQueryExpression)projectionBindingExpression.QueryExpression) - .GetMappedProjection(projectionBindingExpression.ProjectionMember); + .GetProjection(projectionBindingExpression); } else { entityProjectionExpression = (EntityProjectionExpression)entityShaperExpression.ValueBufferExpression; } - if (_clientEval) + if (_indexBasedBinding) { if (!_entityProjectionCache!.TryGetValue(entityProjectionExpression, out var entityProjectionBinding)) { - entityProjectionBinding = new ProjectionBindingExpression( - _queryExpression, _queryExpression.AddToProjection(entityProjectionExpression)); + entityProjectionBinding = AddClientProjection(entityProjectionExpression, typeof(ValueBuffer)); _entityProjectionCache[entityProjectionExpression] = entityProjectionBinding; } @@ -274,7 +310,7 @@ protected override Expression VisitExtension(Expression extensionExpression) if (extensionExpression is IncludeExpression includeExpression) { - return _clientEval + return _indexBasedBinding ? base.VisitExtension(includeExpression) : QueryCompilationContext.NotTranslatedExpression; } @@ -330,7 +366,7 @@ protected override MemberAssignment VisitMemberAssignment(MemberAssignment membe { var expression = memberAssignment.Expression; Expression? visitedExpression; - if (_clientEval) + if (_indexBasedBinding) { visitedExpression = Visit(memberAssignment.Expression); } @@ -442,7 +478,7 @@ protected override Expression VisitNew(NewExpression newExpression) return newExpression; } - if (!_clientEval + if (!_indexBasedBinding && newExpression.Members == null) { return QueryCompilationContext.NotTranslatedExpression; @@ -453,7 +489,7 @@ protected override Expression VisitNew(NewExpression newExpression) { var argument = newExpression.Arguments[i]; Expression? visitedArgument; - if (_clientEval) + if (_indexBasedBinding) { visitedArgument = Visit(argument); } @@ -502,21 +538,6 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) : unaryExpression.Update(MatchTypes(operand, unaryExpression.Operand.Type)); } - private CollectionShaperExpression AddCollectionProjection( - ShapedQueryExpression subquery, - INavigationBase? navigation, - Type elementType) - => new( - new ProjectionBindingExpression( - _queryExpression, - _queryExpression.AddSubqueryProjection( - subquery, - out var innerShaper), - typeof(IEnumerable)), - innerShaper, - navigation, - elementType); - private static Expression MatchTypes(Expression expression, Type targetType) { if (targetType != expression.Type @@ -529,5 +550,17 @@ private static Expression MatchTypes(Expression expression, Type targetType) return expression; } + + private ProjectionBindingExpression AddClientProjection(Expression expression, Type type) + { + var existingIndex = _clientProjections!.FindIndex(e => e.Equals(expression)); + if (existingIndex == -1) + { + _clientProjections.Add(expression); + existingIndex = _clientProjections.Count - 1; + } + + return new ProjectionBindingExpression(_queryExpression, existingIndex, type); + } } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.Helper.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.Helper.cs new file mode 100644 index 00000000000..365871b76cc --- /dev/null +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.Helper.cs @@ -0,0 +1,177 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Storage; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal +{ + public partial class InMemoryQueryExpression + { + private sealed class ResultEnumerable : IEnumerable + { + private readonly Func _getElement; + + public ResultEnumerable(Func getElement) + { + _getElement = getElement; + } + + public IEnumerator GetEnumerator() + => new ResultEnumerator(_getElement()); + + IEnumerator IEnumerable.GetEnumerator() + => GetEnumerator(); + + private sealed class ResultEnumerator : IEnumerator + { + private readonly ValueBuffer _value; + private bool _moved; + + public ResultEnumerator(ValueBuffer value) + { + _value = value; + _moved = _value.IsEmpty; + } + + public bool MoveNext() + { + if (!_moved) + { + _moved = true; + + return _moved; + } + + return false; + } + + public void Reset() + { + _moved = false; + } + + object IEnumerator.Current + => Current; + + public ValueBuffer Current + => !_moved ? ValueBuffer.Empty : _value; + + void IDisposable.Dispose() + { + } + } + } + + private sealed class ProjectionMemberRemappingExpressionVisitor : ExpressionVisitor + { + private readonly Expression _queryExpression; + private readonly Dictionary _projectionMemberMappings; + + public ProjectionMemberRemappingExpressionVisitor( + Expression queryExpression, Dictionary projectionMemberMappings) + { + _queryExpression = queryExpression; + _projectionMemberMappings = projectionMemberMappings; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (expression is ProjectionBindingExpression projectionBindingExpression) + { + Check.DebugAssert(projectionBindingExpression.ProjectionMember != null, + "ProjectionBindingExpression must have projection member."); + + return new ProjectionBindingExpression( + _queryExpression, + _projectionMemberMappings[projectionBindingExpression.ProjectionMember], + projectionBindingExpression.Type); + } + + return base.Visit(expression); + } + } + + private sealed class ProjectionMemberToIndexConvertingExpressionVisitor : ExpressionVisitor + { + private readonly Expression _queryExpression; + private readonly Dictionary _projectionMemberMappings; + + public ProjectionMemberToIndexConvertingExpressionVisitor( + Expression queryExpression, Dictionary projectionMemberMappings) + { + _queryExpression = queryExpression; + _projectionMemberMappings = projectionMemberMappings; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (expression is ProjectionBindingExpression projectionBindingExpression) + { + Check.DebugAssert(projectionBindingExpression.ProjectionMember != null, + "ProjectionBindingExpression must have projection member."); + + return new ProjectionBindingExpression( + _queryExpression, + _projectionMemberMappings[projectionBindingExpression.ProjectionMember], + projectionBindingExpression.Type); + } + + return base.Visit(expression); + } + } + + private sealed class ProjectionIndexRemappingExpressionVisitor : ExpressionVisitor + { + private readonly Expression _oldExpression; + private readonly Expression _newExpression; + private readonly int[] _indexMap; + + public ProjectionIndexRemappingExpressionVisitor( + Expression oldExpression, Expression newExpression, int[] indexMap) + { + _oldExpression = oldExpression; + _newExpression = newExpression; + _indexMap = indexMap; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (expression is ProjectionBindingExpression projectionBindingExpression + && ReferenceEquals(projectionBindingExpression.QueryExpression, _oldExpression)) + { + Check.DebugAssert(projectionBindingExpression.Index != null, + "ProjectionBindingExpression must have index."); + + return new ProjectionBindingExpression( + _newExpression, + _indexMap[projectionBindingExpression.Index.Value], + projectionBindingExpression.Type); + } + + return base.Visit(expression); + } + } + + private sealed class EntityShaperNullableMarkingExpressionVisitor : ExpressionVisitor + { + protected override Expression VisitExtension(Expression extensionExpression) + { + Check.NotNull(extensionExpression, nameof(extensionExpression)); + + return extensionExpression is EntityShaperExpression entityShaper + ? entityShaper.MakeNullable() + : base.VisitExtension(extensionExpression); + } + } + } +} diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.ResultEnumerable.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.ResultEnumerable.cs deleted file mode 100644 index 4f4f92e5610..00000000000 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.ResultEnumerable.cs +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections; -using System.Collections.Generic; -using Microsoft.EntityFrameworkCore.Storage; - -namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal -{ - public partial class InMemoryQueryExpression - { - private sealed class ResultEnumerable : IEnumerable - { - private readonly Func _getElement; - - public ResultEnumerable(Func getElement) - { - _getElement = getElement; - } - - public IEnumerator GetEnumerator() - => new ResultEnumerator(_getElement()); - - IEnumerator IEnumerable.GetEnumerator() - => GetEnumerator(); - - private sealed class ResultEnumerator : IEnumerator - { - private readonly ValueBuffer _value; - private bool _moved; - - public ResultEnumerator(ValueBuffer value) - { - _value = value; - _moved = _value.IsEmpty; - } - - public bool MoveNext() - { - if (!_moved) - { - _moved = true; - - return _moved; - } - - return false; - } - - public void Reset() - { - _moved = false; - } - - object IEnumerator.Current - => Current; - - public ValueBuffer Current - => !_moved ? ValueBuffer.Empty : _value; - - void IDisposable.Dispose() - { - } - } - } - } -} diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs index 955b35d6f58..848fa000318 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs @@ -3,13 +3,9 @@ using System; using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Dynamic; using System.Linq; using System.Linq.Expressions; using System.Reflection; -using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.InMemory.Internal; using Microsoft.EntityFrameworkCore.Metadata; @@ -37,13 +33,17 @@ private static readonly PropertyInfo _valueBufferCountMemberInfo private static readonly MethodInfo _leftJoinMethodInfo = typeof(InMemoryQueryExpression).GetTypeInfo() .GetDeclaredMethods(nameof(LeftJoin)).Single(mi => mi.GetParameters().Length == 6); - private readonly List _clientProjectionExpressions = new(); - private readonly List _projectionMappingExpressions = new(); + private static readonly ConstructorInfo _resultEnumerableConstructor + = typeof(ResultEnumerable).GetConstructors().Single(); private readonly ParameterExpression _valueBufferParameter; - - private IDictionary _projectionMapping = new Dictionary(); private ParameterExpression? _groupingParameter; + private MethodInfo? _singleResultMethodInfo; + private bool _scalarServerQuery; + + private Dictionary _projectionMapping = new(); + private readonly List _clientProjections = new(); + private readonly List _projectionMappingExpressions = new(); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -55,7 +55,7 @@ public InMemoryQueryExpression(IEntityType entityType) { _valueBufferParameter = Parameter(typeof(ValueBuffer), "valueBuffer"); ServerQueryExpression = new InMemoryTableExpression(entityType); - var readExpressionMap = new Dictionary(); + var propertyExpressionsMap = new Dictionary(); var selectorExpressions = new List(); foreach (var property in entityType.GetAllBaseTypesInclusive().SelectMany(et => et.GetDeclaredProperties())) { @@ -64,7 +64,7 @@ public InMemoryQueryExpression(IEntityType entityType) Check.DebugAssert(property.GetIndex() == selectorExpressions.Count - 1, "Properties should be ordered in same order as their indexes."); - readExpressionMap[property] = propertyExpression; + propertyExpressionsMap[property] = propertyExpression; _projectionMappingExpressions.Add(propertyExpression); } @@ -77,7 +77,7 @@ public InMemoryQueryExpression(IEntityType entityType) var entityCheck = derivedEntityType.GetConcreteDerivedTypesInclusive() .Select( e => keyValueComparer.ExtractEqualsBody( - readExpressionMap[discriminatorProperty], + propertyExpressionsMap[discriminatorProperty], Constant(e.GetDiscriminatorValue(), discriminatorProperty.ClrType))) .Aggregate((l, r) => OrElse(l, r)); @@ -90,7 +90,7 @@ public InMemoryQueryExpression(IEntityType entityType) selectorExpressions.Add(propertyExpression); var readExpression = CreateReadValueExpression(property.ClrType, selectorExpressions.Count - 1, property); - readExpressionMap[property] = readExpression; + propertyExpressionsMap[property] = readExpression; _projectionMappingExpressions.Add(readExpression); } } @@ -110,19 +110,10 @@ public InMemoryQueryExpression(IEntityType entityType) selectorLambda); } - var entityProjection = new EntityProjectionExpression(entityType, readExpressionMap); + var entityProjection = new EntityProjectionExpression(entityType, propertyExpressionsMap); _projectionMapping[new ProjectionMember()] = entityProjection; } - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public virtual IReadOnlyList Projection - => _clientProjectionExpressions; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -146,35 +137,12 @@ public virtual ParameterExpression CurrentParameter /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override Type Type - => typeof(IEnumerable); - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public sealed override ExpressionType NodeType - => ExpressionType.Extension; - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public virtual Expression GetSingleScalarProjection() + public virtual void ReplaceProjection(IReadOnlyList clientProjections) { - var expression = CreateReadValueExpression(ServerQueryExpression.Type, 0, null); _projectionMapping.Clear(); - _projectionMapping[new ProjectionMember()] = expression; - _projectionMappingExpressions.Add(expression); - _groupingParameter = null; - - ConvertToEnumerable(); - - return new ProjectionBindingExpression(this, new ProjectionMember(), expression.Type.MakeNullable()); + _projectionMappingExpressions.Clear(); + _clientProjections.Clear(); + _clientProjections.AddRange(clientProjections); } /// @@ -183,131 +151,80 @@ public virtual Expression GetSingleScalarProjection() /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual void ConvertToEnumerable() + public virtual void ReplaceProjection(IReadOnlyDictionary projectionMapping) { - if (ServerQueryExpression.Type.TryGetSequenceType() == null) + _projectionMapping.Clear(); + _projectionMappingExpressions.Clear(); + _clientProjections.Clear(); + var selectorExpressions = new List(); + foreach (var keyValuePair in projectionMapping) { - if (ServerQueryExpression.Type != typeof(ValueBuffer)) + if (keyValuePair.Value is EntityProjectionExpression entityProjectionExpression) { - if (ServerQueryExpression.Type.IsValueType) - { - ServerQueryExpression = Convert(ServerQueryExpression, typeof(object)); - } - - ServerQueryExpression = New( - typeof(ResultEnumerable).GetConstructors().Single(), - Lambda>( - New( - _valueBufferConstructor, - NewArrayInit(typeof(object), ServerQueryExpression)))); + _projectionMapping[keyValuePair.Key] = AddEntityProjection(entityProjectionExpression); } else { - ServerQueryExpression = New( - typeof(ResultEnumerable).GetConstructors().Single(), - Lambda>(ServerQueryExpression)); + selectorExpressions.Add(keyValuePair.Value); + var readExpression = CreateReadValueExpression( + keyValuePair.Value.Type, selectorExpressions.Count - 1, InferPropertyFromInner(keyValuePair.Value)); + _projectionMapping[keyValuePair.Key] = readExpression; + _projectionMappingExpressions.Add(readExpression); } } - } - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public virtual void ReplaceProjectionMapping(IDictionary projectionMappings) - { - _projectionMapping.Clear(); - _projectionMappingExpressions.Clear(); - LambdaExpression? selectorLambda = null; - if (_clientProjectionExpressions.Count > 0) + if (selectorExpressions.Count == 0) { - var remappedProjections = _clientProjectionExpressions - .Select((e, i) => CreateReadValueExpression(e.Type, i, InferPropertyFromInner(e))).ToList(); + // No server correlated term in projection so add dummy 1. + selectorExpressions.Add(Constant(1)); + } - selectorLambda = Lambda( - New( - _valueBufferConstructor, - NewArrayInit( - typeof(object), - _clientProjectionExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), - CurrentParameter); + var selectorLambda = Lambda( + New( + _valueBufferConstructor, + NewArrayInit( + typeof(object), + selectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e).ToArray())), + CurrentParameter); - _clientProjectionExpressions.Clear(); - _clientProjectionExpressions.AddRange(remappedProjections); - } - else - { - var selectorExpressions = new List(); - foreach (var kvp in projectionMappings) - { - if (kvp.Value is EntityProjectionExpression entityProjectionExpression) - { - _projectionMapping[kvp.Key] = UpdateEntityProjection(entityProjectionExpression); - } - else - { - selectorExpressions.Add(kvp.Value); - var expression = CreateReadValueExpression( - kvp.Value.Type, selectorExpressions.Count - 1, InferPropertyFromInner(kvp.Value)); - _projectionMapping[kvp.Key] = expression; - _projectionMappingExpressions.Add(expression); - } - } + ServerQueryExpression = Call( + EnumerableMethods.Select.MakeGenericMethod(CurrentParameter.Type, typeof(ValueBuffer)), + ServerQueryExpression, + selectorLambda); - if (selectorExpressions.Count == 0) + _groupingParameter = null; + + EntityProjectionExpression AddEntityProjection(EntityProjectionExpression entityProjectionExpression) + { + var readExpressionMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) { - // No server correlated term in projection so add dummy 1. - selectorExpressions.Add(Constant(1)); + var expression = entityProjectionExpression.BindProperty(property); + selectorExpressions.Add(expression); + var newExpression = CreateReadValueExpression(expression.Type, selectorExpressions.Count - 1, property); + readExpressionMap[property] = newExpression; + _projectionMappingExpressions.Add(newExpression); } - selectorLambda = Lambda( - New( - _valueBufferConstructor, - NewArrayInit( - typeof(object), - selectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), - CurrentParameter); + var result = new EntityProjectionExpression(entityProjectionExpression.EntityType, readExpressionMap); - EntityProjectionExpression UpdateEntityProjection(EntityProjectionExpression entityProjection) + // Also compute nested entity projections + foreach (var navigation in entityProjectionExpression.EntityType.GetAllBaseTypes() + .Concat(entityProjectionExpression.EntityType.GetDerivedTypesInclusive()) + .SelectMany(t => t.GetDeclaredNavigations())) { - var readExpressionMap = new Dictionary(); - foreach (var property in GetAllPropertiesInHierarchy(entityProjection.EntityType)) + var boundEntityShaperExpression = entityProjectionExpression.BindNavigation(navigation); + if (boundEntityShaperExpression != null) { - var expression = entityProjection.BindProperty(property); - selectorExpressions.Add(expression); - var newExpression = CreateReadValueExpression(expression.Type, selectorExpressions.Count - 1, property); - readExpressionMap[property] = newExpression; - _projectionMappingExpressions.Add(newExpression); + var innerEntityProjection = (EntityProjectionExpression)boundEntityShaperExpression.ValueBufferExpression; + var newInnerEntityProjection = AddEntityProjection(innerEntityProjection); + boundEntityShaperExpression = boundEntityShaperExpression.Update(newInnerEntityProjection); + result.AddNavigationBinding(navigation, boundEntityShaperExpression); } - - var result = new EntityProjectionExpression(entityProjection.EntityType, readExpressionMap); - - // Also compute nested entity projections - foreach (var navigation in entityProjection.EntityType.GetAllBaseTypes() - .Concat(entityProjection.EntityType.GetDerivedTypesInclusive()) - .SelectMany(t => t.GetDeclaredNavigations())) - { - var boundEntityShaperExpression = entityProjection.BindNavigation(navigation); - if (boundEntityShaperExpression != null) - { - var innerEntityProjection = (EntityProjectionExpression)boundEntityShaperExpression.ValueBufferExpression; - var newInnerEntityProjection = UpdateEntityProjection(innerEntityProjection); - boundEntityShaperExpression = boundEntityShaperExpression.Update(newInnerEntityProjection); - result.AddNavigationBinding(navigation, boundEntityShaperExpression); - } - } - - return result; } - } - ServerQueryExpression = Call( - EnumerableMethods.Select.MakeGenericMethod(CurrentParameter.Type, typeof(ValueBuffer)), - ServerQueryExpression, - selectorLambda); - _groupingParameter = null; + return result; + } } /// @@ -316,16 +233,10 @@ EntityProjectionExpression UpdateEntityProjection(EntityProjectionExpression ent /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual IReadOnlyDictionary AddToProjection(EntityProjectionExpression entityProjectionExpression) - { - var indexMap = new Dictionary(); - foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) - { - indexMap[property] = AddToProjection(entityProjectionExpression.BindProperty(property)); - } - - return indexMap; - } + public virtual Expression GetProjection(ProjectionBindingExpression projectionBindingExpression) + => projectionBindingExpression.ProjectionMember != null + ? _projectionMapping[projectionBindingExpression.ProjectionMember] + : _clientProjections[projectionBindingExpression.Index!.Value]; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -333,53 +244,106 @@ public virtual IReadOnlyDictionary AddToProjection(EntityProject /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual int AddToProjection(Expression expression) + public virtual void ApplyProjection() { - _clientProjectionExpressions.Add(expression); + if (_scalarServerQuery) + { + _projectionMapping[new ProjectionMember()] = Constant(0); + return; + } - return _clientProjectionExpressions.Count - 1; - } + var selectorExpressions = new List(); + if (_clientProjections.Count > 0) + { + for (var i = 0; i < _clientProjections.Count; i++) + { + var projection = _clientProjections[i]; + switch (projection) + { + case EntityProjectionExpression entityProjectionExpression: + { + var indexMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) + { + selectorExpressions.Add(entityProjectionExpression.BindProperty(property)); + indexMap[property] = selectorExpressions.Count - 1; + } - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public virtual int AddSubqueryProjection( - ShapedQueryExpression shapedQueryExpression, - out Expression innerShaper) - { - var subquery = (InMemoryQueryExpression)shapedQueryExpression.QueryExpression; - subquery.ApplyProjection(); - var serverQueryExpression = subquery.ServerQueryExpression; + _clientProjections[i] = Constant(indexMap); + break; + } + + case InMemoryQueryExpression inMemoryQueryExpression: + { + var singleResult = inMemoryQueryExpression._scalarServerQuery || inMemoryQueryExpression._singleResultMethodInfo != null; + inMemoryQueryExpression.ApplyProjection(); + var serverQuery = inMemoryQueryExpression.ServerQueryExpression; + if (singleResult) + { + serverQuery = ((LambdaExpression)((NewExpression)serverQuery).Arguments[0]).Body; + } + selectorExpressions.Add(serverQuery); + _clientProjections[i] = Constant(selectorExpressions.Count - 1); + break; + } - if (serverQueryExpression is MethodCallExpression selectMethodCall - && selectMethodCall.Arguments[0].Type == typeof(ResultEnumerable)) + default: + selectorExpressions.Add(projection); + _clientProjections[i] = Constant(selectorExpressions.Count - 1); + break; + } + } + } + else { - var terminatingMethodCall = - (MethodCallExpression)((LambdaExpression)((NewExpression)selectMethodCall.Arguments[0]).Arguments[0]).Body; - selectMethodCall = selectMethodCall.Update( - null!, new[] { terminatingMethodCall.Arguments[0], selectMethodCall.Arguments[1] }); - serverQueryExpression = terminatingMethodCall.Update(null!, new[] { selectMethodCall }); + var newProjectionMapping = new Dictionary(); + foreach (var keyValuePair in _projectionMapping) + { + if (keyValuePair.Value is EntityProjectionExpression entityProjectionExpression) + { + var indexMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) + { + selectorExpressions.Add(entityProjectionExpression.BindProperty(property)); + indexMap[property] = selectorExpressions.Count - 1; + } + + newProjectionMapping[keyValuePair.Key] = Constant(indexMap); + } + else + { + selectorExpressions.Add(keyValuePair.Value); + newProjectionMapping[keyValuePair.Key] = Constant(selectorExpressions.Count - 1); + } + } + _projectionMapping = newProjectionMapping; + _projectionMappingExpressions.Clear(); } - innerShaper = new ShaperRemappingExpressionVisitor(subquery._projectionMapping) - .Visit(shapedQueryExpression.ShaperExpression); + var selectorLambda = Lambda( + New( + _valueBufferConstructor, + NewArrayInit( + typeof(object), + selectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e).ToArray())), + CurrentParameter); - innerShaper = Lambda(innerShaper, subquery.CurrentParameter); + ServerQueryExpression = Call( + EnumerableMethods.Select.MakeGenericMethod(CurrentParameter.Type, typeof(ValueBuffer)), + ServerQueryExpression, + selectorLambda); - return AddToProjection(serverQueryExpression); - } + if (_singleResultMethodInfo != null) + { + ServerQueryExpression = Call( + _singleResultMethodInfo.MakeGenericMethod(CurrentParameter.Type), + ServerQueryExpression); - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public virtual Expression GetMappedProjection(ProjectionMember member) - => _projectionMapping[member]; + _singleResultMethodInfo = null; + + ConvertToEnumerable(); + } + } /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -399,7 +363,7 @@ public virtual void UpdateServerQueryExpression(Expression serverQueryExpression public virtual void ApplySetOperation(MethodInfo setOperationMethodInfo, InMemoryQueryExpression source2) { Check.DebugAssert(_groupingParameter == null, "Cannot apply set operation after GroupBy without flattening."); - if (_clientProjectionExpressions.Count == 0) + if (_clientProjections.Count == 0) { var projectionMapping = new Dictionary(); var source1SelectorExpressions = new List(); @@ -470,21 +434,7 @@ public virtual void ApplySetOperation(MethodInfo setOperationMethodInfo, InMemor } else { - Check.DebugAssert(_clientProjectionExpressions.Count == source2._clientProjectionExpressions.Count, - "Index count in both source should match."); - - // In case of client projections, indexes must match so we don't worry about it. - // We still have to formualte outer client projections again for nullability. - for (var i = 0; i < source2._clientProjectionExpressions.Count; i++) - { - var type1 = _clientProjectionExpressions[i].Type; - var type2 = source2._clientProjectionExpressions[i].Type; - if (!type1.IsNullableValueType() - && type2.IsNullableValueType()) - { - _clientProjectionExpressions[i] = MakeReadValueNullable(_clientProjectionExpressions[i]); - } - } + throw new InvalidOperationException(InMemoryStrings.SetOperationsNotAllowedAfterClientEvaluation); } ServerQueryExpression = Call( @@ -499,7 +449,7 @@ public virtual void ApplySetOperation(MethodInfo setOperationMethodInfo, InMemor /// public virtual void ApplyDefaultIfEmpty() { - if (_clientProjectionExpressions.Count != 0) + if (_clientProjections.Count != 0) { throw new InvalidOperationException(InMemoryStrings.DefaultIfEmptyAppliedAfterProjection); } @@ -507,27 +457,15 @@ public virtual void ApplyDefaultIfEmpty() var projectionMapping = new Dictionary(); foreach (var keyValuePair in _projectionMapping) { - if (keyValuePair.Value is EntityProjectionExpression entityProjection) - { - var map = new Dictionary(); - foreach (var property in GetAllPropertiesInHierarchy(entityProjection.EntityType)) - { - map[property] = MakeReadValueNullable(entityProjection.BindProperty(property)); - } - - projectionMapping[keyValuePair.Key] = new EntityProjectionExpression(entityProjection.EntityType, map); - } - else - { - projectionMapping[keyValuePair.Key] = MakeReadValueNullable(keyValuePair.Value); - } + projectionMapping[keyValuePair.Key] = keyValuePair.Value is EntityProjectionExpression entityProjectionExpression + ? MakeEntityProjectionNullable(entityProjectionExpression) + : (Expression)MakeReadValueNullable(keyValuePair.Value); } _projectionMapping = projectionMapping; var projectionMappingExpressions = _projectionMappingExpressions.Select(e => MakeReadValueNullable(e)).ToList(); _projectionMappingExpressions.Clear(); _projectionMappingExpressions.AddRange(projectionMappingExpressions); - _groupingParameter = null; ServerQueryExpression = Call( @@ -542,33 +480,58 @@ public virtual void ApplyDefaultIfEmpty() /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual void ApplyProjection() + public virtual void ApplyDistinct() { - if (_clientProjectionExpressions.Count == 0) + Check.DebugAssert(!_scalarServerQuery && _singleResultMethodInfo == null, "Cannot apply distinct on single result query"); + Check.DebugAssert(_groupingParameter == null, "Cannot apply distinct after GroupBy before flattening."); + + var selectorExpressions = new List(); + if (_clientProjections.Count == 0) { - var result = new Dictionary(); - foreach (var keyValuePair in _projectionMapping) + selectorExpressions.AddRange(_projectionMappingExpressions); + if (selectorExpressions.Count == 0) { - result[keyValuePair.Key] = keyValuePair.Value is EntityProjectionExpression entityProjection - ? Constant(AddToProjection(entityProjection)) - : Constant(AddToProjection(keyValuePair.Value)); + // No server correlated term in projection so add dummy 1. + selectorExpressions.Add(Constant(1)); } + } + else + { + for (var i = 0; i < _clientProjections.Count; i++) + { + var projection = _clientProjections[i]; + if (projection is InMemoryQueryExpression) + { + throw new InvalidOperationException(InMemoryStrings.DistinctOnSubqueryNotSupported); + } - _projectionMapping = result; + if (projection is EntityProjectionExpression entityProjectionExpression) + { + _clientProjections[i] = TraverseEntityProjection(selectorExpressions, entityProjectionExpression, makeNullable: false); + } + else + { + selectorExpressions.Add(projection); + _clientProjections[i] = CreateReadValueExpression( + projection.Type, selectorExpressions.Count - 1, InferPropertyFromInner(projection)); + + } + } } var selectorLambda = Lambda( - New( - _valueBufferConstructor, - NewArrayInit( - typeof(object), - _clientProjectionExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e).ToArray())), - CurrentParameter); + New( + _valueBufferConstructor, + NewArrayInit( + typeof(object), + selectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e).ToArray())), + CurrentParameter); ServerQueryExpression = Call( - EnumerableMethods.Select.MakeGenericMethod(typeof(ValueBuffer), typeof(ValueBuffer)), - ServerQueryExpression, - selectorLambda); + EnumerableMethods.Distinct.MakeGenericMethod(typeof(ValueBuffer)), + Call(EnumerableMethods.Select.MakeGenericMethod(CurrentParameter.Type, typeof(ValueBuffer)), + ServerQueryExpression, + selectorLambda)); } /// @@ -633,12 +596,13 @@ public virtual InMemoryGroupByShaperExpression ApplyGrouping( /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual void AddInnerJoin( + public virtual Expression AddInnerJoin( InMemoryQueryExpression innerQueryExpression, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, - Type transparentIdentifierType) - => AddJoin(innerQueryExpression, outerKeySelector, innerKeySelector, transparentIdentifierType, innerNullable: false); + Expression outerShaperExpression, + Expression innerShaperExpression) + => AddJoin(innerQueryExpression, outerKeySelector, innerKeySelector, outerShaperExpression, innerShaperExpression, innerNullable: false); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -646,12 +610,13 @@ public virtual void AddInnerJoin( /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual void AddLeftJoin( + public virtual Expression AddLeftJoin( InMemoryQueryExpression innerQueryExpression, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, - Type transparentIdentifierType) - => AddJoin(innerQueryExpression, outerKeySelector, innerKeySelector, transparentIdentifierType, innerNullable: true); + Expression outerShaperExpression, + Expression innerShaperExpression) + => AddJoin(innerQueryExpression, outerKeySelector, innerKeySelector, outerShaperExpression, innerShaperExpression, innerNullable: true); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -659,11 +624,12 @@ public virtual void AddLeftJoin( /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual void AddSelectMany( + public virtual Expression AddSelectMany( InMemoryQueryExpression innerQueryExpression, - Type transparentIdentifierType, + Expression outerShaperExpression, + Expression innerShaperExpression, bool innerNullable) - => AddJoin(innerQueryExpression, null, null, transparentIdentifierType, innerNullable); + => AddJoin(innerQueryExpression, null, null, outerShaperExpression, innerShaperExpression, innerNullable); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -678,29 +644,27 @@ public virtual EntityShaperExpression AddNavigationToWeakEntityType( LambdaExpression outerKeySelector, LambdaExpression innerKeySelector) { + Check.DebugAssert(_clientProjections.Count == 0, "Cannot expand weak entity navigation after client projection yet."); var innerNullable = !navigation.ForeignKey.IsRequiredDependent; var outerParameter = Parameter(typeof(ValueBuffer), "outer"); var innerParameter = Parameter(typeof(ValueBuffer), "inner"); var replacingVisitor = new ReplacingExpressionVisitor( new Expression[] { CurrentParameter, innerQueryExpression.CurrentParameter }, new Expression[] { outerParameter, innerParameter }); - var resultSelectorExpressions = _projectionMappingExpressions - .Select(e => replacingVisitor.Visit(e)) - .ToList(); - - var outerIndex = resultSelectorExpressions.Count; - var innerEntityProjection = (EntityProjectionExpression)innerQueryExpression.GetMappedProjection(new ProjectionMember()); + var selectorExpressions = _projectionMappingExpressions.Select(e => replacingVisitor.Visit(e)).ToList(); + var outerIndex = selectorExpressions.Count; + var innerEntityProjection = (EntityProjectionExpression)innerQueryExpression._projectionMapping[new ProjectionMember()]; var innerReadExpressionMap = new Dictionary(); foreach (var property in GetAllPropertiesInHierarchy(innerEntityProjection.EntityType)) { - var replacedExpression = replacingVisitor.Visit(innerEntityProjection.BindProperty(property)); + var propertyExpression = innerEntityProjection.BindProperty(property); if (innerNullable) { - replacedExpression = MakeReadValueNullable(replacedExpression); + propertyExpression = MakeReadValueNullable(propertyExpression); } - resultSelectorExpressions.Add(replacedExpression); - var readValueExperssion = CreateReadValueExpression(replacedExpression.Type, resultSelectorExpressions.Count - 1, property); + selectorExpressions.Add(propertyExpression); + var readValueExperssion = CreateReadValueExpression(propertyExpression.Type, selectorExpressions.Count - 1, property); innerReadExpressionMap[property] = readValueExperssion; _projectionMappingExpressions.Add(readValueExperssion); } @@ -710,7 +674,9 @@ public virtual EntityShaperExpression AddNavigationToWeakEntityType( var resultSelector = Lambda( New(_valueBufferConstructor, NewArrayInit(typeof(object), - resultSelectorExpressions.Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), + selectorExpressions + .Select(e => replacingVisitor.Visit(e)) + .Select(e => e.Type.IsValueType ? Convert(e, typeof(object)) : e))), outerParameter, innerParameter); @@ -725,7 +691,7 @@ public virtual EntityShaperExpression AddNavigationToWeakEntityType( innerKeySelector, resultSelector, Constant(new ValueBuffer( - Enumerable.Repeat((object?)null, innerQueryExpression._projectionMappingExpressions.Count).ToArray()))); + Enumerable.Repeat((object?)null, selectorExpressions.Count - outerIndex).ToArray()))); } else { @@ -745,6 +711,55 @@ public virtual EntityShaperExpression AddNavigationToWeakEntityType( return entityShaper; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression GetSingleScalarProjection() + { + var expression = CreateReadValueExpression(ServerQueryExpression.Type, 0, null); + _projectionMapping.Clear(); + _projectionMappingExpressions.Clear(); + _clientProjections.Clear(); + _projectionMapping[new ProjectionMember()] = expression; + _projectionMappingExpressions.Add(expression); + _groupingParameter = null; + + _scalarServerQuery = true; + ConvertToEnumerable(); + + return new ProjectionBindingExpression(this, new ProjectionMember(), expression.Type.MakeNullable()); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void ConvertToSingleResult(MethodInfo methodInfo) + { + _singleResultMethodInfo = methodInfo; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Type Type => typeof(IEnumerable); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public sealed override ExpressionType NodeType => ExpressionType.Extension; + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -765,14 +780,30 @@ void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) } expressionPrinter.AppendLine(); - expressionPrinter.AppendLine("ProjectionMapping:"); - using (expressionPrinter.Indent()) + if (_clientProjections.Count > 0) { - foreach (var projectionMapping in _projectionMapping) + expressionPrinter.AppendLine("ClientProjections:"); + using (expressionPrinter.Indent()) { - expressionPrinter.Append("Member: " + projectionMapping.Key + " Projection: "); - expressionPrinter.Visit(projectionMapping.Value); - expressionPrinter.AppendLine(","); + for (var i = 0; i < _clientProjections.Count; i++) + { + expressionPrinter.AppendLine(); + expressionPrinter.Append(i.ToString()).Append(" -> "); + expressionPrinter.Visit(_clientProjections[i]); + } + } + } + else + { + expressionPrinter.AppendLine("ProjectionMapping:"); + using (expressionPrinter.Indent()) + { + foreach (var projectionMapping in _projectionMapping) + { + expressionPrinter.Append("Member: " + projectionMapping.Key + " Projection: "); + expressionPrinter.Visit(projectionMapping.Value); + expressionPrinter.AppendLine(","); + } } } @@ -824,104 +855,150 @@ private Expression GetGroupingKey(Expression key, List groupingExpre } } - private void AddJoin( + private Expression AddJoin( InMemoryQueryExpression innerQueryExpression, LambdaExpression? outerKeySelector, LambdaExpression? innerKeySelector, - Type transparentIdentifierType, + Expression outerShaperExpression, + Expression innerShaperExpression, bool innerNullable) { + var transparentIdentifierType = TransparentIdentifierFactory.Create(outerShaperExpression.Type, innerShaperExpression.Type); + var outerMemberInfo = transparentIdentifierType.GetTypeInfo().GetRequiredDeclaredField("Outer"); + var innerMemberInfo = transparentIdentifierType.GetTypeInfo().GetRequiredDeclaredField("Inner"); + var outerClientEval = _clientProjections.Count > 0; + var innerClientEval = innerQueryExpression._clientProjections.Count > 0; + var resultSelectorExpressions = new List(); var outerParameter = Parameter(typeof(ValueBuffer), "outer"); var innerParameter = Parameter(typeof(ValueBuffer), "inner"); - var projectionMapping = new Dictionary(); var replacingVisitor = new ReplacingExpressionVisitor( new Expression[] { CurrentParameter, innerQueryExpression.CurrentParameter }, new Expression[] { outerParameter, innerParameter }); + int outerIndex; - var outerMemberInfo = transparentIdentifierType.GetTypeInfo().GetRequiredDeclaredField("Outer"); - var innerMemberInfo = transparentIdentifierType.GetTypeInfo().GetRequiredDeclaredField("Inner"); - foreach (var projection in _projectionMapping) + if (outerClientEval) { - if (projection.Value is EntityProjectionExpression entityProjection) + // Outer projection are already populated + if (innerClientEval) { - var readExpressionMap = new Dictionary(); - foreach (var property in GetAllPropertiesInHierarchy(entityProjection.EntityType)) + // Add inner to projection and update indexes + var indexMap = new int[innerQueryExpression._clientProjections.Count]; + for (var i = 0; i < innerQueryExpression._clientProjections.Count; i++) { - var replacedExpression = replacingVisitor.Visit(entityProjection.BindProperty(property)); - readExpressionMap[property] = CreateReadValueExpression( - replacedExpression.Type, GetIndex(replacedExpression), property); + var projectionToAdd = innerQueryExpression._clientProjections[i]; + projectionToAdd = MakeNullable(projectionToAdd, innerNullable); + _clientProjections.Add(projectionToAdd); + indexMap[i] = _clientProjections.Count - 1; } + innerQueryExpression._clientProjections.Clear(); - projectionMapping[projection.Key.Prepend(outerMemberInfo)] - = new EntityProjectionExpression(entityProjection.EntityType, readExpressionMap); + innerShaperExpression = new ProjectionIndexRemappingExpressionVisitor(innerQueryExpression, this, indexMap).Visit(innerShaperExpression); } else { - var replacedExpression = replacingVisitor.Visit(projection.Value); - projectionMapping[projection.Key.Prepend(outerMemberInfo)] = CreateReadValueExpression( - projection.Value.Type, GetIndex(replacedExpression), InferPropertyFromInner(projection.Value)); + // Apply inner projection mapping and convert projection member binding to indexes + var mapping = ConvertProjectionMappingToClientProjections(innerQueryExpression._projectionMapping, innerNullable); + innerShaperExpression = new ProjectionMemberToIndexConvertingExpressionVisitor(this, mapping).Visit(innerShaperExpression); } - } - var outerIndex = _projectionMappingExpressions.Count; - foreach (var projection in innerQueryExpression._projectionMapping) + // TODO: We still need to populate and generate result selector + // Further for a subquery in projection we may need to update correlation terms used inside it. + throw new NotImplementedException(); + } + else { - if (projection.Value is EntityProjectionExpression entityProjection) + if (innerClientEval) { - var readExpressionMap = new Dictionary(); - foreach (var property in GetAllPropertiesInHierarchy(entityProjection.EntityType)) + // Since inner proojections are populated, we need to populate outer also + var mapping = ConvertProjectionMappingToClientProjections(_projectionMapping); + outerShaperExpression = new ProjectionMemberToIndexConvertingExpressionVisitor(this, mapping).Visit(outerShaperExpression); + + var indexMap = new int[innerQueryExpression._clientProjections.Count]; + for (var i = 0; i < innerQueryExpression._clientProjections.Count; i++) { - var replacedExpression = replacingVisitor.Visit(entityProjection.BindProperty(property)); - if (innerNullable) - { - replacedExpression = MakeReadValueNullable(replacedExpression); - } - readExpressionMap[property] = CreateReadValueExpression( - replacedExpression.Type, GetIndex(replacedExpression) + outerIndex, property); + var projectionToAdd = innerQueryExpression._clientProjections[i]; + projectionToAdd = MakeNullable(projectionToAdd, innerNullable); + _clientProjections.Add(projectionToAdd); + indexMap[i] = _clientProjections.Count - 1; } + innerQueryExpression._clientProjections.Clear(); - projectionMapping[projection.Key.Prepend(innerMemberInfo)] - = new EntityProjectionExpression(entityProjection.EntityType, readExpressionMap); + innerShaperExpression = new ProjectionIndexRemappingExpressionVisitor(innerQueryExpression, this, indexMap).Visit(innerShaperExpression); + // TODO: We still need to populate and generate result selector + // Further for a subquery in projection we may need to update correlation terms used inside it. + throw new NotImplementedException(); } else { - var replacedExpression = replacingVisitor.Visit(projection.Value); - if (innerNullable) + var projectionMapping = new Dictionary(); + var mapping = new Dictionary(); + foreach (var projection in _projectionMapping) { - replacedExpression = MakeReadValueNullable(replacedExpression); + var newProjectionMember = projection.Key.Prepend(outerMemberInfo); + mapping[projection.Key] = newProjectionMember; + if (projection.Value is EntityProjectionExpression entityProjectionExpression) + { + projectionMapping[newProjectionMember] = TraverseEntityProjection( + resultSelectorExpressions, entityProjectionExpression, makeNullable: false); + } + else + { + resultSelectorExpressions.Add(projection.Value); + projectionMapping[newProjectionMember] = CreateReadValueExpression( + projection.Value.Type, resultSelectorExpressions.Count - 1, InferPropertyFromInner(projection.Value)); + } } - projectionMapping[projection.Key.Prepend(innerMemberInfo)] = CreateReadValueExpression( - replacedExpression.Type, GetIndex(replacedExpression) + outerIndex, InferPropertyFromInner(replacedExpression)); - } - } + outerShaperExpression = new ProjectionMemberRemappingExpressionVisitor(this, mapping).Visit(outerShaperExpression); + mapping.Clear(); - var resultSelectorExpressions = new List(); - foreach (var expression in _projectionMappingExpressions) - { - var updatedExpression = replacingVisitor.Visit(expression); - resultSelectorExpressions.Add( - updatedExpression.Type.IsValueType ? Convert(updatedExpression, typeof(object)) : updatedExpression); - } + outerIndex = resultSelectorExpressions.Count; + foreach (var projection in innerQueryExpression._projectionMapping) + { + var newProjectionMember = projection.Key.Prepend(innerMemberInfo); + mapping[projection.Key] = newProjectionMember; + if (projection.Value is EntityProjectionExpression entityProjectionExpression) + { + projectionMapping[newProjectionMember] = TraverseEntityProjection( + resultSelectorExpressions, entityProjectionExpression, innerNullable); + } + else + { + var expression = projection.Value; + if (innerNullable) + { + expression = MakeReadValueNullable(expression); + } + resultSelectorExpressions.Add(expression); + projectionMapping[newProjectionMember] = CreateReadValueExpression( + expression.Type, resultSelectorExpressions.Count - 1, InferPropertyFromInner(projection.Value)); + } + } + innerShaperExpression = new ProjectionMemberRemappingExpressionVisitor(this, mapping).Visit(innerShaperExpression); + mapping.Clear(); - foreach (var expression in innerQueryExpression._projectionMappingExpressions) - { - var replacedExpression = replacingVisitor.Visit(expression); - if (innerNullable) - { - replacedExpression = MakeReadValueNullable(replacedExpression); + _projectionMapping = projectionMapping; } - resultSelectorExpressions.Add( - replacedExpression.Type.IsValueType ? Convert(replacedExpression, typeof(object)) : replacedExpression); - - _projectionMappingExpressions.Add( - CreateReadValueExpression( - innerNullable ? expression.Type.MakeNullable() : expression.Type, - GetIndex(expression) + outerIndex, - InferPropertyFromInner(expression))); } var resultSelector = Lambda( - New(_valueBufferConstructor, NewArrayInit(typeof(object), resultSelectorExpressions)), + New( + _valueBufferConstructor, NewArrayInit(typeof(object), + resultSelectorExpressions.Select((e, i) => + { + var expression = replacingVisitor.Visit(e); + if (innerNullable + && i > outerIndex) + { + expression = MakeReadValueNullable(expression); + } + + if (expression.Type.IsValueType) + { + expression = Convert(expression, typeof(object)); + } + + return expression; + }))), outerParameter, innerParameter); @@ -939,7 +1016,7 @@ private void AddJoin( innerKeySelector, resultSelector, Constant(new ValueBuffer( - Enumerable.Repeat((object?)null, innerQueryExpression._projectionMappingExpressions.Count).ToArray()))); + Enumerable.Repeat((object?)null, resultSelectorExpressions.Count - outerIndex).ToArray()))); } else { @@ -965,16 +1042,54 @@ private void AddJoin( resultSelector); } - _projectionMapping = projectionMapping; + if (innerNullable) + { + innerShaperExpression = new EntityShaperNullableMarkingExpressionVisitor().Visit(innerShaperExpression); + } + + return New( + transparentIdentifierType.GetTypeInfo().DeclaredConstructors.Single(), + new[] { outerShaperExpression, innerShaperExpression }, outerMemberInfo, innerMemberInfo); + + static Expression MakeNullable(Expression expression, bool nullable) + => nullable + ? expression is EntityProjectionExpression entityProjection + ? MakeEntityProjectionNullable(entityProjection) + : MakeReadValueNullable(expression) + : expression; } - private static int GetIndex(Expression expression) - => (int)((ConstantExpression)((MethodCallExpression)expression).Arguments[1]).Value!; + private void ConvertToEnumerable() + { + if (ServerQueryExpression.Type.TryGetSequenceType() == null) + { + if (ServerQueryExpression.Type != typeof(ValueBuffer)) + { + if (ServerQueryExpression.Type.IsValueType) + { + ServerQueryExpression = Convert(ServerQueryExpression, typeof(object)); + } + + ServerQueryExpression = New( + _resultEnumerableConstructor, + Lambda>( + New( + _valueBufferConstructor, + NewArrayInit(typeof(object), ServerQueryExpression)))); + } + else + { + ServerQueryExpression = New( + _resultEnumerableConstructor, + Lambda>(ServerQueryExpression)); + } + } + } private MethodCallExpression CreateReadValueExpression(Type type, int index, IPropertyBase? property) => (MethodCallExpression)_valueBufferParameter.CreateValueBufferReadValueExpression(type, index, property); - private IEnumerable GetAllPropertiesInHierarchy(IEntityType entityType) + private static IEnumerable GetAllPropertiesInHierarchy(IEntityType entityType) => entityType.GetAllBaseTypes().Concat(entityType.GetDerivedTypesInclusive()) .SelectMany(t => t.GetDeclaredProperties()); @@ -985,6 +1100,82 @@ private IEnumerable GetAllPropertiesInHierarchy(IEntityType entityTyp ? methodCallExpression.Arguments[2].GetConstantValue() : null; + private static EntityProjectionExpression MakeEntityProjectionNullable(EntityProjectionExpression entityProjectionExpression) + { + var readExpressionMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) + { + readExpressionMap[property] = MakeReadValueNullable(entityProjectionExpression.BindProperty(property)); + } + + var result = new EntityProjectionExpression(entityProjectionExpression.EntityType, readExpressionMap); + + // Also compute nested entity projections + foreach (var navigation in entityProjectionExpression.EntityType.GetAllBaseTypes() + .Concat(entityProjectionExpression.EntityType.GetDerivedTypesInclusive()) + .SelectMany(t => t.GetDeclaredNavigations())) + { + var boundEntityShaperExpression = entityProjectionExpression.BindNavigation(navigation); + if (boundEntityShaperExpression != null) + { + var innerEntityProjection = (EntityProjectionExpression)boundEntityShaperExpression.ValueBufferExpression; + var newInnerEntityProjection = MakeEntityProjectionNullable(innerEntityProjection); + boundEntityShaperExpression = boundEntityShaperExpression.Update(newInnerEntityProjection); + result.AddNavigationBinding(navigation, boundEntityShaperExpression); + } + } + + return result; + } + + private Dictionary ConvertProjectionMappingToClientProjections( + Dictionary projectionMapping, + bool makeNullable = false) + { + var mapping = new Dictionary(); + var entityProjectionCache = new Dictionary(ReferenceEqualityComparer.Instance); + foreach (var projection in projectionMapping) + { + var projectionMember = projection.Key; + var projectionToAdd = projection.Value; + + if (projectionToAdd is EntityProjectionExpression entityProjection) + { + if (!entityProjectionCache.TryGetValue(entityProjection, out var value)) + { + var entityProjectionToCache = entityProjection; + if (makeNullable) + { + entityProjection = MakeEntityProjectionNullable(entityProjection); + } + _clientProjections.Add(entityProjection); + value = _clientProjections.Count - 1; + entityProjectionCache[entityProjectionToCache] = value; + } + + mapping[projectionMember] = value; + } + else + { + if (makeNullable) + { + projectionToAdd = MakeReadValueNullable(projectionToAdd); + } + var existingIndex = _clientProjections.FindIndex(e => e.Equals(projectionToAdd)); + if (existingIndex == -1) + { + _clientProjections.Add(projectionToAdd); + existingIndex = _clientProjections.Count - 1; + } + mapping[projectionMember] = existingIndex; + } + } + + projectionMapping.Clear(); + + return mapping; + } + private static IEnumerable LeftJoin( IEnumerable outer, IEnumerable inner, @@ -995,7 +1186,7 @@ private static IEnumerable LeftJoin( => outer.GroupJoin(inner, outerKeySelector, innerKeySelector, (oe, ies) => new { oe, ies }) .SelectMany(t => t.ies.DefaultIfEmpty(defaultValue), (t, i) => resultSelector(t.oe, i)); - private MethodCallExpression MakeReadValueNullable(Expression expression) + private static MethodCallExpression MakeReadValueNullable(Expression expression) { Check.DebugAssert(expression is MethodCallExpression, "Expression must be method call expression."); @@ -1008,32 +1199,40 @@ private MethodCallExpression MakeReadValueNullable(Expression expression) methodCallExpression.Arguments); } - private sealed class ShaperRemappingExpressionVisitor : ExpressionVisitor + private EntityProjectionExpression TraverseEntityProjection( + List selectorExpressions, EntityProjectionExpression entityProjectionExpression, bool makeNullable) { - private readonly IDictionary _projectionMapping; - - public ShaperRemappingExpressionVisitor(IDictionary projectionMapping) + var readExpressionMap = new Dictionary(); + foreach (var property in GetAllPropertiesInHierarchy(entityProjectionExpression.EntityType)) { - _projectionMapping = projectionMapping; + var expression = entityProjectionExpression.BindProperty(property); + if (makeNullable) + { + expression = MakeReadValueNullable(expression); + } + selectorExpressions.Add(expression); + var newExpression = CreateReadValueExpression(expression.Type, selectorExpressions.Count - 1, property); + readExpressionMap[property] = newExpression; } - [return: NotNullIfNotNull("expression")] - public override Expression? Visit(Expression? expression) + var result = new EntityProjectionExpression(entityProjectionExpression.EntityType, readExpressionMap); + + // Also compute nested entity projections + foreach (var navigation in entityProjectionExpression.EntityType.GetAllBaseTypes() + .Concat(entityProjectionExpression.EntityType.GetDerivedTypesInclusive()) + .SelectMany(t => t.GetDeclaredNavigations())) { - if (expression is ProjectionBindingExpression projectionBindingExpression - && projectionBindingExpression.ProjectionMember != null) + var boundEntityShaperExpression = entityProjectionExpression.BindNavigation(navigation); + if (boundEntityShaperExpression != null) { - var mappingValue = ((ConstantExpression)_projectionMapping[projectionBindingExpression.ProjectionMember]).Value; - return mappingValue is IReadOnlyDictionary indexMap - ? new ProjectionBindingExpression(projectionBindingExpression.QueryExpression, indexMap) - : mappingValue is int index - ? new ProjectionBindingExpression( - projectionBindingExpression.QueryExpression, index, projectionBindingExpression.Type) - : throw new InvalidOperationException(CoreStrings.UnknownEntity("ProjectionMapping")); + var innerEntityProjection = (EntityProjectionExpression)boundEntityShaperExpression.ValueBufferExpression; + var newInnerEntityProjection = TraverseEntityProjection(selectorExpressions, innerEntityProjection, makeNullable); + boundEntityShaperExpression = boundEntityShaperExpression.Update(newInnerEntityProjection); + result.AddNavigationBinding(navigation, boundEntityShaperExpression); } - - return base.Visit(expression); } + + return result; } } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index 045e43f7be9..b06a998fff0 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -152,7 +152,7 @@ private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityTy if (source.ShaperExpression is GroupByShaperExpression) { - inMemoryQueryExpression.ReplaceProjectionMapping(new Dictionary()); + inMemoryQueryExpression.ReplaceProjection(new Dictionary()); } inMemoryQueryExpression.UpdateServerQueryExpression( @@ -186,7 +186,7 @@ private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityTy if (source.ShaperExpression is GroupByShaperExpression) { - inMemoryQueryExpression.ReplaceProjectionMapping(new Dictionary()); + inMemoryQueryExpression.ReplaceProjection(new Dictionary()); } inMemoryQueryExpression.UpdateServerQueryExpression( @@ -268,7 +268,9 @@ private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityTy EnumerableMethods.Select.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type, item.Type), inMemoryQueryExpression.ServerQueryExpression, Expression.Lambda( - inMemoryQueryExpression.GetMappedProjection(new ProjectionMember()), inMemoryQueryExpression.CurrentParameter)), + inMemoryQueryExpression.GetProjection( + new ProjectionBindingExpression(inMemoryQueryExpression, new ProjectionMember(), item.Type)), + inMemoryQueryExpression.CurrentParameter)), item)); return source.UpdateShaperExpression(Expression.Convert(inMemoryQueryExpression.GetSingleScalarProjection(), typeof(bool))); @@ -298,7 +300,7 @@ private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityTy if (source.ShaperExpression is GroupByShaperExpression) { - inMemoryQueryExpression.ReplaceProjectionMapping(new Dictionary()); + inMemoryQueryExpression.ReplaceProjection(new Dictionary()); } inMemoryQueryExpression.UpdateServerQueryExpression( @@ -338,12 +340,7 @@ private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityTy { Check.NotNull(source, nameof(source)); - var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression; - - inMemoryQueryExpression.UpdateServerQueryExpression( - Expression.Call( - EnumerableMethods.Distinct.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), - inMemoryQueryExpression.ServerQueryExpression)); + ((InMemoryQueryExpression)source.QueryExpression).ApplyDistinct(); return source; } @@ -575,23 +572,16 @@ private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityTy } (outerKeySelector, innerKeySelector) = (newOuterKeySelector, newInnerKeySelector); - var transparentIdentifierType = TransparentIdentifierFactory.Create( - resultSelector.Parameters[0].Type, - resultSelector.Parameters[1].Type); - - ((InMemoryQueryExpression)outer.QueryExpression).AddInnerJoin( + var outerShaperExpression = ((InMemoryQueryExpression)outer.QueryExpression).AddInnerJoin( (InMemoryQueryExpression)inner.QueryExpression, outerKeySelector, innerKeySelector, - transparentIdentifierType); - -#pragma warning disable CS0618 // Type or member is obsolete See issue#21200 - return TranslateResultSelectorForJoin( - outer, - resultSelector, - inner.ShaperExpression, - transparentIdentifierType); -#pragma warning restore CS0618 // Type or member is obsolete + outer.ShaperExpression, + inner.ShaperExpression); + + outer = outer.UpdateShaperExpression(outerShaperExpression); + + return TranslateTwoParameterSelector(outer, resultSelector); } private (LambdaExpression? OuterKeySelector, LambdaExpression? InnerKeySelector) ProcessJoinKeySelector( @@ -754,23 +744,16 @@ static bool IsConvertedToNullable(Expression outer, Expression inner) (outerKeySelector, innerKeySelector) = (newOuterKeySelector, newInnerKeySelector); - var transparentIdentifierType = TransparentIdentifierFactory.Create( - resultSelector.Parameters[0].Type, - resultSelector.Parameters[1].Type); - - ((InMemoryQueryExpression)outer.QueryExpression).AddLeftJoin( + var outerShaperExpression = ((InMemoryQueryExpression)outer.QueryExpression).AddLeftJoin( (InMemoryQueryExpression)inner.QueryExpression, outerKeySelector, innerKeySelector, - transparentIdentifierType); - -#pragma warning disable CS0618 // Type or member is obsolete See issue#21200 - return TranslateResultSelectorForJoin( - outer, - resultSelector, - MarkShaperNullable(inner.ShaperExpression), - transparentIdentifierType); -#pragma warning restore CS0618 // Type or member is obsolete + outer.ShaperExpression, + inner.ShaperExpression); + + outer = outer.UpdateShaperExpression(outerShaperExpression); + + return TranslateTwoParameterSelector(outer, resultSelector); } /// @@ -798,7 +781,7 @@ static bool IsConvertedToNullable(Expression outer, Expression inner) if (source.ShaperExpression is GroupByShaperExpression) { - inMemoryQueryExpression.ReplaceProjectionMapping(new Dictionary()); + inMemoryQueryExpression.ReplaceProjection(new Dictionary()); } inMemoryQueryExpression.UpdateServerQueryExpression( @@ -881,8 +864,8 @@ static bool IsConvertedToNullable(Expression outer, Expression inner) var projectionMember = projectionBindingExpression.ProjectionMember; Check.DebugAssert(new ProjectionMember().Equals(projectionMember), "Invalid ProjectionMember when processing OfType"); - var entityProjectionExpression = (EntityProjectionExpression)inMemoryQueryExpression.GetMappedProjection(projectionMember); - inMemoryQueryExpression.ReplaceProjectionMapping( + var entityProjectionExpression = (EntityProjectionExpression)inMemoryQueryExpression.GetProjection(projectionBindingExpression); + inMemoryQueryExpression.ReplaceProjection( new Dictionary { { projectionMember, entityProjectionExpression.UpdateEntityType(derivedType) } @@ -993,24 +976,12 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s if (Visit(collectionSelectorBody) is ShapedQueryExpression inner) { - var transparentIdentifierType = TransparentIdentifierFactory.Create( - resultSelector.Parameters[0].Type, - resultSelector.Parameters[1].Type); - - var innerShaperExpression = defaultIfEmpty - ? MarkShaperNullable(inner.ShaperExpression) - : inner.ShaperExpression; - - ((InMemoryQueryExpression)source.QueryExpression).AddSelectMany( - (InMemoryQueryExpression)inner.QueryExpression, transparentIdentifierType, defaultIfEmpty); - -#pragma warning disable CS0618 // Type or member is obsolete See issue#21200 - return TranslateResultSelectorForJoin( - source, - resultSelector, - innerShaperExpression, - transparentIdentifierType); -#pragma warning restore CS0618 // Type or member is obsolete + var outerShaperExpression = ((InMemoryQueryExpression)source.QueryExpression).AddSelectMany( + (InMemoryQueryExpression)inner.QueryExpression, source.ShaperExpression, inner.ShaperExpression, defaultIfEmpty); + + source = source.UpdateShaperExpression(outerShaperExpression); + + return TranslateTwoParameterSelector(source, resultSelector); } return null; @@ -1454,21 +1425,9 @@ outerKey is NewArrayExpression newArrayExpression return innerShapedQuery; } - EntityProjectionExpression entityProjectionExpression; - if (entityShaperExpression.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression) - { - if (projectionBindingExpression.ProjectionMember == null) - { - throw new InvalidOperationException(); - } - entityProjectionExpression = (EntityProjectionExpression)_queryExpression.GetMappedProjection( - projectionBindingExpression.ProjectionMember); - } - else - { - entityProjectionExpression = (EntityProjectionExpression)entityShaperExpression.ValueBufferExpression; - } - + var entityProjectionExpression = entityShaperExpression.ValueBufferExpression is ProjectionBindingExpression projectionBindingExpression + ? (EntityProjectionExpression)_queryExpression.GetProjection(projectionBindingExpression) + : (EntityProjectionExpression)entityShaperExpression.ValueBufferExpression; var innerShaper = entityProjectionExpression.BindNavigation(navigation); if (innerShaper == null) { @@ -1514,6 +1473,30 @@ private static Expression AddConvertToObject(Expression expression) : expression; } + private ShapedQueryExpression TranslateTwoParameterSelector(ShapedQueryExpression source, LambdaExpression resultSelector) + { + var transparentIdentifierType = source.ShaperExpression.Type; + var transparentIdentifierParameter = Expression.Parameter(transparentIdentifierType); + + Expression original1 = resultSelector.Parameters[0]; + var replacement1 = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Outer"); + Expression original2 = resultSelector.Parameters[1]; + var replacement2 = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Inner"); + var newResultSelector = Expression.Lambda( + new ReplacingExpressionVisitor( + new[] { original1, original2 }, new[] { replacement1, replacement2 }) + .Visit(resultSelector.Body), + transparentIdentifierParameter); + + return TranslateSelect(source, newResultSelector); + } + + private static Expression AccessField( + Type transparentIdentifierType, + Expression targetExpression, + string fieldName) + => Expression.Field(targetExpression, transparentIdentifierType.GetRequiredDeclaredField(fieldName)); + private ShapedQueryExpression? TranslateScalarAggregate( ShapedQueryExpression source, LambdaExpression? selector, @@ -1525,7 +1508,8 @@ private static Expression AddConvertToObject(Expression expression) selector = selector == null || selector.Body == selector.Parameters[0] ? Expression.Lambda( - inMemoryQueryExpression.GetMappedProjection(new ProjectionMember()), + inMemoryQueryExpression.GetProjection(new ProjectionBindingExpression( + inMemoryQueryExpression, new ProjectionMember(), returnType)), inMemoryQueryExpression.CurrentParameter) : TranslateLambdaExpression(source, selector, preserveType: true); @@ -1573,12 +1557,7 @@ MethodInfo GetMethod() source = newSource; } - inMemoryQueryExpression.UpdateServerQueryExpression( - Expression.Call( - method.MakeGenericMethod(inMemoryQueryExpression.CurrentParameter.Type), - inMemoryQueryExpression.ServerQueryExpression)); - - inMemoryQueryExpression.ConvertToEnumerable(); + inMemoryQueryExpression.ConvertToSingleResult(method); return source.ShaperExpression.Type != returnType ? source.UpdateShaperExpression(Expression.Convert(source.ShaperExpression, returnType)) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs deleted file mode 100644 index 89ef18b7494..00000000000 --- a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.CustomShaperCompilingExpressionVisitor.cs +++ /dev/null @@ -1,278 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Linq.Expressions; -using System.Reflection; -using Microsoft.EntityFrameworkCore.Infrastructure; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Query; -using Microsoft.EntityFrameworkCore.Storage; -using Microsoft.EntityFrameworkCore.Utilities; - -namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal -{ - public partial class InMemoryShapedQueryCompilingExpressionVisitor - { - private sealed class CustomShaperCompilingExpressionVisitor : ExpressionVisitor - { - private readonly bool _tracking; - - public CustomShaperCompilingExpressionVisitor(bool tracking) - { - _tracking = tracking; - } - - private static readonly MethodInfo _includeReferenceMethodInfo - = typeof(CustomShaperCompilingExpressionVisitor).GetRequiredDeclaredMethod(nameof(IncludeReference)); - - private static readonly MethodInfo _includeCollectionMethodInfo - = typeof(CustomShaperCompilingExpressionVisitor).GetRequiredDeclaredMethod(nameof(IncludeCollection)); - - private static readonly MethodInfo _materializeCollectionMethodInfo - = typeof(CustomShaperCompilingExpressionVisitor).GetRequiredDeclaredMethod(nameof(MaterializeCollection)); - - private static readonly MethodInfo _materializeSingleResultMethodInfo - = typeof(CustomShaperCompilingExpressionVisitor).GetRequiredDeclaredMethod(nameof(MaterializeSingleResult)); - - private static void IncludeReference( - QueryContext queryContext, - TEntity entity, - TIncludedEntity relatedEntity, - INavigationBase navigation, - INavigationBase inverseNavigation, - Action fixup, - bool trackingQuery) - where TIncludingEntity : class, TEntity - where TEntity : class - where TIncludedEntity : class - { - if (entity is TIncludingEntity includingEntity) - { - if (trackingQuery - && navigation.DeclaringEntityType.FindPrimaryKey() != null) - { - // For non-null relatedEntity StateManager will set the flag - if (relatedEntity == null) - { - queryContext.SetNavigationIsLoaded(includingEntity, navigation); - } - } - else - { - navigation.SetIsLoadedWhenNoTracking(includingEntity); - if (relatedEntity != null) - { - fixup(includingEntity, relatedEntity); - if (inverseNavigation != null - && !inverseNavigation.IsCollection) - { - inverseNavigation.SetIsLoadedWhenNoTracking(relatedEntity); - } - } - } - } - } - - private static void IncludeCollection( - QueryContext queryContext, - IEnumerable innerValueBuffers, - Func innerShaper, - TEntity entity, - INavigationBase navigation, - INavigationBase inverseNavigation, - Action fixup, - bool trackingQuery, - bool setLoaded) - where TIncludingEntity : class, TEntity - where TEntity : class - where TIncludedEntity : class - { - if (entity is TIncludingEntity includingEntity) - { - var collectionAccessor = navigation.GetCollectionAccessor()!; - collectionAccessor.GetOrCreate(includingEntity, forMaterialization: true); - - if (setLoaded) - { - if (trackingQuery) - { - queryContext.SetNavigationIsLoaded(entity, navigation); - } - else - { - navigation.SetIsLoadedWhenNoTracking(entity); - } - } - - foreach (var valueBuffer in innerValueBuffers) - { - var relatedEntity = innerShaper(queryContext, valueBuffer); - - if (!trackingQuery) - { - fixup(includingEntity, relatedEntity); - if (inverseNavigation != null) - { - inverseNavigation.SetIsLoadedWhenNoTracking(relatedEntity); - } - } - } - } - } - - private static TCollection MaterializeCollection( - QueryContext queryContext, - IEnumerable innerValueBuffers, - Func innerShaper, - IClrCollectionAccessor clrCollectionAccessor) - where TCollection : class, ICollection - { - var collection = (TCollection)(clrCollectionAccessor?.Create() ?? new List()); - - foreach (var valueBuffer in innerValueBuffers) - { - var element = innerShaper(queryContext, valueBuffer); - collection.Add(element); - } - - return collection; - } - - private static TResult MaterializeSingleResult( - QueryContext queryContext, - ValueBuffer valueBuffer, - Func innerShaper) - => valueBuffer.IsEmpty - ? default! - : innerShaper(queryContext, valueBuffer); - - protected override Expression VisitExtension(Expression extensionExpression) - { - Check.NotNull(extensionExpression, nameof(extensionExpression)); - - if (extensionExpression is IncludeExpression includeExpression) - { - var entityClrType = includeExpression.EntityExpression.Type; - var includingClrType = includeExpression.Navigation.DeclaringEntityType.ClrType; - var inverseNavigation = includeExpression.Navigation.Inverse; - var relatedEntityClrType = includeExpression.Navigation.TargetEntityType.ClrType; - if (includingClrType != entityClrType - && includingClrType.IsAssignableFrom(entityClrType)) - { - includingClrType = entityClrType; - } - - if (includeExpression.Navigation.IsCollection) - { - var collectionShaper = (CollectionShaperExpression)includeExpression.NavigationExpression; - return Expression.Call( - _includeCollectionMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType), - QueryCompilationContext.QueryContextParameter, - collectionShaper.Projection, - Expression.Constant(((LambdaExpression)Visit(collectionShaper.InnerShaper)).Compile()), - includeExpression.EntityExpression, - Expression.Constant(includeExpression.Navigation), - Expression.Constant(inverseNavigation, typeof(INavigationBase)), - Expression.Constant( - GenerateFixup( - includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation).Compile()), - Expression.Constant(_tracking), -#pragma warning disable EF1001 // Internal EF Core API usage. - Expression.Constant(includeExpression.SetLoaded)); -#pragma warning restore EF1001 // Internal EF Core API usage. - } - - return Expression.Call( - _includeReferenceMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType), - QueryCompilationContext.QueryContextParameter, - includeExpression.EntityExpression, - includeExpression.NavigationExpression, - Expression.Constant(includeExpression.Navigation), - Expression.Constant(inverseNavigation, typeof(INavigationBase)), - Expression.Constant( - GenerateFixup( - includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation).Compile()), - Expression.Constant(_tracking)); - } - - if (extensionExpression is CollectionShaperExpression collectionShaperExpression) - { - var navigation = collectionShaperExpression.Navigation; - var collectionAccessor = navigation?.GetCollectionAccessor(); - var collectionType = collectionAccessor?.CollectionType ?? collectionShaperExpression.Type; - var elementType = collectionShaperExpression.ElementType; - - return Expression.Call( - _materializeCollectionMethodInfo.MakeGenericMethod(elementType, collectionType), - QueryCompilationContext.QueryContextParameter, - collectionShaperExpression.Projection, - Expression.Constant(((LambdaExpression)Visit(collectionShaperExpression.InnerShaper)).Compile()), - Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor))); - } - - if (extensionExpression is SingleResultShaperExpression singleResultShaperExpression) - { - var innerShaper = (LambdaExpression)Visit(singleResultShaperExpression.InnerShaper); - - return Expression.Call( - _materializeSingleResultMethodInfo.MakeGenericMethod(singleResultShaperExpression.Type), - QueryCompilationContext.QueryContextParameter, - singleResultShaperExpression.Projection, - Expression.Constant(innerShaper.Compile())); - } - - return base.VisitExtension(extensionExpression); - } - - private static LambdaExpression GenerateFixup( - Type entityType, - Type relatedEntityType, - INavigationBase navigation, - INavigationBase? inverseNavigation) - { - var entityParameter = Expression.Parameter(entityType); - var relatedEntityParameter = Expression.Parameter(relatedEntityType); - var expressions = new List - { - navigation.IsCollection - ? AddToCollectionNavigation(entityParameter, relatedEntityParameter, navigation) - : AssignReferenceNavigation(entityParameter, relatedEntityParameter, navigation) - }; - - if (inverseNavigation != null) - { - expressions.Add( - inverseNavigation.IsCollection - ? AddToCollectionNavigation(relatedEntityParameter, entityParameter, inverseNavigation) - : AssignReferenceNavigation(relatedEntityParameter, entityParameter, inverseNavigation)); - } - - return Expression.Lambda(Expression.Block(typeof(void), expressions), entityParameter, relatedEntityParameter); - } - - private static Expression AssignReferenceNavigation( - ParameterExpression entity, - ParameterExpression relatedEntity, - INavigationBase navigation) - { - return entity.MakeMemberAccess(navigation.GetMemberInfo(forMaterialization: true, forSet: true)).Assign(relatedEntity); - } - - private static Expression AddToCollectionNavigation( - ParameterExpression entity, - ParameterExpression relatedEntity, - INavigationBase navigation) - => Expression.Call( - Expression.Constant(navigation.GetCollectionAccessor()), - _collectionAccessorAddMethodInfo, - entity, - relatedEntity, - Expression.Constant(true)); - - private static readonly MethodInfo _collectionAccessorAddMethodInfo - = typeof(IClrCollectionAccessor).GetRequiredDeclaredMethod(nameof(IClrCollectionAccessor.Add)); - } - } -} diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs deleted file mode 100644 index 46acbffd280..00000000000 --- a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.InMemoryProjectionBindingRemovingExpressionVisitor.cs +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Linq.Expressions; -using System.Reflection; -using Microsoft.EntityFrameworkCore.Infrastructure; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Query; -using Microsoft.EntityFrameworkCore.Storage; -using Microsoft.EntityFrameworkCore.Utilities; -using ExpressionExtensions = Microsoft.EntityFrameworkCore.Infrastructure.ExpressionExtensions; - -namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal -{ - public partial class InMemoryShapedQueryCompilingExpressionVisitor - { - private sealed class InMemoryProjectionBindingRemovingExpressionVisitor : ExpressionVisitor - { - private readonly IDictionary IndexMap, ParameterExpression valueBuffer)> - _materializationContextBindings - = new Dictionary IndexMap, ParameterExpression valueBuffer)>(); - - protected override Expression VisitBinary(BinaryExpression binaryExpression) - { - Check.NotNull(binaryExpression, nameof(binaryExpression)); - - if (binaryExpression.NodeType == ExpressionType.Assign - && binaryExpression.Left is ParameterExpression parameterExpression - && parameterExpression.Type == typeof(MaterializationContext)) - { - var newExpression = (NewExpression)binaryExpression.Right; - - var projectionBindingExpression = (ProjectionBindingExpression)newExpression.Arguments[0]; - var queryExpression = (InMemoryQueryExpression)projectionBindingExpression.QueryExpression; - - _materializationContextBindings[parameterExpression] - = ((IDictionary)GetProjectionIndex(queryExpression, projectionBindingExpression), - ((InMemoryQueryExpression)projectionBindingExpression.QueryExpression).CurrentParameter); - - var updatedExpression = newExpression.Update( - new[] { Expression.Constant(ValueBuffer.Empty), newExpression.Arguments[1] }); - - return Expression.MakeBinary(ExpressionType.Assign, binaryExpression.Left, updatedExpression); - } - - if (binaryExpression.NodeType == ExpressionType.Assign - && binaryExpression.Left is MemberExpression memberExpression - && memberExpression.Member is FieldInfo fieldInfo - && fieldInfo.IsInitOnly) - { - return memberExpression.Assign(Visit(binaryExpression.Right)); - } - - return base.VisitBinary(binaryExpression); - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - - if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod) - { - var property = methodCallExpression.Arguments[2].GetConstantValue(); - var (indexMap, valueBuffer) = - _materializationContextBindings[ - (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object!]; - - Check.DebugAssert( - property != null || methodCallExpression.Type.IsNullableType(), "Must read nullable value without property"); - - return Expression.Call( - methodCallExpression.Method, - valueBuffer, - Expression.Constant(indexMap[property!]), - methodCallExpression.Arguments[2]); - } - - return base.VisitMethodCall(methodCallExpression); - } - - protected override Expression VisitExtension(Expression extensionExpression) - { - Check.NotNull(extensionExpression, nameof(extensionExpression)); - - if (extensionExpression is ProjectionBindingExpression projectionBindingExpression) - { - var queryExpression = (InMemoryQueryExpression)projectionBindingExpression.QueryExpression; - var projectionIndex = (int)GetProjectionIndex(queryExpression, projectionBindingExpression); - var valueBuffer = queryExpression.CurrentParameter; - var property = InferPropertyFromInner(queryExpression.Projection[projectionIndex]); - - Check.DebugAssert( - property != null - || projectionBindingExpression.Type.IsNullableType() - || projectionBindingExpression.Type == typeof(ValueBuffer), "Must read nullable value without property"); - - return valueBuffer.CreateValueBufferReadValueExpression(projectionBindingExpression.Type, projectionIndex, property); - } - - return base.VisitExtension(extensionExpression); - } - - private IPropertyBase? InferPropertyFromInner(Expression expression) - { - if (expression is MethodCallExpression methodCallExpression - && methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == ExpressionExtensions.ValueBufferTryReadValueMethod) - { - return methodCallExpression.Arguments[2].GetConstantValue(); - } - - return null; - } - - private object GetProjectionIndex( - InMemoryQueryExpression queryExpression, - ProjectionBindingExpression projectionBindingExpression) - { - return projectionBindingExpression.ProjectionMember != null - ? queryExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember).GetConstantValue() - : (projectionBindingExpression.Index != null - ? (object)projectionBindingExpression.Index - : projectionBindingExpression.IndexMap!); - } - } - } -} diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.ShaperExpressionProcessingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.ShaperExpressionProcessingExpressionVisitor.cs new file mode 100644 index 00000000000..408e9f5aced --- /dev/null +++ b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.ShaperExpressionProcessingExpressionVisitor.cs @@ -0,0 +1,420 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Storage; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal +{ + public partial class InMemoryShapedQueryCompilingExpressionVisitor + { + private sealed class ShaperExpressionProcessingExpressionVisitor : ExpressionVisitor + { + private static readonly MethodInfo _includeReferenceMethodInfo + = typeof(ShaperExpressionProcessingExpressionVisitor).GetRequiredDeclaredMethod(nameof(IncludeReference)); + private static readonly MethodInfo _includeCollectionMethodInfo + = typeof(ShaperExpressionProcessingExpressionVisitor).GetRequiredDeclaredMethod(nameof(IncludeCollection)); + private static readonly MethodInfo _materializeCollectionMethodInfo + = typeof(ShaperExpressionProcessingExpressionVisitor).GetRequiredDeclaredMethod(nameof(MaterializeCollection)); + private static readonly MethodInfo _materializeSingleResultMethodInfo + = typeof(ShaperExpressionProcessingExpressionVisitor).GetRequiredDeclaredMethod(nameof(MaterializeSingleResult)); + private static readonly MethodInfo _collectionAccessorAddMethodInfo + = typeof(IClrCollectionAccessor).GetRequiredDeclaredMethod(nameof(IClrCollectionAccessor.Add)); + + private readonly InMemoryShapedQueryCompilingExpressionVisitor _inMemoryShapedQueryCompilingExpressionVisitor; + private readonly bool _tracking; + private ParameterExpression? _valueBufferParameter; + + private readonly Dictionary _mapping = new(); + private readonly List _variables = new(); + private readonly List _expressions = new(); + private readonly Dictionary> _materializationContextBindings = new(); + + public ShaperExpressionProcessingExpressionVisitor( + InMemoryShapedQueryCompilingExpressionVisitor inMemoryShapedQueryCompilingExpressionVisitor, + InMemoryQueryExpression inMemoryQueryExpression, + bool tracking) + { + _inMemoryShapedQueryCompilingExpressionVisitor = inMemoryShapedQueryCompilingExpressionVisitor; + _valueBufferParameter = inMemoryQueryExpression.CurrentParameter; + _tracking = tracking; + } + + private ShaperExpressionProcessingExpressionVisitor( + InMemoryShapedQueryCompilingExpressionVisitor inMemoryShapedQueryCompilingExpressionVisitor, + bool tracking) + { + _inMemoryShapedQueryCompilingExpressionVisitor = inMemoryShapedQueryCompilingExpressionVisitor; + _tracking = tracking; + } + + public LambdaExpression ProcessShaper(Expression shaperExpression) + { + var result = Visit(shaperExpression); + _expressions.Add(result); + result = Expression.Block(_variables, _expressions); + + if (_valueBufferParameter == null) + { + // If parameter is null then the projection is not really server correlated so we can just put anything. + _valueBufferParameter = Expression.Parameter(typeof(ValueBuffer)); + } + + return Expression.Lambda(result, QueryCompilationContext.QueryContextParameter, _valueBufferParameter); + } + + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) + { + case EntityShaperExpression entityShaperExpression: + { + var key = entityShaperExpression.ValueBufferExpression; + if (!_mapping.TryGetValue(key, out var variable)) + { + variable = Expression.Parameter(entityShaperExpression.EntityType.ClrType); + _variables.Add(variable); + var innerShaper = _inMemoryShapedQueryCompilingExpressionVisitor.InjectEntityMaterializers(entityShaperExpression); + innerShaper = Visit(innerShaper); + _expressions.Add(Expression.Assign(variable, innerShaper)); + _mapping[key] = variable; + } + + return variable; + } + + case ProjectionBindingExpression projectionBindingExpression: + { + var key = projectionBindingExpression; + if (!_mapping.TryGetValue(key, out var variable)) + { + variable = Expression.Parameter(projectionBindingExpression.Type); + _variables.Add(variable); + var queryExpression = (InMemoryQueryExpression)projectionBindingExpression.QueryExpression; + if (_valueBufferParameter == null) + { + _valueBufferParameter = queryExpression.CurrentParameter; + } + var projectionIndex = queryExpression.GetProjection(projectionBindingExpression).GetConstantValue(); + + // We don't need to pass property when reading at top-level + _expressions.Add(Expression.Assign( + variable, queryExpression.CurrentParameter.CreateValueBufferReadValueExpression( + projectionBindingExpression.Type, projectionIndex, property: null))); + _mapping[key] = variable; + } + + return variable; + } + + case IncludeExpression includeExpression: + { + var entity = Visit(includeExpression.EntityExpression); + var entityClrType = includeExpression.EntityExpression.Type; + var includingClrType = includeExpression.Navigation.DeclaringEntityType.ClrType; + var inverseNavigation = includeExpression.Navigation.Inverse; + var relatedEntityClrType = includeExpression.Navigation.TargetEntityType.ClrType; + if (includingClrType != entityClrType + && includingClrType.IsAssignableFrom(entityClrType)) + { + includingClrType = entityClrType; + } + + if (includeExpression.Navigation.IsCollection) + { + var collectionResultShaperExpression = (CollectionResultShaperExpression)includeExpression.NavigationExpression; + var shaperLambda = new ShaperExpressionProcessingExpressionVisitor(_inMemoryShapedQueryCompilingExpressionVisitor, _tracking) + .ProcessShaper(collectionResultShaperExpression.InnerShaper); + _expressions.Add( + Expression.Call( + _includeCollectionMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType), + QueryCompilationContext.QueryContextParameter, + Visit(collectionResultShaperExpression.Projection), + Expression.Constant(shaperLambda.Compile()), + entity, + Expression.Constant(includeExpression.Navigation), + Expression.Constant(inverseNavigation, typeof(INavigationBase)), + Expression.Constant( + GenerateFixup( + includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation).Compile()), + Expression.Constant(_tracking), +#pragma warning disable EF1001 // Internal EF Core API usage. + Expression.Constant(includeExpression.SetLoaded))); +#pragma warning restore EF1001 // Internal EF Core API usage. + } + else + { + _expressions.Add(Expression.Call( + _includeReferenceMethodInfo.MakeGenericMethod(entityClrType, includingClrType, relatedEntityClrType), + QueryCompilationContext.QueryContextParameter, + entity, + Visit(includeExpression.NavigationExpression), + Expression.Constant(includeExpression.Navigation), + Expression.Constant(inverseNavigation, typeof(INavigationBase)), + Expression.Constant( + GenerateFixup( + includingClrType, relatedEntityClrType, includeExpression.Navigation, inverseNavigation).Compile()), + Expression.Constant(_tracking))); + } + + return entity; + } + + case CollectionResultShaperExpression collectionResultShaperExpression: + { + var navigation = collectionResultShaperExpression.Navigation; + var collectionAccessor = navigation?.GetCollectionAccessor(); + var collectionType = collectionAccessor?.CollectionType ?? collectionResultShaperExpression.Type; + var elementType = collectionResultShaperExpression.ElementType; + var shaperLambda = new ShaperExpressionProcessingExpressionVisitor(_inMemoryShapedQueryCompilingExpressionVisitor, _tracking) + .ProcessShaper(collectionResultShaperExpression.InnerShaper); + + return Expression.Call( + _materializeCollectionMethodInfo.MakeGenericMethod(elementType, collectionType), + QueryCompilationContext.QueryContextParameter, + Visit(collectionResultShaperExpression.Projection), + Expression.Constant(shaperLambda.Compile()), + Expression.Constant(collectionAccessor, typeof(IClrCollectionAccessor))); + } + + case SingleResultShaperExpression singleResultShaperExpression: + { + var shaperLambda = new ShaperExpressionProcessingExpressionVisitor(_inMemoryShapedQueryCompilingExpressionVisitor, _tracking) + .ProcessShaper(singleResultShaperExpression.InnerShaper); + + return Expression.Call( + _materializeSingleResultMethodInfo.MakeGenericMethod(singleResultShaperExpression.Type), + QueryCompilationContext.QueryContextParameter, + Visit(singleResultShaperExpression.Projection), + Expression.Constant(shaperLambda.Compile())); + } + } + + return base.VisitExtension(extensionExpression); + } + + protected override Expression VisitBinary(BinaryExpression binaryExpression) + { + Check.NotNull(binaryExpression, nameof(binaryExpression)); + + if (binaryExpression.NodeType == ExpressionType.Assign + && binaryExpression.Left is ParameterExpression parameterExpression + && parameterExpression.Type == typeof(MaterializationContext)) + { + var newExpression = (NewExpression)binaryExpression.Right; + + var projectionBindingExpression = (ProjectionBindingExpression)newExpression.Arguments[0]; + var queryExpression = (InMemoryQueryExpression)projectionBindingExpression.QueryExpression; + if (_valueBufferParameter == null) + { + _valueBufferParameter = queryExpression.CurrentParameter; + } + + _materializationContextBindings[parameterExpression] + = queryExpression.GetProjection(projectionBindingExpression).GetConstantValue>(); + + var updatedExpression = newExpression.Update( + new[] { Expression.Constant(ValueBuffer.Empty), newExpression.Arguments[1] }); + + return Expression.MakeBinary(ExpressionType.Assign, binaryExpression.Left, updatedExpression); + } + + if (binaryExpression.NodeType == ExpressionType.Assign + && binaryExpression.Left is MemberExpression memberExpression + && memberExpression.Member is FieldInfo fieldInfo + && fieldInfo.IsInitOnly) + { + return memberExpression.Assign(Visit(binaryExpression.Right)); + } + + return base.VisitBinary(binaryExpression); + } + + protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) + { + Check.NotNull(methodCallExpression, nameof(methodCallExpression)); + + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == EntityFrameworkCore.Infrastructure.ExpressionExtensions.ValueBufferTryReadValueMethod) + { + var property = methodCallExpression.Arguments[2].GetConstantValue(); + var indexMap = _materializationContextBindings[ + (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object!]; + + Check.DebugAssert( + property != null || methodCallExpression.Type.IsNullableType(), "Must read nullable value without property"); + + return Expression.Call( + methodCallExpression.Method, + _valueBufferParameter!, + Expression.Constant(indexMap[property!]), + methodCallExpression.Arguments[2]); + } + + return base.VisitMethodCall(methodCallExpression); + } + + private static void IncludeReference( + QueryContext queryContext, + TEntity entity, + TIncludedEntity relatedEntity, + INavigationBase navigation, + INavigationBase inverseNavigation, + Action fixup, + bool trackingQuery) + where TIncludingEntity : class, TEntity + where TEntity : class + where TIncludedEntity : class + { + if (entity is TIncludingEntity includingEntity) + { + if (trackingQuery + && navigation.DeclaringEntityType.FindPrimaryKey() != null) + { + // For non-null relatedEntity StateManager will set the flag + if (relatedEntity == null) + { + queryContext.SetNavigationIsLoaded(includingEntity, navigation); + } + } + else + { + navigation.SetIsLoadedWhenNoTracking(includingEntity); + if (relatedEntity != null) + { + fixup(includingEntity, relatedEntity); + if (inverseNavigation != null + && !inverseNavigation.IsCollection) + { + inverseNavigation.SetIsLoadedWhenNoTracking(relatedEntity); + } + } + } + } + } + + private static void IncludeCollection( + QueryContext queryContext, + IEnumerable innerValueBuffers, + Func innerShaper, + TEntity entity, + INavigationBase navigation, + INavigationBase inverseNavigation, + Action fixup, + bool trackingQuery, + bool setLoaded) + where TIncludingEntity : class, TEntity + where TEntity : class + where TIncludedEntity : class + { + if (entity is TIncludingEntity includingEntity) + { + var collectionAccessor = navigation.GetCollectionAccessor()!; + collectionAccessor.GetOrCreate(includingEntity, forMaterialization: true); + + if (setLoaded) + { + if (trackingQuery) + { + queryContext.SetNavigationIsLoaded(entity, navigation); + } + else + { + navigation.SetIsLoadedWhenNoTracking(entity); + } + } + + foreach (var valueBuffer in innerValueBuffers) + { + var relatedEntity = innerShaper(queryContext, valueBuffer); + + if (!trackingQuery) + { + fixup(includingEntity, relatedEntity); + if (inverseNavigation != null) + { + inverseNavigation.SetIsLoadedWhenNoTracking(relatedEntity); + } + } + } + } + } + + private static TCollection MaterializeCollection( + QueryContext queryContext, + IEnumerable innerValueBuffers, + Func innerShaper, + IClrCollectionAccessor clrCollectionAccessor) + where TCollection : class, ICollection + { + var collection = (TCollection)(clrCollectionAccessor?.Create() ?? new List()); + + foreach (var valueBuffer in innerValueBuffers) + { + var element = innerShaper(queryContext, valueBuffer); + collection.Add(element); + } + + return collection; + } + + private static TResult? MaterializeSingleResult( + QueryContext queryContext, + ValueBuffer valueBuffer, + Func innerShaper) + => valueBuffer.IsEmpty + ? default + : innerShaper(queryContext, valueBuffer); + + private static LambdaExpression GenerateFixup( + Type entityType, + Type relatedEntityType, + INavigationBase navigation, + INavigationBase? inverseNavigation) + { + var entityParameter = Expression.Parameter(entityType); + var relatedEntityParameter = Expression.Parameter(relatedEntityType); + var expressions = new List + { + navigation.IsCollection + ? AddToCollectionNavigation(entityParameter, relatedEntityParameter, navigation) + : AssignReferenceNavigation(entityParameter, relatedEntityParameter, navigation) + }; + + if (inverseNavigation != null) + { + expressions.Add( + inverseNavigation.IsCollection + ? AddToCollectionNavigation(relatedEntityParameter, entityParameter, inverseNavigation) + : AssignReferenceNavigation(relatedEntityParameter, entityParameter, inverseNavigation)); + } + + return Expression.Lambda(Expression.Block(typeof(void), expressions), entityParameter, relatedEntityParameter); + } + + private static Expression AssignReferenceNavigation( + ParameterExpression entity, + ParameterExpression relatedEntity, + INavigationBase navigation) + => entity.MakeMemberAccess(navigation.GetMemberInfo(forMaterialization: true, forSet: true)).Assign(relatedEntity); + + private static Expression AddToCollectionNavigation( + ParameterExpression entity, + ParameterExpression relatedEntity, + INavigationBase navigation) + => Expression.Call( + Expression.Constant(navigation.GetCollectionAccessor()), + _collectionAccessorAddMethodInfo, + entity, + relatedEntity, + Expression.Constant(true)); + } + } +} diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.cs index 2dadf002550..81930ee4899 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryShapedQueryCompilingExpressionVisitor.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq.Expressions; using System.Reflection; using Microsoft.EntityFrameworkCore.Metadata; @@ -44,10 +45,6 @@ protected override Expression VisitExtension(Expression extensionExpression) switch (extensionExpression) { - case InMemoryQueryExpression inMemoryQueryExpression: - inMemoryQueryExpression.ApplyProjection(); - return Visit(inMemoryQueryExpression.ServerQueryExpression); - case InMemoryTableExpression inMemoryTableExpression: return Expression.Call( _tableMethodInfo, @@ -69,27 +66,18 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery Check.NotNull(shapedQueryExpression, nameof(shapedQueryExpression)); var inMemoryQueryExpression = (InMemoryQueryExpression)shapedQueryExpression.QueryExpression; + inMemoryQueryExpression.ApplyProjection(); - var shaper = new ShaperExpressionProcessingExpressionVisitor( - inMemoryQueryExpression, inMemoryQueryExpression.CurrentParameter) - .Inject(shapedQueryExpression.ShaperExpression); - - shaper = InjectEntityMaterializers(shaper); - - var innerEnumerable = Visit(inMemoryQueryExpression); - - shaper = new InMemoryProjectionBindingRemovingExpressionVisitor().Visit(shaper); - - shaper = new CustomShaperCompilingExpressionVisitor( - QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll).Visit(shaper); - - var shaperLambda = (LambdaExpression)shaper; + var shaperExpression = new ShaperExpressionProcessingExpressionVisitor( + this, inMemoryQueryExpression, QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll) + .ProcessShaper(shapedQueryExpression.ShaperExpression); + var innerEnumerable = Visit(inMemoryQueryExpression.ServerQueryExpression); return Expression.New( - typeof(QueryingEnumerable<>).MakeGenericType(shaperLambda.ReturnType).GetConstructors()[0], + typeof(QueryingEnumerable<>).MakeGenericType(shaperExpression.ReturnType).GetConstructors()[0], QueryCompilationContext.QueryContextParameter, innerEnumerable, - Expression.Constant(shaperLambda.Compile()), + Expression.Constant(shaperExpression.Compile()), Expression.Constant(_contextType), Expression.Constant( QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution), diff --git a/src/EFCore.InMemory/Query/Internal/ShaperExpressionProcessingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/ShaperExpressionProcessingExpressionVisitor.cs deleted file mode 100644 index 924132a8c15..00000000000 --- a/src/EFCore.InMemory/Query/Internal/ShaperExpressionProcessingExpressionVisitor.cs +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.Collections.Generic; -using System.Linq.Expressions; -using Microsoft.EntityFrameworkCore.Query; -using Microsoft.EntityFrameworkCore.Utilities; - -namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal -{ - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public class ShaperExpressionProcessingExpressionVisitor : ExpressionVisitor - { - private readonly InMemoryQueryExpression? _queryExpression; - private readonly ParameterExpression _valueBufferParameter; - - private readonly IDictionary _mapping = new Dictionary(); - private readonly List _variables = new(); - private readonly List _expressions = new(); - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public ShaperExpressionProcessingExpressionVisitor( - InMemoryQueryExpression? queryExpression, - ParameterExpression valueBufferParameter) - { - _queryExpression = queryExpression; - _valueBufferParameter = valueBufferParameter; - } - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - public virtual Expression Inject(Expression expression) - { - var result = Visit(expression); - _expressions.Add(result); - result = Expression.Block(_variables, _expressions); - - return ConvertToLambda(result); - } - - private LambdaExpression ConvertToLambda(Expression result) - => Expression.Lambda( - result, - QueryCompilationContext.QueryContextParameter, - _valueBufferParameter); - - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - protected override Expression VisitExtension(Expression extensionExpression) - { - Check.NotNull(extensionExpression, nameof(extensionExpression)); - - switch (extensionExpression) - { - case EntityShaperExpression entityShaperExpression: - { - var key = GenerateKey((ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression); - if (!_mapping.TryGetValue(key, out var variable)) - { - variable = Expression.Parameter(entityShaperExpression.EntityType.ClrType); - _variables.Add(variable); - _expressions.Add(Expression.Assign(variable, entityShaperExpression)); - _mapping[key] = variable; - } - - return variable; - } - - case ProjectionBindingExpression projectionBindingExpression: - { - var key = GenerateKey(projectionBindingExpression); - if (!_mapping.TryGetValue(key, out var variable)) - { - variable = Expression.Parameter(projectionBindingExpression.Type); - _variables.Add(variable); - _expressions.Add(Expression.Assign(variable, projectionBindingExpression)); - _mapping[key] = variable; - } - - return variable; - } - - case IncludeExpression includeExpression: - { - var entity = Visit(includeExpression.EntityExpression); - if (includeExpression.NavigationExpression is CollectionShaperExpression collectionShaper) - { - var innerLambda = (LambdaExpression)collectionShaper.InnerShaper; - var innerShaper = new ShaperExpressionProcessingExpressionVisitor(null, innerLambda.Parameters[0]) - .Inject(innerLambda.Body); - - _expressions.Add( - includeExpression.Update( - entity, - collectionShaper.Update( - Visit(collectionShaper.Projection), - innerShaper))); - } - else - { - _expressions.Add( - includeExpression.Update( - entity, - Visit(includeExpression.NavigationExpression))); - } - - return entity; - } - - case CollectionShaperExpression collectionShaperExpression: - { - var key = GenerateKey((ProjectionBindingExpression)collectionShaperExpression.Projection); - if (!_mapping.TryGetValue(key, out var variable)) - { - var projection = Visit(collectionShaperExpression.Projection); - - variable = Expression.Parameter(collectionShaperExpression.Type); - _variables.Add(variable); - - var innerLambda = (LambdaExpression)collectionShaperExpression.InnerShaper; - var innerShaper = new ShaperExpressionProcessingExpressionVisitor(null, innerLambda.Parameters[0]) - .Inject(innerLambda.Body); - - _expressions.Add(Expression.Assign(variable, collectionShaperExpression.Update(projection, innerShaper))); - _mapping[key] = variable; - } - - return variable; - } - - case SingleResultShaperExpression singleResultShaperExpression: - { - var key = GenerateKey((ProjectionBindingExpression)singleResultShaperExpression.Projection); - if (!_mapping.TryGetValue(key, out var variable)) - { - var projection = Visit(singleResultShaperExpression.Projection); - - variable = Expression.Parameter(singleResultShaperExpression.Type); - _variables.Add(variable); - - var innerLambda = (LambdaExpression)singleResultShaperExpression.InnerShaper; - var innerShaper = new ShaperExpressionProcessingExpressionVisitor(null, innerLambda.Parameters[0]) - .Inject(innerLambda.Body); - - _expressions.Add(Expression.Assign(variable, singleResultShaperExpression.Update(projection, innerShaper))); - _mapping[key] = variable; - } - - return variable; - } - } - - return base.VisitExtension(extensionExpression); - } - - private Expression GenerateKey(ProjectionBindingExpression projectionBindingExpression) - => _queryExpression != null - && projectionBindingExpression.ProjectionMember != null - ? _queryExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember) - : projectionBindingExpression; - } -} diff --git a/src/EFCore.InMemory/Query/Internal/SingleResultShaperExpression.cs b/src/EFCore.InMemory/Query/Internal/SingleResultShaperExpression.cs index f539d762f14..1b586e52cd1 100644 --- a/src/EFCore.InMemory/Query/Internal/SingleResultShaperExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/SingleResultShaperExpression.cs @@ -4,6 +4,7 @@ using System; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal @@ -24,12 +25,11 @@ public class SingleResultShaperExpression : Expression, IPrintableExpression /// public SingleResultShaperExpression( Expression projection, - Expression innerShaper, - Type type) + Expression innerShaper) { Projection = projection; InnerShaper = innerShaper; - Type = type; + Type = innerShaper.Type; } /// @@ -56,7 +56,7 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) /// public virtual SingleResultShaperExpression Update(Expression projection, Expression innerShaper) => projection != Projection || innerShaper != InnerShaper - ? new SingleResultShaperExpression(projection, innerShaper, Type) + ? new SingleResultShaperExpression(projection, innerShaper) : this; /// diff --git a/src/EFCore/Query/CollectionShaperExpression.cs b/src/EFCore/Query/CollectionShaperExpression.cs index cddc417cb0a..bdae2296500 100644 --- a/src/EFCore/Query/CollectionShaperExpression.cs +++ b/src/EFCore/Query/CollectionShaperExpression.cs @@ -18,6 +18,7 @@ namespace Microsoft.EntityFrameworkCore.Query /// not used in application code. /// /// + [Obsolete("Use provider specific expressions for collection results.")] public class CollectionShaperExpression : Expression, IPrintableExpression { /// diff --git a/src/EFCore/Query/ProjectionBindingExpression.cs b/src/EFCore/Query/ProjectionBindingExpression.cs index e63b76bf4a6..7211cfc16dc 100644 --- a/src/EFCore/Query/ProjectionBindingExpression.cs +++ b/src/EFCore/Query/ProjectionBindingExpression.cs @@ -66,6 +66,7 @@ public ProjectionBindingExpression( /// /// The query expression to get the value from. /// The index map to bind with query expression projection for ValueBuffer. + [Obsolete("The dictionary should be stored in client projection in query expression and access via index based binding.")] public ProjectionBindingExpression( Expression queryExpression, IReadOnlyDictionary indexMap) @@ -96,6 +97,7 @@ public ProjectionBindingExpression( /// /// The projection member to bind if binding is via index map for a value buffer. /// + [Obsolete("The dictionary should be stored in client projection in query expression and access via index based binding.")] public virtual IReadOnlyDictionary? IndexMap { get; } /// @@ -127,6 +129,7 @@ void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) { expressionPrinter.Append(Index.ToString()!); } +#pragma warning disable CS0618 // Type or member is obsolete else if (IndexMap != null) { using (expressionPrinter.Indent()) @@ -137,6 +140,7 @@ void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) } } } +#pragma warning restore CS0618 // Type or member is obsolete } /// @@ -153,10 +157,14 @@ private bool Equals(ProjectionBindingExpression projectionBindingExpression) ?? projectionBindingExpression.ProjectionMember == null) && Index == projectionBindingExpression.Index // Using reference equality here since if we are this far, we don't need to compare this. +#pragma warning disable CS0618 // Type or member is obsolete && IndexMap == projectionBindingExpression.IndexMap; +#pragma warning restore CS0618 // Type or member is obsolete /// public override int GetHashCode() +#pragma warning disable CS0618 // Type or member is obsolete => HashCode.Combine(QueryExpression, ProjectionMember, Index, IndexMap); +#pragma warning restore CS0618 // Type or member is obsolete } } diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindSelectQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindSelectQueryCosmosTest.cs index 8982c6a4fae..ceade5691a6 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindSelectQueryCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/NorthwindSelectQueryCosmosTest.cs @@ -1302,6 +1302,12 @@ public override Task Take_on_correlated_collection_in_first(bool async) return base.Take_on_correlated_collection_in_first(async); } + [ConditionalTheory(Skip = "Cross collection join Issue#17246")] + public override Task Client_projection_via_ctor_arguments(bool async) + { + return base.Client_projection_via_ctor_arguments(async); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.InMemory.FunctionalTests/Query/GearsOfWarQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/GearsOfWarQueryInMemoryTest.cs index 02340f99ac6..d9177295444 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/GearsOfWarQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/GearsOfWarQueryInMemoryTest.cs @@ -1,8 +1,10 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.InMemory.Internal; using Microsoft.EntityFrameworkCore.TestModels.GearsOfWarModel; using Xunit; using Xunit.Abstractions; @@ -52,5 +54,21 @@ public override Task Projecting_entity_as_well_as_correlated_collection_of_scala [ConditionalTheory(Skip = "issue #24325")] public override Task Correlated_collection_with_distinct_3_levels(bool async) => base.Correlated_collection_with_distinct_3_levels(async); + + public override async Task Projecting_correlated_collection_followed_by_Distinct(bool async) + { + var message = (await Assert.ThrowsAsync( + () => base.Projecting_correlated_collection_followed_by_Distinct(async))).Message; + + Assert.Equal(InMemoryStrings.DistinctOnSubqueryNotSupported, message); + } + + public override async Task Projecting_some_properties_as_well_as_correlated_collection_followed_by_Distinct(bool async) + { + var message = (await Assert.ThrowsAsync( + () => base.Projecting_some_properties_as_well_as_correlated_collection_followed_by_Distinct(async))).Message; + + Assert.Equal(InMemoryStrings.DistinctOnSubqueryNotSupported, message); + } } } diff --git a/test/EFCore.InMemory.FunctionalTests/Query/NorthwindSetOperationsQueryInMemoryTest.cs b/test/EFCore.InMemory.FunctionalTests/Query/NorthwindSetOperationsQueryInMemoryTest.cs index d6b30375bf3..55dbb654bdf 100644 --- a/test/EFCore.InMemory.FunctionalTests/Query/NorthwindSetOperationsQueryInMemoryTest.cs +++ b/test/EFCore.InMemory.FunctionalTests/Query/NorthwindSetOperationsQueryInMemoryTest.cs @@ -1,7 +1,11 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.InMemory.Internal; using Microsoft.EntityFrameworkCore.TestUtilities; +using Xunit; using Xunit.Abstractions; namespace Microsoft.EntityFrameworkCore.Query @@ -18,5 +22,13 @@ public NorthwindSetOperationsQueryInMemoryTest( { //TestLoggerFactory.TestOutputHelper = testOutputHelper; } + + public override async Task Collection_projection_before_set_operation_fails(bool async) + { + var message = (await Assert.ThrowsAsync( + () => base.Collection_projection_before_set_operation_fails(async))).Message; + + Assert.Equal(InMemoryStrings.SetOperationsNotAllowedAfterClientEvaluation, message); + } } } diff --git a/test/EFCore.Relational.Specification.Tests/Query/GearsOfWarQueryRelationalTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/GearsOfWarQueryRelationalTestBase.cs index 58509668da3..24f621de037 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/GearsOfWarQueryRelationalTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/GearsOfWarQueryRelationalTestBase.cs @@ -19,8 +19,6 @@ protected GearsOfWarQueryRelationalTestBase(TFixture fixture) { } - [ConditionalTheory] - [MemberData(nameof(IsAsyncData))] public override async Task Correlated_collection_with_groupby_with_complex_grouping_key_not_projecting_identifier_column_with_group_aggregate_in_final_projection(bool async) { var message = (await Assert.ThrowsAsync( @@ -29,8 +27,6 @@ public override async Task Correlated_collection_with_groupby_with_complex_group Assert.Equal(RelationalStrings.InsufficientInformationToIdentifyOuterElementOfCollectionJoin, message); } - [ConditionalTheory] - [MemberData(nameof(IsAsyncData))] public override async Task Correlated_collection_with_distinct_not_projecting_identifier_column_also_projecting_complex_expressions(bool async) { var message = (await Assert.ThrowsAsync( @@ -125,8 +121,6 @@ await AssertQuery( elementSorter: e => e.Name); } - [ConditionalTheory] - [MemberData(nameof(IsAsyncData))] public override async Task Projecting_correlated_collection_followed_by_Distinct(bool async) { var message = (await Assert.ThrowsAsync( @@ -135,8 +129,6 @@ public override async Task Projecting_correlated_collection_followed_by_Distinct Assert.Equal(RelationalStrings.DistinctOnCollectionNotSupported, message); } - [ConditionalTheory] - [MemberData(nameof(IsAsyncData))] public override async Task Projecting_some_properties_as_well_as_correlated_collection_followed_by_Distinct(bool async) { var message = (await Assert.ThrowsAsync( @@ -145,8 +137,6 @@ public override async Task Projecting_some_properties_as_well_as_correlated_coll Assert.Equal(RelationalStrings.DistinctOnCollectionNotSupported, message); } - [ConditionalTheory] - [MemberData(nameof(IsAsyncData))] public override async Task Projecting_entity_as_well_as_correlated_collection_followed_by_Distinct(bool async) { var message = (await Assert.ThrowsAsync( @@ -155,8 +145,6 @@ public override async Task Projecting_entity_as_well_as_correlated_collection_fo Assert.Equal(RelationalStrings.DistinctOnCollectionNotSupported, message); } - [ConditionalTheory] - [MemberData(nameof(IsAsyncData))] public override async Task Projecting_entity_as_well_as_complex_correlated_collection_followed_by_Distinct(bool async) { var message = (await Assert.ThrowsAsync( @@ -165,8 +153,6 @@ public override async Task Projecting_entity_as_well_as_complex_correlated_colle Assert.Equal(RelationalStrings.DistinctOnCollectionNotSupported, message); } - [ConditionalTheory] - [MemberData(nameof(IsAsyncData))] public override async Task Projecting_entity_as_well_as_correlated_collection_of_scalars_followed_by_Distinct(bool async) { var message = (await Assert.ThrowsAsync( @@ -175,8 +161,6 @@ public override async Task Projecting_entity_as_well_as_correlated_collection_of Assert.Equal(RelationalStrings.DistinctOnCollectionNotSupported, message); } - [ConditionalTheory] - [MemberData(nameof(IsAsyncData))] public override async Task Correlated_collection_with_distinct_3_levels(bool async) { var message = (await Assert.ThrowsAsync( diff --git a/test/EFCore.Specification.Tests/Query/NorthwindSelectQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindSelectQueryTestBase.cs index ff898ead293..6e997b744e2 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindSelectQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindSelectQueryTestBase.cs @@ -2317,5 +2317,61 @@ public virtual Task Take_on_correlated_collection_in_first(bool async) asserter: (e, a) => AssertCollection(e.Orders, a.Orders, ordered: true, elementAsserter: (ee, aa) => AssertEqual(ee.Title, aa.Title))); } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Client_projection_via_ctor_arguments(bool async) + { + return AssertSingle( + async, + ss => + ss.Set() + .Where(c => c.CustomerID == "ALFKI") + .Include(c => c.Orders) + .Select(c => new CustomerDetailsWithCount(c.CustomerID, c.City, + c.Orders.Select(o => new OrderInfo(o.OrderID, o.OrderDate)).ToList(), c.Orders.Count)), + asserter: (e, a) => + { + Assert.Equal(e.CustomerID, a.CustomerID); + Assert.Equal(e.City, a.City); + AssertCollection(e.OrderInfos, a.OrderInfos, + elementSorter: i => i.OrderID, + elementAsserter: (ie, ia) => + { + Assert.Equal(ie.OrderID, ia.OrderID); + Assert.Equal(ie.OrderDate, ia.OrderDate); + }); + Assert.Equal(e.OrderCount, a.OrderCount); + }); + } + + private class CustomerDetailsWithCount + { + public CustomerDetailsWithCount(string customerID, string city, List orderInfos, int orderCount) + { + CustomerID = customerID; + City = city; + OrderInfos = orderInfos; + OrderCount = orderCount; + } + + public string CustomerID { get; } + public string City { get; } + public List OrderInfos { get; } + public int OrderCount { get; } + } + + private class OrderInfo + { + public OrderInfo(int orderID, DateTime? orderDate) + { + OrderID = orderID; + OrderDate = orderDate; + } + + public int OrderID { get; } + public DateTime? OrderDate { get; } + } + } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindSelectQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindSelectQuerySqlServerTest.cs index 5fe39cdfe1c..f9dc1d7a5d1 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindSelectQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindSelectQuerySqlServerTest.cs @@ -1824,6 +1824,24 @@ ORDER BY [o].[OrderDate] ORDER BY [t].[CustomerID], [t0].[OrderDate], [t0].[OrderID], [t0].[CustomerID]"); } + public override async Task Client_projection_via_ctor_arguments(bool async) + { + await base.Client_projection_via_ctor_arguments(async); + + AssertSql( + @"SELECT [t].[CustomerID], [t].[City], [o0].[OrderID], [o0].[OrderDate], [t].[c] +FROM ( + SELECT TOP(2) [c].[CustomerID], [c].[City], ( + SELECT COUNT(*) + FROM [Orders] AS [o] + WHERE [c].[CustomerID] = [o].[CustomerID]) AS [c] + FROM [Customers] AS [c] + WHERE [c].[CustomerID] = N'ALFKI' +) AS [t] +LEFT JOIN [Orders] AS [o0] ON [t].[CustomerID] = [o0].[CustomerID] +ORDER BY [t].[CustomerID], [o0].[OrderID]"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); diff --git a/test/EFCore.Tests/ChangeTracking/Internal/QueryFixupTest.cs b/test/EFCore.Tests/ChangeTracking/Internal/QueryFixupTest.cs index d3d4949c544..366b47f274f 100644 --- a/test/EFCore.Tests/ChangeTracking/Internal/QueryFixupTest.cs +++ b/test/EFCore.Tests/ChangeTracking/Internal/QueryFixupTest.cs @@ -885,7 +885,7 @@ public void Query_owned() }); } - [ConditionalFact] + [ConditionalFact(Skip = "Issue#24807")] public void Query_subowned() { Seed();