55using System . Linq . Expressions ;
66using System . Reflection ;
77using EntityFrameworkCore . Projectables . Extensions ;
8+ using Microsoft . EntityFrameworkCore ;
89using Microsoft . EntityFrameworkCore . Metadata ;
910using Microsoft . EntityFrameworkCore . Query ;
1011
@@ -16,6 +17,7 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
1617 readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new ( ) ;
1718 readonly Dictionary < MemberInfo , LambdaExpression ? > _projectableMemberCache = new ( ) ;
1819 private bool _disableRootRewrite ;
20+ private List < string > _includedProjections = new ( ) ;
1921 private IEntityType ? _entityType ;
2022
2123 private readonly MethodInfo _select ;
@@ -60,6 +62,7 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
6062 public Expression ? Replace ( Expression ? node )
6163 {
6264 _disableRootRewrite = false ;
65+ _includedProjections . Clear ( ) ;
6366 var ret = Visit ( node ) ;
6467
6568 if ( _disableRootRewrite )
@@ -138,6 +141,28 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
138141
139142 protected override Expression VisitMethodCall ( MethodCallExpression node )
140143 {
144+ if ( node . Method . Name == nameof ( EntityFrameworkQueryableExtensions . Include ) )
145+ {
146+ var include = node . Arguments [ 1 ] switch {
147+ ConstantExpression { Value : string str } => str ,
148+ LambdaExpression { Body : MemberExpression member } => member . Member . Name ,
149+ UnaryExpression { Operand : LambdaExpression { Body : MemberExpression member } } => member . Member . Name ,
150+ _ => null
151+ } ;
152+ // Only rewrite the include if it includes a projectable property (or if we don't know what's happening).
153+ var ret = Visit ( node . Arguments [ 0 ] ) ;
154+ // The visit here is needed because we need the _entityType defined on the query root for the condition below.
155+ if (
156+ include != null
157+ && _entityType ? . ClrType
158+ ? . GetProperty ( include )
159+ ? . GetCustomAttribute < ProjectableAttribute > ( ) != null )
160+ {
161+ _includedProjections . Add ( include ) ;
162+ return ret ;
163+ }
164+ }
165+
141166 // Replace MethodGroup arguments with their reflected expressions.
142167 // Note that MethodCallExpression.Update returns the original Expression if argument values have not changed.
143168 node = node . Update ( node . Object , node . Arguments . Select ( arg => arg switch {
@@ -212,13 +237,13 @@ PropertyInfo property when nodeExpression is not null
212237 var updatedBody = _expressionArgumentReplacer . Visit ( reflectedExpression . Body ) ;
213238 _expressionArgumentReplacer . ParameterArgumentMapping . Clear ( ) ;
214239
215- return base . Visit (
240+ return Visit (
216241 updatedBody
217242 ) ;
218243 }
219244 else
220245 {
221- return base . Visit (
246+ return Visit (
222247 reflectedExpression . Body
223248 ) ;
224249 }
@@ -243,7 +268,14 @@ protected override Expression VisitExtension(Expression node)
243268 private Expression _AddProjectableSelect ( Expression node , IEntityType entityType )
244269 {
245270 var projectableProperties = entityType . ClrType . GetProperties ( )
246- . Where ( x => x . IsDefined ( typeof ( ProjectableAttribute ) , false ) )
271+ . Where ( x => {
272+ var attr = x . GetCustomAttribute < ProjectableAttribute > ( ) ;
273+ if ( attr == null )
274+ return false ;
275+ if ( attr . OnlyOnInclude )
276+ return _includedProjections . Contains ( x . Name ) ;
277+ return true ;
278+ } )
247279 . Where ( x => x . CanWrite )
248280 . ToList ( ) ;
249281
@@ -288,7 +320,7 @@ private Expression _GetAccessor(PropertyInfo property, ParameterExpression para)
288320 _expressionArgumentReplacer . ParameterArgumentMapping . Add ( lambda . Parameters [ 0 ] , para ) ;
289321 var updatedBody = _expressionArgumentReplacer . Visit ( lambda . Body ) ;
290322 _expressionArgumentReplacer . ParameterArgumentMapping . Clear ( ) ;
291- return base . Visit ( updatedBody ) ;
323+ return Visit ( updatedBody ) ;
292324 }
293325 }
294326}
0 commit comments