Skip to content

CPP-737: add support for generic interfaces #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions Dojo.AutoGenerators/AutoInterfaceGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ private class ClassDefinition
public string Name { get; set; }
public string Namespace { get; set; }
public List<string> Methods { get; set; } = new();
public bool IsGeneric { get; set; }
public string GenericArguments { get; set; }
public string GenericConstraints { get; set; }
public string FullName => IsGeneric ? $"{Name}{GenericArguments}" : Name;
}

public static string GetGenericTypeArguments(INamedTypeSymbol classSymbol)
{
if (!classSymbol.IsGenericType) return null;

var bdr = new StringBuilder();

bdr.Append('<');
var args = classSymbol.TypeArguments.Select(arg => arg.ToString()).ToList();
bdr.Append(string.Join(", ", args));
bdr.Append('>');

return bdr.ToString();
}

public static string GetInterfaceName(ITypeSymbol typeSymbol)
Expand Down Expand Up @@ -61,7 +79,7 @@ public static string GetMethodDefinition(IMethodSymbol method)
}

bdr.Append(GetParametersDefinition(method));
bdr.Append(GetGenericTypeConstraints(method));
bdr.Append(GetGenericTypeConstraints(method, symbol => symbol.IsGenericMethod, symbol => symbol.TypeParameters));
bdr.Append(';');
return bdr.ToString();
}
Expand Down Expand Up @@ -125,32 +143,50 @@ public static string GetParametersDefinition(IMethodSymbol method)
return "(" + string.Join(", ", parameters) + ")";
}

