Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix test and a bit of code cleanup #34731

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -504,15 +504,14 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
: propertyMap.Values.Max() + 1;

var updatedExpression = newExpression.Update(
new[]
{
_parentVisitor.Dependencies.LiftableConstantFactory.CreateLiftableConstant(
[
_parentVisitor.Dependencies.LiftableConstantFactory.CreateLiftableConstant(
ValueBuffer.Empty,
static _ => ValueBuffer.Empty,
"emptyValueBuffer",
typeof(ValueBuffer)),
newExpression.Arguments[1]
});
]);

return Assign(binaryExpression.Left, updatedExpression);
}
Expand All @@ -524,15 +523,14 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
_jsonMaterializationContextToJsonReaderDataAndKeyValuesParameterMapping[parameterExpression] = mappedParameter;

var updatedExpression = newExpression.Update(
new[]
{
_parentVisitor.Dependencies.LiftableConstantFactory.CreateLiftableConstant(
[
_parentVisitor.Dependencies.LiftableConstantFactory.CreateLiftableConstant(
ValueBuffer.Empty,
static _ => ValueBuffer.Empty,
"emptyValueBuffer",
typeof(ValueBuffer)),
newExpression.Arguments[1]
});
]);

return Assign(binaryExpression.Left, updatedExpression);
}
Expand Down Expand Up @@ -1392,7 +1390,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var valueExpression = MakeIndex(
keyPropertyValuesParameter,
ObjectArrayIndexerPropertyInfo,
new[] { Constant(index) });
[Constant(index)]);
return methodCallExpression.Type != valueExpression.Type
? Convert(valueExpression, methodCallExpression.Type)
: valueExpression;
Expand Down Expand Up @@ -1744,8 +1742,11 @@ protected override Expression VisitSwitch(SwitchExpression switchExpression)
//sometimes we have shadow snapshot and sometimes not, but type initializer always comes last
switch (body.Expressions[^1])
{
case UnaryExpression { Operand: BlockExpression innerBlock } jsonEntityTypeInitializerUnary
when jsonEntityTypeInitializerUnary.NodeType is ExpressionType.Convert or ExpressionType.ConvertChecked:
case UnaryExpression
{
Operand: BlockExpression innerBlock,
NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked
} jsonEntityTypeInitializerUnary:
{
// in case of proxies, the entity initializer block is wrapped around Convert node
// that converts from the proxy type to the actual entity type.
Expand Down Expand Up @@ -1796,7 +1797,7 @@ protected override Expression VisitSwitch(SwitchExpression switchExpression)
case NewExpression jsonEntityTypeInitializerCtor:
var newInstanceVariable = Variable(jsonEntityTypeInitializerCtor.Type, "instance");
jsonEntityTypeInitializerBlock = Block(
new[] { newInstanceVariable },
[newInstanceVariable],
Assign(newInstanceVariable, jsonEntityTypeInitializerCtor),
newInstanceVariable);
break;
Expand Down Expand Up @@ -1886,14 +1887,7 @@ protected override Expression VisitSwitch(SwitchExpression switchExpression)

// Fixup is only needed for non-tracking queries, in case of tracking (or NoTrackingWithIdentityResolution) - ChangeTracker does the job
// or for empty/null collections of a tracking queries.
if (queryStateManager)
{
ProcessFixup(trackingInnerFixupMap);
}
else
{
ProcessFixup(innerFixupMap);
}
ProcessFixup(queryStateManager ? trackingInnerFixupMap : innerFixupMap);

finalBlockExpressions.Add(jsonEntityTypeVariable);

Expand Down Expand Up @@ -2113,7 +2107,7 @@ protected override Expression VisitConditional(ConditionalExpression conditional
Block(
ifTrueBlock.Variables,
ifTrueBlock.Expressions.Concat(
new Expression[] { Assign(entityAlreadyTrackedVariable, Constant(true)), Default(typeof(void)) })))
[Assign(entityAlreadyTrackedVariable, Constant(true)), Default(typeof(void))])))
};

resultBlockVariables.AddRange(ifFalseBlock.Variables.ToList());
Expand Down Expand Up @@ -2266,7 +2260,7 @@ protected override Expression VisitBinary(BinaryExpression node)
? (Expression)currentVariable
: Convert(currentVariable, genericMethod.GetParameters()[1].ParameterType);
return Block(
new[] { currentVariable },
[currentVariable],
MakeMemberAccess(instance, property.GetMemberInfo(forMaterialization: true, forSet: false))
.Assign(currentVariable),
IfThenElse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,35 +61,35 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)

