Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve codegen for merged activation factories #1660

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions src/Authoring/WinRT.SourceGenerator/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using System.Reflection.Metadata.Ecma335;
using System.Reflection.PortableExecutable;
using System.Text;
using System.Threading;

namespace Generator
{
Expand Down Expand Up @@ -196,13 +197,43 @@ public void Generate()
Logger.Close();
}

private static bool ShouldEmitCallToTryGetDependentActivationFactory(GeneratorExecutionContext context)
{
if (!context.AnalyzerConfigOptions.GetCsWinRTMergeReferencedActivationFactories())
{
return false;
}

foreach (MetadataReference metadataReference in context.Compilation.References)
{
if (context.Compilation.GetAssemblyOrModuleSymbol(metadataReference) is not IAssemblySymbol assemblySymbol)
{
continue;
}

// Check if the current assembly is a WinRT component (we just need one)
if (MergeReferencedActivationFactoriesGenerator.TryGetDependentAssemblyExportsTypeName(
assemblySymbol,
context.Compilation,
CancellationToken.None,
out _))
{
return true;
}
}

return false;
}

/// <summary>
/// Generates the native exports for a WinRT component.
/// </summary>
/// <param name="context">The <see cref="GeneratorExecutionContext"/> value to use to produce source files.</param>
public static void GenerateWinRTNativeExports(GeneratorExecutionContext context)
{
context.AddSource("NativeExports.g.cs", """
StringBuilder builder = new();

builder.AppendLine("""
// <auto-generated/>
#pragma warning disable

Expand Down Expand Up @@ -240,11 +271,22 @@ public static int DllGetActivationFactory(void* activatableClassId, void** facto

IntPtr obj = GetActivationFactory(fullyQualifiedTypeName);

if ((void*)obj is null)
{
obj = TryGetDependentActivationFactory(fullyQualifiedTypeName);
}
""");

// Only emit this call if we have actually generated that. We want to avoid generating
// that default implementation in every single assembly the generator runs on.
if (ShouldEmitCallToTryGetDependentActivationFactory(context))
{
builder.AppendLine("""
if ((void*)obj is null)
{
obj = TryGetDependentActivationFactory(fullyQualifiedTypeName);
}

""");
}

builder.Append("""
if ((void*)obj is null)
{
*factory = null;
Expand Down Expand Up @@ -279,6 +321,8 @@ public static int DllCanUnloadNow()
}
}
""");

context.AddSource("NativeExports.g.cs", builder.ToString());
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
// Generate the chaining helper
context.RegisterImplementationSourceOutput(assemblyExportsTypeNames, static (context, assemblyExportsTypeNames) =>
{
if (assemblyExportsTypeNames.IsEmpty)
{
return;
}
StringBuilder builder = new();
builder.AppendLine("""
Expand All @@ -88,14 +93,10 @@ partial class Module
/// <param name="fullyQualifiedTypeName">The marshalled fully qualified type name of the activation factory to retrieve.</param>
/// <returns>The pointer to the activation factory that corresponds with the class specified by <paramref name="fullyQualifiedTypeName"/>.</returns>
internal static unsafe IntPtr TryGetDependentActivationFactory(ReadOnlySpan<char> fullyQualifiedTypeName)
{
""");
{
IntPtr obj;
if (!assemblyExportsTypeNames.IsEmpty)
{
builder.AppendLine(" IntPtr obj;");
builder.AppendLine();
}
""");
foreach (string assemblyExportsTypeName in assemblyExportsTypeNames)
{
Expand Down Expand Up @@ -129,7 +130,7 @@ internal static unsafe IntPtr TryGetDependentActivationFactory(ReadOnlySpan<char
/// <param name="token">The <see cref="CancellationToken"/> instance to use.</param>
/// <param name="name">The resulting type name, if found.</param>
/// <returns>Whether a type name was found.</returns>
private static bool TryGetDependentAssemblyExportsTypeName(
internal static bool TryGetDependentAssemblyExportsTypeName(
IAssemblySymbol assemblySymbol,
Compilation compilation,
CancellationToken token,
Expand Down