public static string GetGenericTypeConstraints(IMethodSymbol method)
public static string GetGenericTypeConstraints<T>(T symbol, Func<T, bool> isGeneric, Func<T, IEnumerable<ITypeParameterSymbol>> getTypeParameters)
{
if (!method.IsGenericMethod
|| method.TypeParameters.Length == 0)
var typeParameters = getTypeParameters(symbol)?.ToList() ?? new List<ITypeParameterSymbol>();
if (!isGeneric(symbol) || typeParameters.Count == 0)
{
return string.Empty;
}

StringBuilder bdr = new();
foreach (var typeParam in method.TypeParameters)
foreach (var typeParam in typeParameters)
{
if (typeParam.ConstraintTypes.Length > 0)
if (typeParam.ConstraintTypes.Length > 0 || typeParam.HasReferenceTypeConstraint ||
typeParam.HasValueTypeConstraint || typeParam.HasConstructorConstraint)
{
bdr.AppendLine();
bdr.Append($" where {typeParam.Name} : {string.Join(", ", typeParam.ConstraintTypes.Select(x => x.ToDisplayString()))}");
bdr.Append($" where {typeParam.Name} : ");
}
else if (typeParam.HasReferenceTypeConstraint)
else
{
bdr.AppendLine();
bdr.Append($" where {typeParam.Name} : class");
continue;
}

var constraints = new List<string>();
if (typeParam.ConstraintTypes.Length > 0)
{
constraints.Add($"{string.Join(", ", typeParam.ConstraintTypes.Select(x => x.ToDisplayString()))}");
}
else if (typeParam.HasValueTypeConstraint)

if (typeParam.HasReferenceTypeConstraint)
{
bdr.AppendLine();
bdr.Append($" where {typeParam.Name} : struct");
constraints.Add("class");
}

if (typeParam.HasValueTypeConstraint)
{
constraints.Add("struct");
}

if (typeParam.HasConstructorConstraint)
{
constraints.Add("new()");
}

bdr.Append(string.Join(", ", constraints));
}

return bdr.ToString();
Expand Down Expand Up @@ -188,10 +224,13 @@ public void Execute(GeneratorExecutionContext context)
{
var classDefinition = new ClassDefinition();

var symbolModel = semanticModel.GetDeclaredSymbol(classNode) as ITypeSymbol;
var symbolModel = semanticModel.GetDeclaredSymbol(classNode) as INamedTypeSymbol;

classDefinition.Name = GetInterfaceName(symbolModel);
classDefinition.Namespace = GetNamespaceFullName(symbolModel.ContainingNamespace);
classDefinition.IsGeneric = symbolModel.IsGenericType;
classDefinition.GenericArguments = GetGenericTypeArguments(symbolModel);
classDefinition.GenericConstraints = GetGenericTypeConstraints(symbolModel, symbol => symbol.IsGenericType, symbol => symbol.TypeParameters);

foreach (var member in symbolModel.GetMembers())
{
Expand Down Expand Up @@ -238,12 +277,12 @@ public void Execute(GeneratorExecutionContext context)
namespace {classDefinition.Namespace}
{{
[GeneratedCode(""Dojo.SourceGenerator"", ""{Assembly.GetExecutingAssembly().GetName().Version}"")]
public partial class {classDefinition.Name}: I{classDefinition.Name}
public partial class {classDefinition.FullName}: I{classDefinition.FullName}{classDefinition.GenericConstraints}
{{
}}

[GeneratedCode(""Dojo.SourceGenerator"", ""{Assembly.GetExecutingAssembly().GetName().Version}"")]
public interface I{classDefinition.Name}
public interface I{classDefinition.FullName}{classDefinition.GenericConstraints}
{{
");
// add the filepath of each tree to the class we're building
Expand Down
51 changes: 51 additions & 0 deletions Dojo.Generators.Tests/AutoInterfaceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,56 @@ public interface ITestFoo
// Assert
GeneratorTestHelper.CompareSources(expectedSource, actual);
}

[Theory]
[InlineData("TestFoo<T>", "TestFoo<T>", "ITestFoo<T>")]
[InlineData("TestFoo<T>\n where T : class", "TestFoo<T>", "ITestFoo<T>\n where T : class")]
[InlineData("TestFoo<T>\n where T : class, new()", "TestFoo<T>", "ITestFoo<T>\n where T : class, new()")]
[InlineData("TestFoo<T>\n where T : struct", "TestFoo<T>", "ITestFoo<T>\n where T : struct")]
[InlineData("TestFoo<T>\n where T : Dojo.Generators.Tests.AutoInterfaceTests", "TestFoo<T>", "ITestFoo<T>\n where T : Dojo.Generators.Tests.AutoInterfaceTests")]
[InlineData("TestFoo<T1, T2>", "TestFoo<T1, T2>", "ITestFoo<T1, T2>")]
[InlineData("TestFoo<T1, T2> where T1 : class", "TestFoo<T1, T2>", "ITestFoo<T1, T2>\n where T1 : class")]
[InlineData("TestFoo<T1, T2> where T1 : class where T2 : class", "TestFoo<T1, T2>", "ITestFoo<T1, T2>\n where T1 : class\n where T2 : class")]
public void GenericInterface_Generate(string source, string expectedType, string expectedInterface)
{
// Arrange
string userSource = $@"
using System;
using System.Text;
namespace Level1.Level2
{{
[AutoInterface]
public partial class {source}
{{
public void SomeMethod()
{{
return null;
}}
}}
}}";

string expectedSource = $@"
using System;
using System.CodeDom.Compiler;

namespace Level1.Level2
{{
[GeneratedCode(""Dojo.SourceGenerator"", ""{Assembly.GetExecutingAssembly().GetName().Version}"")]
public partial class {expectedType}: {expectedInterface}
{{
}}

[GeneratedCode(""Dojo.SourceGenerator"", ""{Assembly.GetExecutingAssembly().GetName().Version}"")]
public interface {expectedInterface}
{{
void SomeMethod();
}}
}}";
// Act
var actual = GeneratorTestHelper.GenerateFromSource<AutoInterfaceGenerator>(userSource);

// Assert
GeneratorTestHelper.CompareSources(expectedSource, actual);
}
}
}