Skip to content

Commit

Permalink
Reuse generated code when possible (#1821)
Browse files Browse the repository at this point in the history
* Initialize size improvements

* Also optimize ccw

* Fix build

* Add comments and renaming
  • Loading branch information
manodasanW authored Oct 13, 2024
1 parent eb67781 commit db00a61
Showing 1 changed file with 130 additions and 43 deletions.
173 changes: 130 additions & 43 deletions src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -773,21 +773,21 @@ private static void GenerateVtableAttributes(
GenerateVtableAttributes(sourceProductionContext.AddSource, value.vtableAttributes, value.context.properties.isCsWinRTComponent, value.context.escapedAssemblyName);
}

internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, string escapedAssemblyName)
internal static string GenerateVtableEntry(VtableEntry vtableEntry, string escapedAssemblyName)
{
StringBuilder source = new();

foreach (var genericInterface in vtableAttribute.GenericInterfaces)
foreach (var genericInterface in vtableEntry.GenericInterfaces)
{
source.AppendLine(GenericVtableInitializerStrings.GetInstantiationInitFunction(
genericInterface.GenericDefinition,
genericInterface.GenericParameters,
escapedAssemblyName));
}

if (vtableAttribute.IsDelegate)
if (vtableEntry.IsDelegate)
{
var @interface = vtableAttribute.Interfaces.First();
var @interface = vtableEntry.Interfaces.First();
source.AppendLine();
source.AppendLine($$"""
var delegateInterface = new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry
Expand All @@ -799,15 +799,15 @@ internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, stri
return global::WinRT.DelegateTypeDetails<{{@interface}}>.GetExposedInterfaces(delegateInterface);
""");
}
else if (vtableAttribute.Interfaces.Any())
else if (vtableEntry.Interfaces.Any())
{
source.AppendLine();
source.AppendLine($$"""
return new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry[]
{
""");

foreach (var @interface in vtableAttribute.Interfaces)
foreach (var @interface in vtableEntry.Interfaces)
{
var genericStartIdx = @interface.IndexOf('<');
var interfaceStaticsMethod = @interface[..(genericStartIdx == -1 ? @interface.Length : genericStartIdx)] + "Methods";
Expand Down Expand Up @@ -840,6 +840,10 @@ internal static string GenerateVtableEntry(VtableAttribute vtableAttribute, stri

internal static void GenerateVtableAttributes(Action<string, string> addSource, ImmutableArray<VtableAttribute> vtableAttributes, bool isCsWinRTComponentFromAotOptimizer, string escapedAssemblyName)
{
var vtableEntryToVtableClassName = new Dictionary<VtableEntry, string>();
StringBuilder vtableClassesSource = new();
bool firstVtableClass = true;

// Using ToImmutableHashSet to avoid duplicate entries from the use of partial classes by the developer
// to split out their implementation. When they do that, we will get multiple entries here for that
// and try to generate the same attribute and file with the same data as we use the semantic model
Expand All @@ -850,11 +854,10 @@ internal static void GenerateVtableAttributes(Action<string, string> addSource,
// from the AOT optimizer, then any public types are not handled
// right now as they are handled by the WinRT component source generator
// calling this.
if (((isCsWinRTComponentFromAotOptimizer && !vtableAttribute.IsPublic) || !isCsWinRTComponentFromAotOptimizer) &&
if (((isCsWinRTComponentFromAotOptimizer && !vtableAttribute.IsPublic) || !isCsWinRTComponentFromAotOptimizer) &&
vtableAttribute.Interfaces.Any())
{
StringBuilder source = new();
source.AppendLine("using static WinRT.TypeExtensions;\n");
if (!vtableAttribute.IsGlobalNamespace)
{
source.AppendLine($$"""
Expand All @@ -863,6 +866,16 @@ namespace {{vtableAttribute.Namespace}}
""");
}

// Check if this class shares the same vtable as another class. If so, reuse the same generated class for it.
VtableEntry entry = new(vtableAttribute.Interfaces, vtableAttribute.GenericInterfaces, vtableAttribute.IsDelegate);
bool vtableEntryExists = vtableEntryToVtableClassName.TryGetValue(entry, out var ccwClassName);
if (!vtableEntryExists)
{
var @namespace = vtableAttribute.IsGlobalNamespace ? "" : $"{vtableAttribute.Namespace}.";
ccwClassName = GeneratorHelper.EscapeTypeNameForIdentifier(@namespace + vtableAttribute.ClassName);
vtableEntryToVtableClassName.Add(entry, ccwClassName);
}

var escapedClassName = GeneratorHelper.EscapeTypeNameForIdentifier(vtableAttribute.ClassName);

// Simple case when the type is not nested
Expand All @@ -874,7 +887,7 @@ namespace {{vtableAttribute.Namespace}}
}

source.AppendLine($$"""
[global::WinRT.WinRTExposedType(typeof({{escapedClassName}}WinRTTypeDetails))]
[global::WinRT.WinRTExposedType(typeof(global::WinRT.{{escapedAssemblyName}}VtableClasses.{{ccwClassName}}WinRTTypeDetails))]
partial class {{vtableAttribute.ClassName}}
{
}
Expand All @@ -900,7 +913,7 @@ partial class {{vtableAttribute.ClassName}}
}

source.AppendLine($$"""
[global::WinRT.WinRTExposedType(typeof({{escapedClassName}}WinRTTypeDetails))]
[global::WinRT.WinRTExposedType(typeof(global::WinRT.{{escapedAssemblyName}}VtableClasses.{{ccwClassName}}WinRTTypeDetails))]
partial {{classHierarchy[0].GetTypeKeyword()}} {{classHierarchy[0].QualifiedName}}
{
}
Expand All @@ -913,62 +926,78 @@ partial class {{vtableAttribute.ClassName}}
}
}

source.AppendLine();
source.AppendLine($$"""
internal sealed class {{escapedClassName}}WinRTTypeDetails : global::WinRT.IWinRTExposedTypeDetails
// Only generate class, if this is the first time we run into this set of vtables.
if (!vtableEntryExists)
{
if (firstVtableClass)
{
vtableClassesSource.AppendLine($$"""
namespace WinRT.{{escapedAssemblyName}}VtableClasses
{
""");
firstVtableClass = false;
}
else
{
vtableClassesSource.AppendLine();
}

vtableClassesSource.AppendLine($$"""
internal sealed class {{ccwClassName}}WinRTTypeDetails : global::WinRT.IWinRTExposedTypeDetails
{
public global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry[] GetExposedInterfaces()
{
""");

if (vtableAttribute.Interfaces.Any())
{
foreach (var genericInterface in vtableAttribute.GenericInterfaces)
if (vtableAttribute.Interfaces.Any())
{
source.AppendLine(GenericVtableInitializerStrings.GetInstantiationInitFunction(
genericInterface.GenericDefinition,
genericInterface.GenericParameters,
escapedAssemblyName));
}
foreach (var genericInterface in vtableAttribute.GenericInterfaces)
{
vtableClassesSource.AppendLine(GenericVtableInitializerStrings.GetInstantiationInitFunction(
genericInterface.GenericDefinition,
genericInterface.GenericParameters,
escapedAssemblyName));
}

source.AppendLine();
source.AppendLine($$"""
vtableClassesSource.AppendLine();
vtableClassesSource.AppendLine($$"""
return new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry[]
{
""");

foreach (var @interface in vtableAttribute.Interfaces)
{
var genericStartIdx = @interface.IndexOf('<');
var interfaceStaticsMethod = @interface[..(genericStartIdx == -1 ? @interface.Length : genericStartIdx)] + "Methods";
if (genericStartIdx != -1)
foreach (var @interface in vtableAttribute.Interfaces)
{
interfaceStaticsMethod += @interface[genericStartIdx..@interface.Length];
}
var genericStartIdx = @interface.IndexOf('<');
var interfaceStaticsMethod = @interface[..(genericStartIdx == -1 ? @interface.Length : genericStartIdx)] + "Methods";
if (genericStartIdx != -1)
{
interfaceStaticsMethod += @interface[genericStartIdx..@interface.Length];
}

source.AppendLine($$"""
vtableClassesSource.AppendLine($$"""
new global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry
{
IID = global::ABI.{{interfaceStaticsMethod}}.IID,
Vtable = global::ABI.{{interfaceStaticsMethod}}.AbiToProjectionVftablePtr
},
""");
}
source.AppendLine($$"""
}
vtableClassesSource.AppendLine($$"""
};
""");
}
else
{
source.AppendLine($$"""
}
else
{
vtableClassesSource.AppendLine($$"""
return global::System.Array.Empty<global::System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry>();
""");
}
}

source.AppendLine($$"""
vtableClassesSource.AppendLine($$"""
}
}
""");
}

if (!vtableAttribute.IsGlobalNamespace)
{
Expand All @@ -979,6 +1008,12 @@ internal sealed class {{escapedClassName}}WinRTTypeDetails : global::WinRT.IWinR
addSource($"{prefix}{escapedClassName}.WinRTVtable.g.cs", source.ToString());
}
}

if (vtableClassesSource.Length != 0)
{
vtableClassesSource.AppendLine("}");
addSource($"WinRTCCWVtable.g.cs", vtableClassesSource.ToString());
}
}

private static void GenerateCCWForGenericInstantiation(
Expand Down Expand Up @@ -1444,12 +1479,37 @@ private static ComWrappers.ComInterfaceEntry[] LookupVtableEntries(Type type)
""");
}

// We gather all the class names that have the same vtable and generate it
// as part of one if to reduce generated code.
var vtableEntryToClassNameList = new Dictionary<VtableEntry, List<string>>();
foreach (var vtableAttribute in value.vtableAttributes.ToImmutableHashSet())
{
VtableEntry entry = new(vtableAttribute.Interfaces, vtableAttribute.GenericInterfaces, vtableAttribute.IsDelegate);
if (!vtableEntryToClassNameList.TryGetValue(entry, out var classNameList))
{
classNameList = new List<string>();
vtableEntryToClassNameList.Add(entry, classNameList);
}
classNameList.Add(vtableAttribute.VtableLookupClassName);
}

foreach (var vtableEntry in vtableEntryToClassNameList)
{
source.AppendLine($$"""
if (typeName == "{{vtableEntry.Value[0]}}"
""");

for (var i = 1; i < vtableEntry.Value.Count; i++)
{
source.AppendLine($$"""
|| typeName == "{{vtableEntry.Value[i]}}"
""");
}

source.AppendLine($$"""
if (typeName == "{{vtableAttribute.VtableLookupClassName}}")
)
{
{{GenerateVtableEntry(vtableAttribute, value.context.escapedAssemblyName)}}
{{GenerateVtableEntry(vtableEntry.Key, value.context.escapedAssemblyName)}}
}
""");
}
Expand All @@ -1469,12 +1529,34 @@ private static string LookupRuntimeClassName(Type type)
string typeName = type.ToString();
""");

var runtimeClassNameToClassNameList = new Dictionary<string, List<string>>();
foreach (var vtableAttribute in value.vtableAttributes.ToImmutableHashSet().Where(static v => !string.IsNullOrEmpty(v.RuntimeClassName)))
{
if (!runtimeClassNameToClassNameList.TryGetValue(vtableAttribute.RuntimeClassName, out var classNameList))
{
classNameList = new List<string>();
runtimeClassNameToClassNameList.Add(vtableAttribute.RuntimeClassName, classNameList);
}
classNameList.Add(vtableAttribute.VtableLookupClassName);
}

foreach (var entry in runtimeClassNameToClassNameList)
{
source.AppendLine($$"""
if (typeName == "{{vtableAttribute.VtableLookupClassName}}")
if (typeName == "{{entry.Value[0]}}"
""");

for (var i = 1; i < entry.Value.Count; i++)
{
source.AppendLine($$"""
|| typeName == "{{entry.Value[i]}}"
""");
}

source.AppendLine($$"""
)
{
return "{{vtableAttribute.RuntimeClassName}}";
return "{{entry.Key}}";
}
""");
}
Expand Down Expand Up @@ -1630,6 +1712,11 @@ internal sealed record VtableAttribute(
bool IsPublic,
string RuntimeClassName = default);

sealed record VtableEntry(
EquatableArray<string> Interfaces,
EquatableArray<GenericInterface> GenericInterfaces,
bool IsDelegate);

internal readonly record struct BindableCustomProperty(
string Name,
string Type,
Expand Down

0 comments on commit db00a61

Please sign in to comment.