Skip to content

Commit b8a558e

Browse files
authored
Fix MethodFinder TryFindAggregateMethod to support array (#923)
* Fix MethodFinder TryFindAggregateMethod to support array * 3
1 parent ef89139 commit b8a558e

File tree

6 files changed

+317
-163
lines changed

6 files changed

+317
-163
lines changed

src/System.Linq.Dynamic.Core/Parser/SupportedMethods/MethodFinder.cs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ internal class MethodFinder
1010
{
1111
private readonly ParsingConfig _parsingConfig;
1212
private readonly IExpressionHelper _expressionHelper;
13+
private readonly IDictionary<Type, MethodInfo[]> _cachedMethods;
1314

1415
/// <summary>
1516
/// #794
@@ -43,19 +44,32 @@ public MethodFinder(ParsingConfig parsingConfig, IExpressionHelper expressionHel
4344
{
4445
_parsingConfig = Check.NotNull(parsingConfig);
4546
_expressionHelper = Check.NotNull(expressionHelper);
47+
_cachedMethods = new Dictionary<Type, MethodInfo[]>
48+
{
49+
{ typeof(Enumerable), typeof(Enumerable).GetMethods().Where(m => !m.IsGenericMethodDefinition).ToArray() },
50+
{ typeof(Queryable), typeof(Queryable).GetMethods().Where(m => !m.IsGenericMethodDefinition).ToArray() }
51+
};
4652
}
4753

4854
public bool TryFindAggregateMethod(Type callType, string methodName, Type parameterType, [NotNullWhen(true)] out MethodInfo? aggregateMethod)
4955
{
50-
aggregateMethod = callType
51-
.GetMethods()
52-
.Where(m => m.Name == methodName && !m.IsGenericMethodDefinition)
53-
.SelectMany(m => m.GetParameters(), (m, p) => new { Method = m, Parameter = p })
54-
.Where(x => x.Parameter.ParameterType == parameterType)
55-
.Select(x => x.Method)
56-
.FirstOrDefault();
57-
58-
return aggregateMethod != null;
56+
var nonGenericMethodsByName = _cachedMethods[callType]
57+
.Where(m => m.Name == methodName)
58+
.ToArray();
59+
60+
if (TypeHelper.TryGetAsEnumerable(parameterType, out var parameterTypeAsEnumerable))
61+
{
62+
aggregateMethod = nonGenericMethodsByName
63+
.SelectMany(m => m.GetParameters(), (m, p) => new { Method = m, Parameter = p })
64+
.Where(x => x.Parameter.ParameterType == parameterTypeAsEnumerable)
65+
.Select(x => x.Method)
66+
.FirstOrDefault();
67+
68+
return aggregateMethod != null;
69+
}
70+
71+
aggregateMethod = null;
72+
return false;
5973
}
6074

6175
public bool CheckAggregateMethodAndTryUpdateArgsToMatchMethodArgs(string methodName, ref Expression[] args)

src/System.Linq.Dynamic.Core/Parser/TypeHelper.cs

Lines changed: 96 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@ namespace System.Linq.Dynamic.Core.Parser;
66

77
internal static class TypeHelper
88
{
9+
internal static bool TryGetAsEnumerable(Type type, [NotNullWhen(true)] out Type? enumerableType)
10+
{
11+
if (type.IsArray)
12+
{
13+
enumerableType = typeof(IEnumerable<>).MakeGenericType(type.GetElementType()!);
14+
return true;
15+
}
16+
17+
if (type.GetTypeInfo().IsGenericType && type.GetGenericTypeDefinition() == typeof(IEnumerable<>))
18+
{
19+
enumerableType = type;
20+
return true;
21+
}
22+
23+
enumerableType = null;
24+
return false;
25+
}
26+
927
public static bool TryGetFirstGenericArgument(Type type, [NotNullWhen(true)] out Type? genericType)
1028
{
1129
var genericArguments = type.GetTypeInfo().GetGenericTypeArguments();
@@ -196,79 +214,79 @@ public static bool IsCompatibleWith(Type source, Type target)
196214
}
197215
return false;
198216
#else
199-
if (source == target)
200-
{
201-
return true;
202-
}
217+
if (source == target)
218+
{
219+
return true;
220+
}
203221

204-
if (!target.GetTypeInfo().IsValueType)
205-
{
206-
return target.IsAssignableFrom(source);
207-
}
222+
if (!target.GetTypeInfo().IsValueType)
223+
{
224+
return target.IsAssignableFrom(source);
225+
}
208226

209-
Type st = GetNonNullableType(source);
210-
Type tt = GetNonNullableType(target);
227+
Type st = GetNonNullableType(source);
228+
Type tt = GetNonNullableType(target);
211229

212-
if (st != source && tt == target)
213-
{
214-
return false;
215-
}
230+
if (st != source && tt == target)
231+
{
232+
return false;
233+
}
216234

217-
Type sc = st.GetTypeInfo().IsEnum ? typeof(object) : st;
218-
Type tc = tt.GetTypeInfo().IsEnum ? typeof(object) : tt;
235+
Type sc = st.GetTypeInfo().IsEnum ? typeof(object) : st;
236+
Type tc = tt.GetTypeInfo().IsEnum ? typeof(object) : tt;
219237

220-
if (sc == typeof(sbyte))
221-
{
222-
if (tc == typeof(sbyte) || tc == typeof(short) || tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
223-
return true;
224-
}
225-
else if (sc == typeof(byte))
226-
{
227-
if (tc == typeof(byte) || tc == typeof(short) || tc == typeof(ushort) || tc == typeof(int) || tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
228-
return true;
229-
}
230-
else if (sc == typeof(short))
231-
{
232-
if (tc == typeof(short) || tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
233-
return true;
234-
}
235-
else if (sc == typeof(ushort))
236-
{
237-
if (tc == typeof(ushort) || tc == typeof(int) || tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
238-
return true;
239-
}
240-
else if (sc == typeof(int))
241-
{
242-
if (tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
243-
return true;
244-
}
245-
else if (sc == typeof(uint))
246-
{
247-
if (tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
248-
return true;
249-
}
250-
else if (sc == typeof(long))
251-
{
252-
if (tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
253-
return true;
254-
}
255-
else if (sc == typeof(ulong))
256-
{
257-
if (tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
258-
return true;
259-
}
260-
else if (sc == typeof(float))
261-
{
262-
if (tc == typeof(float) || tc == typeof(double))
263-
return true;
264-
}
265-
266-
if (st == tt)
267-
{
238+
if (sc == typeof(sbyte))
239+
{
240+
if (tc == typeof(sbyte) || tc == typeof(short) || tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
268241
return true;
269-
}
242+
}
243+
else if (sc == typeof(byte))
244+
{
245+
if (tc == typeof(byte) || tc == typeof(short) || tc == typeof(ushort) || tc == typeof(int) || tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
246+
return true;
247+
}
248+
else if (sc == typeof(short))
249+
{
250+
if (tc == typeof(short) || tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
251+
return true;
252+
}
253+
else if (sc == typeof(ushort))
254+
{
255+
if (tc == typeof(ushort) || tc == typeof(int) || tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
256+
return true;
257+
}
258+
else if (sc == typeof(int))
259+
{
260+
if (tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
261+
return true;
262+
}
263+
else if (sc == typeof(uint))
264+
{
265+
if (tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
266+
return true;
267+
}
268+
else if (sc == typeof(long))
269+
{
270+
if (tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
271+
return true;
272+
}
273+
else if (sc == typeof(ulong))
274+
{
275+
if (tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
276+
return true;
277+
}
278+
else if (sc == typeof(float))
279+
{
280+
if (tc == typeof(float) || tc == typeof(double))
281+
return true;
282+
}
270283

271-
return false;
284+
if (st == tt)
285+
{
286+
return true;
287+
}
288+
289+
return false;
272290
#endif
273291
}
274292

@@ -391,19 +409,19 @@ private static int GetNumericTypeKind(Type type)
391409
return 0;
392410
}
393411
#else
394-
if (type.GetTypeInfo().IsEnum)
395-
{
396-
return 0;
397-
}
412+
if (type.GetTypeInfo().IsEnum)
413+
{
414+
return 0;
415+
}
398416

399-
if (type == typeof(char) || type == typeof(float) || type == typeof(double) || type == typeof(decimal))
400-
return 1;
401-
if (type == typeof(sbyte) || type == typeof(short) || type == typeof(int) || type == typeof(long))
402-
return 2;
403-
if (type == typeof(byte) || type == typeof(ushort) || type == typeof(uint) || type == typeof(ulong))
404-
return 3;
417+
if (type == typeof(char) || type == typeof(float) || type == typeof(double) || type == typeof(decimal))
418+
return 1;
419+
if (type == typeof(sbyte) || type == typeof(short) || type == typeof(int) || type == typeof(long))
420+
return 2;
421+
if (type == typeof(byte) || type == typeof(ushort) || type == typeof(uint) || type == typeof(ulong))
422+
return 3;
405423

406-
return 0;
424+
return 0;
407425
#endif
408426
}
409427

@@ -484,7 +502,7 @@ private static void AddInterface(ICollection<Type> types, Type type)
484502

485503
public static bool TryParseEnum(string value, Type? type, [NotNullWhen(true)] out object? enumValue)
486504
{
487-
if (type is { } && type.GetTypeInfo().IsEnum && Enum.IsDefined(type, value))
505+
if (type != null && type.GetTypeInfo().IsEnum && Enum.IsDefined(type, value))
488506
{
489507
enumValue = Enum.Parse(type, value, true);
490508
return true;
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using System.Linq.Dynamic.Core.Tests.Helpers.Models;
2+
using Xunit;
3+
4+
namespace System.Linq.Dynamic.Core.Tests;
5+
6+
public partial class ExpressionTests
7+
{
8+
[Fact]
9+
public void ExpressionTests_Sum()
10+
{
11+
// Arrange
12+
int[] initValues = [1, 2, 3, 4, 5];
13+
var qry = initValues.AsQueryable().Select(x => new { strValue = "str", intValue = x }).GroupBy(x => x.strValue);
14+
15+
// Act
16+
var result = qry.Select("Sum(intValue)").AsDynamicEnumerable().ToArray()[0];
17+
18+
// Assert
19+
Assert.Equal(15, result);
20+
}
21+
22+
[Fact]
23+
public void ExpressionTests_Sum_LowerCase()
24+
{
25+
// Arrange
26+
int[] initValues = [1, 2, 3, 4, 5];
27+
var qry = initValues.AsQueryable().Select(x => new { strValue = "str", intValue = x }).GroupBy(x => x.strValue);
28+
29+
// Act
30+
var result = qry.Select("sum(intValue)").AsDynamicEnumerable().ToArray()[0];
31+
32+
// Assert
33+
Assert.Equal(15, result);
34+
}
35+
36+
[Fact]
37+
public void ExpressionTests_Sum2()
38+
{
39+
// Arrange
40+
var initValues = new[]
41+
{
42+
new SimpleValuesModel { FloatValue = 1 },
43+
new SimpleValuesModel { FloatValue = 2 },
44+
new SimpleValuesModel { FloatValue = 3 },
45+
};
46+
47+
var qry = initValues.AsQueryable();
48+
49+
// Act
50+
var result = qry.Select("FloatValue").Sum();
51+
var result2 = ((IQueryable<float>)qry.Select("FloatValue")).Sum();
52+
53+
// Assert
54+
Assert.Equal(6.0f, result);
55+
Assert.Equal(6.0f, result2);
56+
}
57+
}

test/System.Linq.Dynamic.Core.Tests/ExpressionTests.cs

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,57 +2154,6 @@ public void ExpressionTests_Subtract_Number()
21542154
Check.That(result).ContainsExactly(expected);
21552155
}
21562156

2157-
[Fact]
2158-
public void ExpressionTests_Sum()
2159-
{
2160-
// Arrange
2161-
int[] initValues = { 1, 2, 3, 4, 5 };
2162-
var qry = initValues.AsQueryable().Select(x => new { strValue = "str", intValue = x }).GroupBy(x => x.strValue);
2163-
2164-
// Act
2165-
var result = qry.Select("Sum(intValue)").AsDynamicEnumerable().ToArray()[0];
2166-
2167-
// Assert
2168-
Assert.Equal(15, result);
2169-
}
2170-
2171-
[Fact]
2172-
public void ExpressionTests_Sum_LowerCase()
2173-
{
2174-
// Arrange
2175-
int[] initValues = { 1, 2, 3, 4, 5 };
2176-
var qry = initValues.AsQueryable().Select(x => new { strValue = "str", intValue = x }).GroupBy(x => x.strValue);
2177-
2178-
// Act
2179-
var result = qry.Select("sum(intValue)").AsDynamicEnumerable().ToArray()[0];
2180-
2181-
// Assert
2182-
Assert.Equal(15, result);
2183-
}
2184-
2185-
[Fact]
2186-
public void ExpressionTests_Sum2()
2187-
{
2188-
// Arrange
2189-
var initValues = new[]
2190-
{
2191-
new SimpleValuesModel { FloatValue = 1 },
2192-
new SimpleValuesModel { FloatValue = 2 },
2193-
new SimpleValuesModel { FloatValue = 3 },
2194-
};
2195-
2196-
var qry = initValues.AsQueryable();
2197-
2198-
// Act
2199-
var result = qry.Select("FloatValue").Sum();
2200-
var result2 = ((IQueryable<float>)qry.Select("FloatValue")).Sum();
2201-
2202-
// Assert
2203-
Assert.Equal(6.0f, result);
2204-
Assert.Equal(6.0f, result2);
2205-
}
2206-
2207-
22082157
[Fact]
22092158
public void ExpressionTests_Type_Integer()
22102159
{

0 commit comments

Comments
 (0)