Skip to content

[wasm] Lazy init of [JSExport] bindings #77293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ internal static class Comparers
/// Comparer for an individual generated stub source as a syntax tree and the generated diagnostics for the stub.
/// </summary>
public static readonly IEqualityComparer<(MemberDeclarationSyntax Syntax, ImmutableArray<Diagnostic> Diagnostics)> GeneratedSyntax = new CustomValueTupleElementComparer<MemberDeclarationSyntax, ImmutableArray<Diagnostic>>(SyntaxEquivalentComparer.Instance, new ImmutableArraySequenceEqualComparer<Diagnostic>(EqualityComparer<Diagnostic>.Default));
public static readonly IEqualityComparer<(MemberDeclarationSyntax, StatementSyntax, AttributeListSyntax, ImmutableArray<Diagnostic>)> GeneratedSyntax4 =
new CustomValueTupleElementComparer<MemberDeclarationSyntax, StatementSyntax, AttributeListSyntax, ImmutableArray<Diagnostic>>(
SyntaxEquivalentComparer.Instance, SyntaxEquivalentComparer.Instance, SyntaxEquivalentComparer.Instance,
new ImmutableArraySequenceEqualComparer<Diagnostic>(EqualityComparer<Diagnostic>.Default));
}

/// <summary>
Expand Down Expand Up @@ -67,4 +71,34 @@ public int GetHashCode((T, U) obj)
throw new UnreachableException();
}
}

internal sealed class CustomValueTupleElementComparer<T, U, V, W> : IEqualityComparer<(T, U, V, W)>
{
private readonly IEqualityComparer<T> _item1Comparer;
private readonly IEqualityComparer<U> _item2Comparer;
private readonly IEqualityComparer<V> _item3Comparer;
private readonly IEqualityComparer<W> _item4Comparer;

public CustomValueTupleElementComparer(IEqualityComparer<T> item1Comparer, IEqualityComparer<U> item2Comparer, IEqualityComparer<V> item3Comparer, IEqualityComparer<W> item4Comparer)
{
_item1Comparer = item1Comparer;
_item2Comparer = item2Comparer;
_item3Comparer = item3Comparer;
_item4Comparer = item4Comparer;
}

public bool Equals((T, U, V, W) x, (T, U, V, W) y)
{
return _item1Comparer.Equals(x.Item1, y.Item1)
&& _item2Comparer.Equals(x.Item2, y.Item2)
&& _item3Comparer.Equals(x.Item3, y.Item3)
&& _item4Comparer.Equals(x.Item4, y.Item4)
;
}

public int GetHashCode((T, U, V, W) obj)
{
throw new UnreachableException();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ internal static class Constants
public const string JSFunctionSignatureGlobal = "global::System.Runtime.InteropServices.JavaScript.JSFunctionBinding";
public const string JSMarshalerArgumentGlobal = "global::System.Runtime.InteropServices.JavaScript.JSMarshalerArgument";
public const string ModuleInitializerAttributeGlobal = "global::System.Runtime.CompilerServices.ModuleInitializerAttribute";
public const string CompilerGeneratedAttributeGlobal = "global::System.Runtime.CompilerServices.CompilerGeneratedAttribute";
public const string DynamicDependencyAttributeGlobal = "global::System.Diagnostics.CodeAnalysis.DynamicDependencyAttribute";
public const string ThreadStaticGlobal = "global::System.ThreadStaticAttribute";
public const string TaskGlobal = "global::System.Threading.Tasks.Task";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,27 +123,35 @@ public BlockSyntax GenerateJSExportBody()
return Block(allStatements);
}

public BlockSyntax GenerateJSExportRegistration()
public static StatementSyntax[] GenerateJSExportArchitectureCheck()
{
var registrationStatements = new List<StatementSyntax>();
registrationStatements.Add(IfStatement(
BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(Constants.OSArchitectureGlobal),
IdentifierName(Constants.ArchitectureWasmGlobal)),
ReturnStatement()));

var signatureArgs = new List<ArgumentSyntax>();

signatureArgs.Add(Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(_signatureContext.QualifiedMethodName))));
signatureArgs.Add(Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(_signatureContext.TypesHash))));
return new StatementSyntax[]{
IfStatement(
BinaryExpression(SyntaxKind.LogicalOrExpression,
IdentifierName("initialized"),
BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(Constants.OSArchitectureGlobal),
IdentifierName(Constants.ArchitectureWasmGlobal))),
ReturnStatement()),
ExpressionStatement(
AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName("initialized"),
LiteralExpression(SyntaxKind.TrueLiteralExpression))),
};
}

