Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public class ProjectableDescriptor

public ParameterListSyntax? ParametersList { get; set; }

public IEnumerable<string>? ParameterTypeNames { get; set; }

public TypeParameterListSyntax? TypeParameterList { get; set; }

public SyntaxList<TypeParameterConstraintClauseSyntax>? ConstraintClauses { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ x is IPropertySymbol xProperty &&
var expressionSyntaxRewriter = new ExpressionSyntaxRewriter(memberSymbol.ContainingType, nullConditionalRewriteSupport, semanticModel, context);
var declarationSyntaxRewriter = new DeclarationSyntaxRewriter(semanticModel);

var methodSymbol = memberSymbol as IMethodSymbol;

var descriptor = new ProjectableDescriptor {

UsingDirectives = member.SyntaxTree.GetRoot().DescendantNodes().OfType<UsingDirectiveSyntax>(),
Expand All @@ -128,6 +130,14 @@ x is IPropertySymbol xProperty &&
ParametersList = SyntaxFactory.ParameterList()
};

// Collect parameter type names for method overload disambiguation
if (methodSymbol is not null)
{
descriptor.ParameterTypeNames = methodSymbol.Parameters
.Select(p => p.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
.ToList();
}

if (memberSymbol.ContainingType is INamedTypeSymbol { IsGenericType: true } containingNamedType)
{
descriptor.ClassTypeParameterList = SyntaxFactory.TypeParameterList();
Expand Down Expand Up @@ -196,8 +206,6 @@ x is IPropertySymbol xProperty &&
);
}

var methodSymbol = memberSymbol as IMethodSymbol;

if (methodSymbol is { IsExtensionMethod: true })
{
var targetTypeSymbol = methodSymbol.Parameters.First().Type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ static void Execute(MemberDeclarationSyntax member, Compilation compilation, Sou
throw new InvalidOperationException("Expected a memberName here");
}

var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName);
var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName, projectable.ParameterTypeNames);
var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs";

var classSyntax = ClassDeclaration(generatedClassName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,31 @@ public static class ProjectionExpressionClassNameGenerator
public const string Namespace = "EntityFrameworkCore.Projectables.Generated";

public static string GenerateName(string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName)
{
return GenerateName(namespaceName, nestedInClassNames, memberName, null);
}

public static string GenerateName(string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName, IEnumerable<string>? parameterTypeNames)
{
var stringBuilder = new StringBuilder();

return GenerateNameImpl(stringBuilder, namespaceName, nestedInClassNames, memberName);
return GenerateNameImpl(stringBuilder, namespaceName, nestedInClassNames, memberName, parameterTypeNames);
}

public static string GenerateFullName(string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName)
{
return GenerateFullName(namespaceName, nestedInClassNames, memberName, null);
}

public static string GenerateFullName(string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName, IEnumerable<string>? parameterTypeNames)
{
var stringBuilder = new StringBuilder(Namespace);
stringBuilder.Append('.');

return GenerateNameImpl(stringBuilder, namespaceName, nestedInClassNames, memberName);
return GenerateNameImpl(stringBuilder, namespaceName, nestedInClassNames, memberName, parameterTypeNames);
}

static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName)
static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceName, IEnumerable<string>? nestedInClassNames, string memberName, IEnumerable<string>? parameterTypeNames)
{
stringBuilder.Append(namespaceName?.Replace('.', '_'));
stringBuilder.Append('_');
Expand Down Expand Up @@ -57,6 +67,35 @@ static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceNam
}
stringBuilder.Append(memberName);

// Add parameter types to make method overloads unique
if (parameterTypeNames is not null)
{
var parameterIndex = 0;
foreach (var parameterTypeName in parameterTypeNames)
{
stringBuilder.Append("_P");
stringBuilder.Append(parameterIndex);
stringBuilder.Append('_');
// Replace characters that are not valid in type names with underscores
var sanitizedTypeName = parameterTypeName
.Replace("global::", "") // Remove global:: prefix
.Replace('.', '_')
.Replace('<', '_')
.Replace('>', '_')
.Replace(',', '_')
.Replace(' ', '_')
.Replace('[', '_')
.Replace(']', '_')
.Replace('`', '_')
.Replace(':', '_') // Additional safety for any remaining colons
.Replace('?', '_'); // Handle nullable reference types
stringBuilder.Append(sanitizedTypeName);
parameterIndex++;
}
}

// Add generic arity at the very end (after parameter types)
// This matches how the CLR names generic types
if (arity > 0)
{
stringBuilder.Append('`');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,41 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo
static LambdaExpression? GetExpressionFromGeneratedType(MemberInfo projectableMemberInfo)
{
var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here");
var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name);

// Keep track of the original declaring type's generic arguments for later use
var originalDeclaringType = declaringType;

// For generic types, use the generic type definition to match the generated name
// which is based on the open generic type
if (declaringType.IsGenericType && !declaringType.IsGenericTypeDefinition)
{
declaringType = declaringType.GetGenericTypeDefinition();
}

// Get parameter types for method overload disambiguation
// Use the same format as Roslyn's SymbolDisplayFormat.FullyQualifiedFormat
// which uses C# keywords for primitive types (int, string, etc.)
string[]? parameterTypeNames = null;
if (projectableMemberInfo is MethodInfo method)
{
// For generic methods, use the generic definition to get parameter types
// This ensures type parameters like TEntity are used instead of concrete types
var methodToInspect = method.IsGenericMethod ? method.GetGenericMethodDefinition() : method;

parameterTypeNames = methodToInspect.GetParameters()
.Select(p => GetFullTypeName(p.ParameterType))
.ToArray();
}

var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name, parameterTypeNames);

var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName);

