Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ internal class MethodFinder
{
private readonly ParsingConfig _parsingConfig;
private readonly IExpressionHelper _expressionHelper;
private readonly IDictionary<Type, MethodInfo[]> _cachedMethods;

/// <summary>
/// #794
Expand Down Expand Up @@ -43,19 +44,32 @@ public MethodFinder(ParsingConfig parsingConfig, IExpressionHelper expressionHel
{
_parsingConfig = Check.NotNull(parsingConfig);
_expressionHelper = Check.NotNull(expressionHelper);
_cachedMethods = new Dictionary<Type, MethodInfo[]>
{
{ typeof(Enumerable), typeof(Enumerable).GetMethods().Where(m => !m.IsGenericMethodDefinition).ToArray() },
{ typeof(Queryable), typeof(Queryable).GetMethods().Where(m => !m.IsGenericMethodDefinition).ToArray() }
};
}

public bool TryFindAggregateMethod(Type callType, string methodName, Type parameterType, [NotNullWhen(true)] out MethodInfo? aggregateMethod)
{
aggregateMethod = callType
.GetMethods()
.Where(m => m.Name == methodName && !m.IsGenericMethodDefinition)
.SelectMany(m => m.GetParameters(), (m, p) => new { Method = m, Parameter = p })
.Where(x => x.Parameter.ParameterType == parameterType)
.Select(x => x.Method)
.FirstOrDefault();

return aggregateMethod != null;
var nonGenericMethodsByName = _cachedMethods[callType]
.Where(m => m.Name == methodName)
.ToArray();

if (TypeHelper.TryGetAsEnumerable(parameterType, out var parameterTypeAsEnumerable))
{
aggregateMethod = nonGenericMethodsByName
.SelectMany(m => m.GetParameters(), (m, p) => new { Method = m, Parameter = p })
.Where(x => x.Parameter.ParameterType == parameterTypeAsEnumerable)
.Select(x => x.Method)
.FirstOrDefault();

return aggregateMethod != null;
}

aggregateMethod = null;
return false;
}

public bool CheckAggregateMethodAndTryUpdateArgsToMatchMethodArgs(string methodName, ref Expression[] args)
Expand Down
174 changes: 96 additions & 78 deletions src/System.Linq.Dynamic.Core/Parser/TypeHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@ namespace System.Linq.Dynamic.Core.Parser;

internal static class TypeHelper
{
internal static bool TryGetAsEnumerable(Type type, [NotNullWhen(true)] out Type? enumerableType)
{
if (type.IsArray)
{
enumerableType = typeof(IEnumerable<>).MakeGenericType(type.GetElementType()!);
return true;
}

if (type.GetTypeInfo().IsGenericType && type.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
enumerableType = type;
return true;
}

enumerableType = null;
return false;
}

public static bool TryGetFirstGenericArgument(Type type, [NotNullWhen(true)] out Type? genericType)
{
var genericArguments = type.GetTypeInfo().GetGenericTypeArguments();
Expand Down Expand Up @@ -196,79 +214,79 @@ public static bool IsCompatibleWith(Type source, Type target)
}
return false;
#else
if (source == target)
{
return true;
}
if (source == target)
{
return true;
}

if (!target.GetTypeInfo().IsValueType)
{
return target.IsAssignableFrom(source);
}
if (!target.GetTypeInfo().IsValueType)
{
return target.IsAssignableFrom(source);
}

Type st = GetNonNullableType(source);
Type tt = GetNonNullableType(target);
Type st = GetNonNullableType(source);
Type tt = GetNonNullableType(target);

if (st != source && tt == target)
{
return false;
}
if (st != source && tt == target)
{
return false;
}

Type sc = st.GetTypeInfo().IsEnum ? typeof(object) : st;
Type tc = tt.GetTypeInfo().IsEnum ? typeof(object) : tt;
Type sc = st.GetTypeInfo().IsEnum ? typeof(object) : st;
Type tc = tt.GetTypeInfo().IsEnum ? typeof(object) : tt;

if (sc == typeof(sbyte))
{
if (tc == typeof(sbyte) || tc == typeof(short) || tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(byte))
{
if (tc == typeof(byte) || tc == typeof(short) || tc == typeof(ushort) || tc == typeof(int) || tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(short))
{
if (tc == typeof(short) || tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(ushort))
{
if (tc == typeof(ushort) || tc == typeof(int) || tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(int))
{
if (tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(uint))
{
if (tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(long))
{
if (tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(ulong))
{
if (tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(float))
{
if (tc == typeof(float) || tc == typeof(double))
return true;
}

if (st == tt)
{
if (sc == typeof(sbyte))
{
if (tc == typeof(sbyte) || tc == typeof(short) || tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
}
else if (sc == typeof(byte))
{
if (tc == typeof(byte) || tc == typeof(short) || tc == typeof(ushort) || tc == typeof(int) || tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(short))
{
if (tc == typeof(short) || tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(ushort))
{
if (tc == typeof(ushort) || tc == typeof(int) || tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(int))
{
if (tc == typeof(int) || tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(uint))
{
if (tc == typeof(uint) || tc == typeof(long) || tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(long))
{
if (tc == typeof(long) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(ulong))
{
if (tc == typeof(ulong) || tc == typeof(float) || tc == typeof(double) || tc == typeof(decimal))
return true;
}
else if (sc == typeof(float))
{
if (tc == typeof(float) || tc == typeof(double))
return true;
}

return false;
if (st == tt)
{
return true;
}

return false;
#endif
}

Expand Down Expand Up @@ -391,19 +409,19 @@ private static int GetNumericTypeKind(Type type)
return 0;
}
#else
if (type.GetTypeInfo().IsEnum)
{
return 0;
}
if (type.GetTypeInfo().IsEnum)
{
return 0;
}

if (type == typeof(char) || type == typeof(float) || type == typeof(double) || type == typeof(decimal))
return 1;
if (type == typeof(sbyte) || type == typeof(short) || type == typeof(int) || type == typeof(long))
return 2;
if (type == typeof(byte) || type == typeof(ushort) || type == typeof(uint) || type == typeof(ulong))
return 3;
if (type == typeof(char) || type == typeof(float) || type == typeof(double) || type == typeof(decimal))
return 1;
if (type == typeof(sbyte) || type == typeof(short) || type == typeof(int) || type == typeof(long))
return 2;
if (type == typeof(byte) || type == typeof(ushort) || type == typeof(uint) || type == typeof(ulong))
return 3;

return 0;
return 0;
#endif
}

Expand Down Expand Up @@ -484,7 +502,7 @@ private static void AddInterface(ICollection<Type> types, Type type)

public static bool TryParseEnum(string value, Type? type, [NotNullWhen(true)] out object? enumValue)
{
if (type is { } && type.GetTypeInfo().IsEnum && Enum.IsDefined(type, value))
if (type != null && type.GetTypeInfo().IsEnum && Enum.IsDefined(type, value))
{
enumValue = Enum.Parse(type, value, true);
return true;
Expand Down
57 changes: 57 additions & 0 deletions test/System.Linq.Dynamic.Core.Tests/ExpressionTests.Sum.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using System.Linq.Dynamic.Core.Tests.Helpers.Models;
using Xunit;

namespace System.Linq.Dynamic.Core.Tests;

public partial class ExpressionTests
{
[Fact]
public void ExpressionTests_Sum()
{
// Arrange
int[] initValues = [1, 2, 3, 4, 5];
var qry = initValues.AsQueryable().Select(x => new { strValue = "str", intValue = x }).GroupBy(x => x.strValue);

// Act
var result = qry.Select("Sum(intValue)").AsDynamicEnumerable().ToArray()[0];

// Assert
Assert.Equal(15, result);
}

[Fact]
public void ExpressionTests_Sum_LowerCase()
{
// Arrange
int[] initValues = [1, 2, 3, 4, 5];
var qry = initValues.AsQueryable().Select(x => new { strValue = "str", intValue = x }).GroupBy(x => x.strValue);

// Act
var result = qry.Select("sum(intValue)").AsDynamicEnumerable().ToArray()[0];

// Assert
Assert.Equal(15, result);
}

[Fact]
public void ExpressionTests_Sum2()
{
// Arrange
var initValues = new[]
{
new SimpleValuesModel { FloatValue = 1 },
new SimpleValuesModel { FloatValue = 2 },
new SimpleValuesModel { FloatValue = 3 },
};

var qry = initValues.AsQueryable();

// Act
var result = qry.Select("FloatValue").Sum();
var result2 = ((IQueryable<float>)qry.Select("FloatValue")).Sum();

// Assert
Assert.Equal(6.0f, result);
Assert.Equal(6.0f, result2);
}
}
51 changes: 0 additions & 51 deletions test/System.Linq.Dynamic.Core.Tests/ExpressionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2154,57 +2154,6 @@ public void ExpressionTests_Subtract_Number()
Check.That(result).ContainsExactly(expected);
}

[Fact]
public void ExpressionTests_Sum()
{
// Arrange
int[] initValues = { 1, 2, 3, 4, 5 };
var qry = initValues.AsQueryable().Select(x => new { strValue = "str", intValue = x }).GroupBy(x => x.strValue);

// Act
var result = qry.Select("Sum(intValue)").AsDynamicEnumerable().ToArray()[0];

// Assert
Assert.Equal(15, result);
}

[Fact]
public void ExpressionTests_Sum_LowerCase()
{
// Arrange
int[] initValues = { 1, 2, 3, 4, 5 };
var qry = initValues.AsQueryable().Select(x => new { strValue = "str", intValue = x }).GroupBy(x => x.strValue);

// Act
var result = qry.Select("sum(intValue)").AsDynamicEnumerable().ToArray()[0];

// Assert
Assert.Equal(15, result);
}

[Fact]
public void ExpressionTests_Sum2()
{
// Arrange
var initValues = new[]
{
new SimpleValuesModel { FloatValue = 1 },
new SimpleValuesModel { FloatValue = 2 },
new SimpleValuesModel { FloatValue = 3 },
};

var qry = initValues.AsQueryable();

// Act
var result = qry.Select("FloatValue").Sum();
var result2 = ((IQueryable<float>)qry.Select("FloatValue")).Sum();

// Assert
Assert.Equal(6.0f, result);
Assert.Equal(6.0f, result2);
}


[Fact]
public void ExpressionTests_Type_Integer()
{
Expand Down
Loading
Loading