signatureArgs.Add(CreateSignaturesSyntax());
public StatementSyntax GenerateJSExportRegistration()
{
var signatureArgs = new List<ArgumentSyntax>
{
Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(_signatureContext.QualifiedMethodName))),
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(_signatureContext.TypesHash))),
CreateSignaturesSyntax()
};

registrationStatements.Add(ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
return ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(Constants.JSFunctionSignatureGlobal), IdentifierName(Constants.BindCSFunctionMethod)))
.WithArgumentList(ArgumentList(SeparatedList(signatureArgs)))));

return Block(List(registrationStatements));
.WithArgumentList(ArgumentList(SeparatedList(signatureArgs))));
}

private ArgumentSyntax CreateSignaturesSyntax()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using System.Collections.Generic;

namespace Microsoft.Interop.JavaScript
{
Expand Down Expand Up @@ -73,7 +75,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return ImmutableArray.Create(Diagnostic.Create(GeneratorDiagnostics.JSExportRequiresAllowUnsafeBlocks, null));
}));

IncrementalValuesProvider<(MemberDeclarationSyntax, ImmutableArray<Diagnostic>)> generateSingleStub = methodsToGenerate
IncrementalValuesProvider<(MemberDeclarationSyntax, StatementSyntax, AttributeListSyntax, ImmutableArray<Diagnostic>)> generateSingleStub = methodsToGenerate
.Combine(stubEnvironment)
.Combine(stubOptions)
.Select(static (data, ct) => new
Expand All @@ -90,42 +92,64 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Select(
static (data, ct) => GenerateSource(data)
)
.WithComparer(Comparers.GeneratedSyntax)
.WithComparer(Comparers.GeneratedSyntax4)
.WithTrackingName(StepNames.GenerateSingleStub);

context.RegisterDiagnostics(generateSingleStub.SelectMany((stubInfo, ct) => stubInfo.Item2));
context.RegisterDiagnostics(generateSingleStub.SelectMany((stubInfo, ct) => stubInfo.Item4));

IncrementalValueProvider<ImmutableArray<(StatementSyntax, AttributeListSyntax)>> regSyntax = generateSingleStub
.Select(
static (data, ct) => (data.Item2, data.Item3))
.Collect();

IncrementalValueProvider<string> registration = regSyntax
.Select(static (data, ct) => GenerateRegSource(data))
.Select(static (data, ct) => data.NormalizeWhitespace().ToFullString());

IncrementalValueProvider<ImmutableArray<(string, string)>> generated = generateSingleStub
.Combine(registration)
.Select(
static (data, ct) => (data.Left.Item1.NormalizeWhitespace().ToFullString(), data.Right))
.Collect();


context.RegisterSourceOutput(generated,
(context, generatedSources) =>
{
// Don't generate a file if we don't have to, to avoid the extra IDE overhead once we have generated
// files in play.
if (generatedSources.IsEmpty)
return;

StringBuilder source = new();
// Mark in source that the file is auto-generated.
source.AppendLine("// <auto-generated/>");
// this is the assembly level registration
source.AppendLine(generatedSources[0].Item2);
// this is the method wrappers to be called from JS
foreach (var generated in generatedSources)
{
source.AppendLine(generated.Item1);
}

// Once https://github.com/dotnet/roslyn/issues/61326 is resolved, we can avoid the ToString() here.
context.AddSource("JSExports.g.cs", source.ToString());
});

context.RegisterConcatenatedSyntaxOutputs(generateSingleStub.Select((data, ct) => data.Item1), "JSExports.g.cs");
}