if (expressionFactoryType is not null)
{
if (expressionFactoryType.IsGenericTypeDefinition)
{
expressionFactoryType = expressionFactoryType.MakeGenericType(declaringType.GenericTypeArguments);
expressionFactoryType = expressionFactoryType.MakeGenericType(originalDeclaringType.GenericTypeArguments);
}

var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic);
Expand All @@ -93,6 +119,92 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo

return null;
}

static string GetFullTypeName(Type type)
{
// Handle generic type parameters (e.g., T, TEntity)
if (type.IsGenericParameter)
{
return type.Name;
}

// Handle array types
if (type.IsArray)
{
var elementType = type.GetElementType();
if (elementType == null)
{
// Fallback for edge cases where GetElementType() might return null
return type.Name;
}

var rank = type.GetArrayRank();
var elementTypeName = GetFullTypeName(elementType);

if (rank == 1)
{
return $"{elementTypeName}[]";
}
else
{
var commas = new string(',', rank - 1);
return $"{elementTypeName}[{commas}]";
}
}

// Map primitive types to their C# keyword equivalents to match Roslyn's output
var typeKeyword = GetCSharpKeyword(type);
if (typeKeyword != null)
{
return typeKeyword;
}

// For generic types, construct the full name matching Roslyn's format
if (type.IsGenericType)
{
var genericTypeDef = type.GetGenericTypeDefinition();
var genericArgs = type.GetGenericArguments();
var baseName = genericTypeDef.FullName ?? genericTypeDef.Name;

// Remove the `n suffix (e.g., `1, `2)
var backtickIndex = baseName.IndexOf('`');
if (backtickIndex > 0)
{
baseName = baseName.Substring(0, backtickIndex);
}

var args = string.Join(", ", genericArgs.Select(GetFullTypeName));
return $"{baseName}<{args}>";
}

if (type.FullName != null)
{
// Replace + with . for nested types to match Roslyn's format
return type.FullName.Replace('+', '.');
}

return type.Name;
}

static string? GetCSharpKeyword(Type type)
{
if (type == typeof(bool)) return "bool";
if (type == typeof(byte)) return "byte";
if (type == typeof(sbyte)) return "sbyte";
if (type == typeof(char)) return "char";
if (type == typeof(decimal)) return "decimal";
if (type == typeof(double)) return "double";
if (type == typeof(float)) return "float";
if (type == typeof(int)) return "int";
if (type == typeof(uint)) return "uint";
if (type == typeof(long)) return "long";
if (type == typeof(ulong)) return "ulong";
if (type == typeof(short)) return "short";
if (type == typeof(ushort)) return "ushort";
if (type == typeof(object)) return "object";
if (type == typeof(string)) return "string";
return null;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id], [e].[Id] + 10 AS [Result]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [e].[Id], CAST(LEN(N'Hello_' + [e].[Name]) AS int) AS [Result]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using System.Linq;
using System.Threading.Tasks;
using EntityFrameworkCore.Projectables;
using EntityFrameworkCore.Projectables.FunctionalTests.Helpers;
using Microsoft.EntityFrameworkCore;
using VerifyXunit;
using Xunit;

namespace EntityFrameworkCore.Projectables.FunctionalTests
{
[UsesVerify]
public class MethodOverloadsTests
{
public record Entity
{
public int Id { get; set; }
public string Name { get; set; } = "";

[Projectable]
public int Calculate(int x) => Id + x;

[Projectable]
public int Calculate(string prefix) => (prefix + Name).Length;
}

[Fact]
public Task MethodOverload_WithIntParameter()
{
using var dbContext = new SampleDbContext<Entity>();

var query = dbContext.Set<Entity>()
.Select(e => new { e.Id, Result = e.Calculate(10) });

return Verifier.Verify(query.ToQueryString());
}

[Fact]
public Task MethodOverload_WithStringParameter()
{
using var dbContext = new SampleDbContext<Entity>();

var query = dbContext.Set<Entity>()
.Select(e => new { e.Id, Result = e.Calculate("Hello_") });

return Verifier.Verify(query.ToQueryString());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Foo;
namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class Foo_C_Test
static class Foo_C_Test_P0_object
{
static global::System.Linq.Expressions.Expression<global::System.Func<object, bool>> Expression()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Projectables.Repro;
namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class Projectables_Repro_SomeExtensions_AsSomeResult
static class Projectables_Repro_SomeExtensions_AsSomeResult_P0_Projectables_Repro_SomeEntity
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::Projectables.Repro.SomeEntity, string>> Expression()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using Foo;
namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class Foo_EntityExtensions_Entity_Something
static class Foo_EntityExtensions_Entity_Something_P0_Foo_EntityExtensions_Entity
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.EntityExtensions.Entity, global::Foo.EntityExtensions.Entity>> Expression()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using EntityFrameworkCore.Projectables;
namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class _Foo_Calculate
static class _Foo_Calculate_P0_int
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo, int, int>> Expression()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using EntityFrameworkCore.Projectables;
namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class _SomeExtensions_Test
static class _SomeExtensions_Test_P0_SomeFlag
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::SomeFlag, bool>> Expression()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using Foo;
namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class Foo_EntityExtensions_EnforceString
static class Foo_EntityExtensions_EnforceString_P0_T
{
static global::System.Linq.Expressions.Expression<global::System.Func<T, string>> Expression<T>()
where T : unmanaged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using Foo;
namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class Foo_C_NextFoo
static class Foo_C_NextFoo_P0_System_Collections_Generic_List_object__P1_System_Collections_Generic_List_int__
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::System.Collections.Generic.List<object>, global::System.Collections.Generic.List<int?>, global::System.Collections.Generic.List<object>>> Expression()
{
Expand Down
Loading