// Convert x.Count > 0 and x.Count != 0 to x.Any()
{
NodeType: ExpressionType.GreaterThan or ExpressionType.NotEqual,
Left: MemberExpression
{
Member: { Name: nameof(ICollection<object>.Count), DeclaringType.IsGenericType: true } member,
Expression: Expression source
},
Right: ConstantExpression { Value: 0 }
}
when (member.DeclaringType.GetGenericTypeDefinition().GetInterfaces().Any(
x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(ICollection<>)))
=> VisitMethodCall(
Expression.Call(
EnumerableMethods.AnyWithoutPredicate.MakeGenericMethod(source.Type.GetSequenceType()),
source)),
NodeType: ExpressionType.GreaterThan or ExpressionType.NotEqual,
Left: MemberExpression
{
Member: { Name: nameof(ICollection<object>.Count), DeclaringType.IsGenericType: true } member,
Expression: Expression source
},
Right: ConstantExpression { Value: 0 }
}
when (member.DeclaringType.GetGenericTypeDefinition().GetInterfaces().Any(
x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(ICollection<>)))
=> VisitMethodCall(
Expression.Call(
EnumerableMethods.AnyWithoutPredicate.MakeGenericMethod(source.Type.GetSequenceType()),
source)),

// Same for arrays: convert x.Length > 0 and x.Length != 0 to x.Any()
{
NodeType: ExpressionType.GreaterThan or ExpressionType.NotEqual,
Left: UnaryExpression
{
NodeType: ExpressionType.ArrayLength,
Operand: Expression source
},
Right: ConstantExpression { Value: 0 }
}
=> VisitMethodCall(
Expression.Call(
EnumerableMethods.AnyWithoutPredicate.MakeGenericMethod(source.Type.GetSequenceType()),
source)),
NodeType: ExpressionType.GreaterThan or ExpressionType.NotEqual,
Left: UnaryExpression
{
NodeType: ExpressionType.ArrayLength,
Operand: Expression source
},
Right: ConstantExpression { Value: 0 }
}
=> VisitMethodCall(
Expression.Call(
EnumerableMethods.AnyWithoutPredicate.MakeGenericMethod(source.Type.GetSequenceType()),
source)),

