Skip to content

Commit ec7c478

Browse files
committed
Add OnlyOnInclude option for the root rewritter
1 parent dfa59c8 commit ec7c478

File tree

3 files changed

+47
-11
lines changed

3 files changed

+47
-11
lines changed

samples/BasicSample/Program.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public class User
2121
public string FullName { get; set; }
2222
private string _FullName => FirstName + " " + LastName;
2323

24-
[Projectable(UseMemberBody = nameof(_TotalSpent))]
24+
[Projectable(UseMemberBody = nameof(_TotalSpent), OnlyOnInclude = true)]
2525
public double TotalSpent { get; set; }
2626
private double _TotalSpent => Orders.Sum(x => x.PriceSum);
2727

@@ -154,10 +154,11 @@ public static void Main(string[] args)
154154
}
155155

156156
{
157-
var result = dbContext.Users.FirstOrDefault();
157+
Console.WriteLine($"Unloaded total: {dbContext.Users.First().TotalSpent}");
158+
var result = dbContext.Users.Include(x => x.TotalSpent).FirstOrDefault();
158159
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
159160

160-
result = dbContext.Users.FirstOrDefault(x => x.TotalSpent > 1);
161+
result = dbContext.Users.Include(x => x.TotalSpent).FirstOrDefault(x => x.TotalSpent > 1);
161162
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
162163

163164
var spent = dbContext.Users.Sum(x => x.TotalSpent);

src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
11
using System;
2-
using System.Collections.Generic;
3-
using System.Linq;
4-
using System.Text;
5-
using System.Threading.Tasks;
62

73
namespace EntityFrameworkCore.Projectables
84
{
@@ -23,5 +19,12 @@ public sealed class ProjectableAttribute : Attribute
2319
/// or null to get it from the current member.
2420
/// </summary>
2521
public string? UseMemberBody { get; set; }
22+
23+
/// <summary>
24+
/// <c>true</c> will allow you to request for this property by
25+
/// explicitly calling .Include(x => x.Property) on the query,
26+
/// <c>false</c> will always consider this query to be included.
27+
/// </summary>
28+
public bool OnlyOnInclude { get; set; }
2629
}
2730
}

src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Linq.Expressions;
66
using System.Reflection;
77
using EntityFrameworkCore.Projectables.Extensions;
8+
using Microsoft.EntityFrameworkCore;
89
using Microsoft.EntityFrameworkCore.Metadata;
910
using Microsoft.EntityFrameworkCore.Query;
1011

@@ -16,6 +17,7 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
1617
readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new();
1718
readonly Dictionary<MemberInfo, LambdaExpression?> _projectableMemberCache = new();
1819
private bool _disableRootRewrite;
20+
private List<string> _includedProjections = new();
1921
private IEntityType? _entityType;
2022

2123
private readonly MethodInfo _select;
@@ -60,6 +62,7 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
6062
public Expression? Replace(Expression? node)
6163
{
6264
_disableRootRewrite = false;
65+
_includedProjections.Clear();
6366
var ret = Visit(node);
6467

6568
if (_disableRootRewrite)
@@ -138,6 +141,28 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
138141

139142
protected override Expression VisitMethodCall(MethodCallExpression node)
140143
{
144+
if (node.Method.Name == nameof(EntityFrameworkQueryableExtensions.Include))
145+
{
146+
var include = node.Arguments[1] switch {
147+
ConstantExpression { Value: string str } => str,
148+
LambdaExpression { Body: MemberExpression member } => member.Member.Name,
149+
UnaryExpression { Operand: LambdaExpression { Body: MemberExpression member } } => member.Member.Name,
150+
_ => null
151+
};
152+
// Only rewrite the include if it includes a projectable property (or if we don't know what's happening).
153+
var ret = Visit(node.Arguments[0]);
154+
// The visit here is needed because we need the _entityType defined on the query root for the condition below.
155+
if (
156+
include != null
157+
&& _entityType?.ClrType
158+
?.GetProperty(include)
159+
?.GetCustomAttribute<ProjectableAttribute>() != null)
160+
{
161+
_includedProjections.Add(include);
162+
return ret;
163+
}
164+
}
165+
141166
// Replace MethodGroup arguments with their reflected expressions.
142167
// Note that MethodCallExpression.Update returns the original Expression if argument values have not changed.
143168
node = node.Update(node.Object, node.Arguments.Select(arg => arg switch {
@@ -212,13 +237,13 @@ PropertyInfo property when nodeExpression is not null
212237
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
213238
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
214239

215-
return base.Visit(
240+
return Visit(
216241
updatedBody
217242
);
218243
}
219244
else
220245
{
221-
return base.Visit(
246+
return Visit(
222247
reflectedExpression.Body
223248
);
224249
}
@@ -243,7 +268,14 @@ protected override Expression VisitExtension(Expression node)
243268
private Expression _AddProjectableSelect(Expression node, IEntityType entityType)
244269
{
245270
var projectableProperties = entityType.ClrType.GetProperties()
246-
.Where(x => x.IsDefined(typeof(ProjectableAttribute), false))
271+
.Where(x => {
272+
var attr = x.GetCustomAttribute<ProjectableAttribute>();
273+
if (attr == null)
274+
return false;
275+
if (attr.OnlyOnInclude)
276+
return _includedProjections.Contains(x.Name);
277+
return true;
278+
})
247279
.Where(x => x.CanWrite)
248280
.ToList();
249281

@@ -288,7 +320,7 @@ private Expression _GetAccessor(PropertyInfo property, ParameterExpression para)
288320
_expressionArgumentReplacer.ParameterArgumentMapping.Add(lambda.Parameters[0], para);
289321
var updatedBody = _expressionArgumentReplacer.Visit(lambda.Body);
290322
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
291-
return base.Visit(updatedBody);
323+
return Visit(updatedBody);
292324
}
293325
}
294326
}

0 commit comments

Comments
 (0)