private static MemberDeclarationSyntax PrintGeneratedSource(
ContainingSyntax userDeclaredMethod,
JSSignatureContext stub,
ContainingSyntaxContext containingSyntaxContext,
BlockSyntax wrapperStatements, BlockSyntax registerStatements)
BlockSyntax wrapperStatements, string wrapperName)
{
var WrapperName = "__Wrapper_" + userDeclaredMethod.Identifier + "_" + stub.TypesHash;
var RegistrationName = "__Register_" + userDeclaredMethod.Identifier + "_" + stub.TypesHash;

MemberDeclarationSyntax wrappperMethod = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), Identifier(WrapperName))
MemberDeclarationSyntax wrappperMethod = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), Identifier(wrapperName))
.WithModifiers(TokenList(new[] { Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.UnsafeKeyword) }))
.WithParameterList(ParameterList(SingletonSeparatedList(
Parameter(Identifier("__arguments_buffer")).WithType(PointerType(ParseTypeName(Constants.JSMarshalerArgumentGlobal))))))
.WithBody(wrapperStatements);

MemberDeclarationSyntax registerMethod = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), Identifier(RegistrationName))
.WithAttributeLists(List(new AttributeListSyntax[]{
AttributeList(SingletonSeparatedList(Attribute(IdentifierName(Constants.ModuleInitializerAttributeGlobal)))),
AttributeList(SingletonSeparatedList(Attribute(IdentifierName(Constants.DynamicDependencyAttributeGlobal))
.WithArgumentList(AttributeArgumentList(SeparatedList(new[]{
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(WrapperName))),
AttributeArgument(TypeOfExpression(ParseTypeName(stub.StubTypeFullName)))}
)))))}))
.WithModifiers(TokenList(new[] { Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword) }))
.WithBody(registerStatements);


MemberDeclarationSyntax toPrint = containingSyntaxContext.WrapMembersInContainingSyntaxWithUnsafeModifier(wrappperMethod, registerMethod);
MemberDeclarationSyntax toPrint = containingSyntaxContext.WrapMembersInContainingSyntaxWithUnsafeModifier(wrappperMethod);

return toPrint;
}
Expand Down Expand Up @@ -197,7 +221,68 @@ private static IncrementalStubGenerationContext CalculateStubInformation(
return MarshallingGeneratorFactoryKey.Create((env.TargetFramework, env.TargetFrameworkVersion, options), jsGeneratorFactory);
}

private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateSource(
private static NamespaceDeclarationSyntax GenerateRegSource(
ImmutableArray<(StatementSyntax Registration, AttributeListSyntax Attribute)> methods)
{
const string generatedNamespace = "System.Runtime.InteropServices.JavaScript";
const string initializerClass = "__GeneratedInitializer";
const string initializerName = "__Register_";
const string selfInitName = "__Net7SelfInit_";

if (methods.IsEmpty) return NamespaceDeclaration(IdentifierName(generatedNamespace));

var registerStatements = new List<StatementSyntax>();
registerStatements.AddRange(JSExportCodeGenerator.GenerateJSExportArchitectureCheck());

var attributes = new List<AttributeListSyntax>();
foreach (var m in methods)
{
registerStatements.Add(m.Registration);
attributes.Add(m.Attribute);
}

FieldDeclarationSyntax field = FieldDeclaration(VariableDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)))
.WithVariables(SingletonSeparatedList(
VariableDeclarator(Identifier("initialized")))))
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)));

MemberDeclarationSyntax method = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), Identifier(initializerName))
.WithAttributeLists(List(attributes))
.WithModifiers(TokenList(new[] { Token(SyntaxKind.StaticKeyword) }))
.WithBody(Block(registerStatements));