_ => base.VisitBinary(binaryExpression)
};
Expand Down Expand Up @@ -542,23 +542,33 @@ private MethodCallExpression TryNormalizeOrderAndOrderDescending(MethodCallExpre

private MethodCallExpression TryFlattenGroupJoinSelectMany(MethodCallExpression methodCallExpression)
{
var genericMethod = methodCallExpression.Method.GetGenericMethodDefinition();
if (genericMethod == QueryableMethods.SelectManyWithCollectionSelector)
switch (methodCallExpression)
{
// SelectMany
var selectManySource = methodCallExpression.Arguments[0];
if (selectManySource is MethodCallExpression { Method.IsGenericMethod: true } groupJoinMethod
&& groupJoinMethod.Method.GetGenericMethodDefinition() == QueryableMethods.GroupJoin)
case
{
Method: { Name: nameof(Queryable.SelectMany), IsGenericMethod: true } selectManyMethod,
Arguments:
[
MethodCallExpression
{
Method: { Name: nameof(QueryableMethods.GroupJoin), IsGenericMethod: true } groupJoinMethod,
Arguments: { Count: 5 } groupJoinArguments
},
_,
_
] selectManyArguments
}
when selectManyMethod.GetGenericMethodDefinition() == QueryableMethods.SelectManyWithCollectionSelector
&& groupJoinMethod.GetGenericMethodDefinition() == QueryableMethods.GroupJoin:
{
// GroupJoin
var outer = groupJoinMethod.Arguments[0];
var inner = groupJoinMethod.Arguments[1];
var outerKeySelector = groupJoinMethod.Arguments[2].UnwrapLambdaFromQuote();
var innerKeySelector = groupJoinMethod.Arguments[3].UnwrapLambdaFromQuote();
var groupJoinResultSelector = groupJoinMethod.Arguments[4].UnwrapLambdaFromQuote();
var outer = groupJoinArguments[0];
var inner = groupJoinArguments[1];
var outerKeySelector = groupJoinArguments[2].UnwrapLambdaFromQuote();
var innerKeySelector = groupJoinArguments[3].UnwrapLambdaFromQuote();
var groupJoinResultSelector = groupJoinArguments[4].UnwrapLambdaFromQuote();

var selectManyCollectionSelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
var selectManyResultSelector = methodCallExpression.Arguments[2].UnwrapLambdaFromQuote();
var selectManyCollectionSelector = selectManyArguments[1].UnwrapLambdaFromQuote();
var selectManyResultSelector = selectManyArguments[2].UnwrapLambdaFromQuote();

var collectionSelectorBody = selectManyCollectionSelector.Body;
var defaultIfEmpty = false;
Expand Down Expand Up @@ -603,14 +613,16 @@ private MethodCallExpression TryFlattenGroupJoinSelectMany(MethodCallExpression
resultSelectorBody,
groupJoinResultSelector.Parameters[0],
selectManyResultSelector.Parameters[1]);
var genericArguments = groupJoinMethod.Method.GetGenericArguments();
var genericArguments = groupJoinMethod.GetGenericArguments();
genericArguments[^1] = resultSelector.ReturnType;

return Expression.Call(
(defaultIfEmpty ? QueryableExtensions.LeftJoinMethodInfo : QueryableMethods.Join).MakeGenericMethod(
genericArguments),
outer, inner, outerKeySelector, innerKeySelector, resultSelector);
}

break;
// TODO: Convert correlated patterns to SelectMany
//else
//{
Expand Down Expand Up @@ -643,22 +655,30 @@ private MethodCallExpression TryFlattenGroupJoinSelectMany(MethodCallExpression
// selectManyResultSelector.Parameters[1]);
//}
}
}
else if (genericMethod == QueryableMethods.SelectManyWithoutCollectionSelector)
{
// SelectMany
var selectManySource = methodCallExpression.Arguments[0];
if (selectManySource is MethodCallExpression { Method.IsGenericMethod: true } groupJoinMethod
&& groupJoinMethod.Method.GetGenericMethodDefinition() == QueryableMethods.GroupJoin)

case
{
Method: { Name: nameof(Queryable.SelectMany), IsGenericMethod: true } selectManyMethod,
Arguments:
[
MethodCallExpression
{
Method: { Name: nameof(QueryableMethods.GroupJoin), IsGenericMethod: true } groupJoinMethod,
Arguments: { Count: 5 } groupJoinArguments
},
_
] selectManyArguments
}
when selectManyMethod.GetGenericMethodDefinition() == QueryableMethods.SelectManyWithoutCollectionSelector
&& groupJoinMethod.GetGenericMethodDefinition() == QueryableMethods.GroupJoin:
{
// GroupJoin
var outer = groupJoinMethod.Arguments[0];
var inner = groupJoinMethod.Arguments[1];
var outerKeySelector = groupJoinMethod.Arguments[2].UnwrapLambdaFromQuote();
var innerKeySelector = groupJoinMethod.Arguments[3].UnwrapLambdaFromQuote();
var groupJoinResultSelector = groupJoinMethod.Arguments[4].UnwrapLambdaFromQuote();
var outer = groupJoinArguments[0];
var inner = groupJoinArguments[1];
var outerKeySelector = groupJoinArguments[2].UnwrapLambdaFromQuote();
var innerKeySelector = groupJoinArguments[3].UnwrapLambdaFromQuote();
var groupJoinResultSelector = groupJoinArguments[4].UnwrapLambdaFromQuote();

var selectManyResultSelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
var selectManyResultSelector = selectManyArguments[1].UnwrapLambdaFromQuote();

var groupJoinResultSelectorBody = groupJoinResultSelector.Body;
var defaultIfEmpty = false;
Expand Down Expand Up @@ -693,14 +713,16 @@ private MethodCallExpression TryFlattenGroupJoinSelectMany(MethodCallExpression
groupJoinResultSelector.Parameters[0],
innerKeySelector.Parameters[0]);

var genericArguments = groupJoinMethod.Method.GetGenericArguments();
var genericArguments = groupJoinMethod.GetGenericArguments();
genericArguments[^1] = resultSelector.ReturnType;

return Expression.Call(
(defaultIfEmpty ? QueryableExtensions.LeftJoinMethodInfo : QueryableMethods.Join).MakeGenericMethod(
genericArguments),
outer, inner, outerKeySelector, innerKeySelector, resultSelector);
}

break;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,7 @@ public virtual Task Two_captured_variables_in_same_lambda()
public virtual Task Two_captured_variables_in_different_lambdas()
=> Test(
"""
var starts = "blog";
var starts = "Blog";
var ends = "2";
var blog = await context.Blogs.Where(b => b.Name.StartsWith(starts)).Where(b => b.Name.EndsWith(ends)).SingleAsync();
Assert.Equal(9, blog.Id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1952,7 +1952,7 @@ public override async Task Two_captured_variables_in_different_lambdas()

AssertSql(
"""
@__starts_0_startswith='blog%' (Size = 4000)
@__starts_0_startswith='Blog%' (Size = 4000)
@__ends_1_endswith='%2' (Size = 4000)

SELECT TOP(2) [b].[Id], [b].[Name], [b].[Json]
Expand Down