// when we are running code generated by .NET8 on .NET7 runtime we need to auto initialize the assembly, because .NET7 doesn't call the registration from JS
// this also keeps the code protected from trimming
MemberDeclarationSyntax initializerMethod = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), Identifier(selfInitName))
.WithAttributeLists(List(new[]{
AttributeList(SingletonSeparatedList(Attribute(IdentifierName(Constants.ModuleInitializerAttributeGlobal)))),
}))
.WithModifiers(TokenList(new[] {
Token(SyntaxKind.StaticKeyword),
Token(SyntaxKind.InternalKeyword)
}))
.WithBody(Block(
IfStatement(BinaryExpression(SyntaxKind.EqualsExpression,
IdentifierName("Environment.Version.Major"),
LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(7))),
Block(SingletonList<StatementSyntax>(
ExpressionStatement(InvocationExpression(IdentifierName(initializerName))))))));

var ns = NamespaceDeclaration(IdentifierName(generatedNamespace))
.WithMembers(
SingletonList<MemberDeclarationSyntax>(
ClassDeclaration(initializerClass)
.WithModifiers(TokenList(new SyntaxToken[]{
Token(SyntaxKind.UnsafeKeyword)}))
.WithMembers(List(new[] { field, initializerMethod, method }))
.WithAttributeLists(SingletonList(AttributeList(SingletonSeparatedList(
Attribute(IdentifierName(Constants.CompilerGeneratedAttributeGlobal)))
)))));

return ns;
}

private static (MemberDeclarationSyntax, StatementSyntax, AttributeListSyntax, ImmutableArray<Diagnostic>) GenerateSource(
IncrementalStubGenerationContext incrementalContext)
{
var diagnostics = new GeneratorDiagnostics();
Expand All @@ -215,10 +300,21 @@ private static (MemberDeclarationSyntax, ImmutableArray<Diagnostic>) GenerateSou
},
incrementalContext.GeneratorFactoryKey.GeneratorFactory);

BlockSyntax wrapper = stubGenerator.GenerateJSExportBody();
BlockSyntax registration = stubGenerator.GenerateJSExportRegistration();
var wrapperName = "__Wrapper_" + incrementalContext.StubMethodSyntaxTemplate.Identifier + "_" + incrementalContext.SignatureContext.TypesHash;

return (PrintGeneratedSource(incrementalContext.StubMethodSyntaxTemplate, incrementalContext.SignatureContext, incrementalContext.ContainingSyntaxContext, wrapper, registration), incrementalContext.Diagnostics.Array.AddRange(diagnostics.Diagnostics));
BlockSyntax wrapper = stubGenerator.GenerateJSExportBody();
StatementSyntax registration = stubGenerator.GenerateJSExportRegistration();
AttributeListSyntax registrationAttribute = AttributeList(SingletonSeparatedList(Attribute(IdentifierName(Constants.DynamicDependencyAttributeGlobal))
.WithArgumentList(AttributeArgumentList(SeparatedList(new[]{
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(wrapperName))),
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(incrementalContext.SignatureContext.StubTypeFullName))),
AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(incrementalContext.SignatureContext.AssemblyName))),
}
)))));

return (PrintGeneratedSource(incrementalContext.ContainingSyntaxContext, wrapper, wrapperName),
registration, registrationAttribute,
incrementalContext.Diagnostics.Array.AddRange(diagnostics.Diagnostics));
}

private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax methodSyntax, IMethodSymbol method)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ public static JSSignatureContext Create(
StubTypeFullName = stubTypeFullName,
MethodName = fullName,
QualifiedMethodName = qualifiedName,
BindingName = "__signature_" + method.Name + "_" + typesHash
BindingName = "__signature_" + method.Name + "_" + typesHash,
AssemblyName = env.Compilation.AssemblyName,
};
}

Expand All @@ -87,6 +88,7 @@ private static string GetFullyQualifiedMethodName(StubEnvironment env, IMethodSy
public string MethodName { get; init; }
public string QualifiedMethodName { get; init; }
public string BindingName { get; init; }
public string AssemblyName { get; init; }

public override int GetHashCode()
{
Expand Down
Loading