Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.DotnetRuntime.Extensions;

namespace Microsoft.Interop.Analyzers
{
[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class VtableIndexStubDiagnosticsAnalyzer : DiagnosticAnalyzer
{
public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get; } =
ImmutableArray.Create(
GeneratorDiagnostics.InvalidAttributedMethodSignature,
GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers,
GeneratorDiagnostics.ReturnConfigurationNotSupported,
GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingUnmanagedObjectUnwrapperAttribute,
GeneratorDiagnostics.InvalidStringMarshallingConfigurationOnMethod,
GeneratorDiagnostics.InvalidExceptionMarshallingConfiguration,
GeneratorDiagnostics.ConfigurationNotSupported,
GeneratorDiagnostics.ParameterTypeNotSupported,
GeneratorDiagnostics.ReturnTypeNotSupported,
GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails,
GeneratorDiagnostics.ReturnTypeNotSupportedWithDetails,
GeneratorDiagnostics.ParameterConfigurationNotSupported,
GeneratorDiagnostics.MarshalAsParameterConfigurationNotSupported,
GeneratorDiagnostics.MarshalAsReturnConfigurationNotSupported,
GeneratorDiagnostics.ConfigurationValueNotSupported,
GeneratorDiagnostics.MarshallingAttributeConfigurationNotSupported,
GeneratorDiagnostics.UnnecessaryParameterMarshallingInfo,
GeneratorDiagnostics.UnnecessaryReturnMarshallingInfo,
GeneratorDiagnostics.GeneratedComInterfaceUsageDoesNotFollowBestPractices);

public override void Initialize(AnalysisContext context)
{
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.None);
context.EnableConcurrentExecution();
context.RegisterCompilationStartAction(compilationContext =>
{
INamedTypeSymbol? virtualMethodIndexAttrType = compilationContext.Compilation.GetBestTypeByMetadataName(TypeNames.VirtualMethodIndexAttribute);
if (virtualMethodIndexAttrType is null)
return;

StubEnvironment env = new StubEnvironment(
compilationContext.Compilation,
compilationContext.Compilation.GetEnvironmentFlags());

compilationContext.RegisterSymbolAction(symbolContext =>
{
IMethodSymbol method = (IMethodSymbol)symbolContext.Symbol;
AttributeData? virtualMethodIndexAttr = null;
foreach (AttributeData attr in method.GetAttributes())
{
if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, virtualMethodIndexAttrType))
{
virtualMethodIndexAttr = attr;
break;
}
}

if (virtualMethodIndexAttr is null)
return;

foreach (SyntaxReference syntaxRef in method.DeclaringSyntaxReferences)
{
if (syntaxRef.GetSyntax(symbolContext.CancellationToken) is MethodDeclarationSyntax methodSyntax)
{
AnalyzeMethod(symbolContext, methodSyntax, method, env);
break;
}
}
}, SymbolKind.Method);
});
}

private static void AnalyzeMethod(SymbolAnalysisContext context, MethodDeclarationSyntax methodSyntax, IMethodSymbol method, StubEnvironment env)
{
DiagnosticInfo? invalidMethodDiagnostic = GetDiagnosticIfInvalidMethodForGeneration(methodSyntax, method);
if (invalidMethodDiagnostic is not null)
{
context.ReportDiagnostic(invalidMethodDiagnostic.ToDiagnostic());
return;
}

SourceAvailableIncrementalMethodStubGenerationContext stubContext = VtableIndexStubGenerator.CalculateStubInformation(methodSyntax, method, env, context.CancellationToken);

if (stubContext.VtableIndexData.Direction is MarshalDirection.ManagedToUnmanaged or MarshalDirection.Bidirectional)
{
var (_, diagnostics) = VirtualMethodPointerStubGenerator.GenerateManagedToNativeStub(stubContext, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
foreach (DiagnosticInfo diag in diagnostics)
context.ReportDiagnostic(diag.ToDiagnostic());
}

if (stubContext.VtableIndexData.Direction is MarshalDirection.UnmanagedToManaged or MarshalDirection.Bidirectional)
{
var (_, diagnostics) = VirtualMethodPointerStubGenerator.GenerateNativeToManagedStub(stubContext, VtableIndexStubGeneratorHelpers.GetGeneratorResolver);
foreach (DiagnosticInfo diag in diagnostics)
context.ReportDiagnostic(diag.ToDiagnostic());
}
}

internal static DiagnosticInfo? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax methodSyntax, IMethodSymbol method)
{
// Verify the method has no generic types or defined implementation
// and is not marked static or sealed
if (methodSyntax.TypeParameterList is not null
|| methodSyntax.Body is not null
|| methodSyntax.Modifiers.Any(SyntaxKind.StaticKeyword)
|| methodSyntax.Modifiers.Any(SyntaxKind.SealedKeyword))
{
return DiagnosticInfo.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, methodSyntax.Identifier.GetLocation(), method.Name);
}

// Verify that the types the method is declared in are marked partial.
for (SyntaxNode? parentNode = methodSyntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent)
{
if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword))
{
return DiagnosticInfo.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, methodSyntax.Identifier.GetLocation(), method.Name, typeDecl.Identifier);
}
}

// Verify the method does not have a ref return
if (method.ReturnsByRef || method.ReturnsByRefReadonly)
{
return DiagnosticInfo.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, methodSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString());
}

// Verify there is an [UnmanagedObjectUnwrapperAttribute<TMapper>]
if (!method.ContainingType.GetAttributes().Any(att => att.AttributeClass.IsOfType(TypeNames.UnmanagedObjectUnwrapperAttribute)))
{
return DiagnosticInfo.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingUnmanagedObjectUnwrapperAttribute, methodSyntax.Identifier.GetLocation(), method.Name);
}

return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
);
});

context.RegisterDiagnostics(attributedInterfaces.SelectMany(static (data, ct) => data.Diagnostics));

// Create list of methods (inherited and declared) and their owning interface
var interfaceContextsToGenerate = attributedInterfaces.SelectMany(static (a, ct) => a.InterfaceContexts);
var comMethodContexts = attributedInterfaces.Select(static (a, ct) => a.MethodContexts);
Expand All @@ -185,11 +183,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
GenerateIUnknownDerivedAttributeApplication(x.Interface.Info, ct).NormalizeWhitespace()
]));

// Report diagnostics for managed-to-unmanaged and unmanaged-to-managed stubs, deduplicating diagnostics that are reported for both.
context.RegisterDiagnostics(
interfaceAndMethodsContexts
.SelectMany(static (data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics).Union(data.DeclaredMethods.SelectMany(m => m.UnmanagedToManagedStub.Diagnostics))));

var filesToGenerate = syntaxes
.Select(static (methodSyntaxes, ct) =>
{
Expand Down Expand Up @@ -443,7 +436,7 @@ private static IncrementalMethodStubGenerationContext CalculateSharedStubInforma
ComInterfaceDispatchMarshallingInfo.Instance);
}

private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax? syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterface, CancellationToken ct)
internal static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax? syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ComInterfaceInfo owningInterface, CancellationToken ct)
{
ISignatureDiagnosticLocations locations = syntax is null
? NoneSignatureDiagnosticLocations.Instance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Where(
static modelData => modelData is not null);

var methodsWithDiagnostics = attributedMethods.Select(static (data, ct) =>
{
Diagnostic? diagnostic = GetDiagnosticIfInvalidMethodForGeneration(data.Syntax, data.Symbol);
return new { data.Syntax, data.Symbol, Diagnostic = diagnostic };
});

// Split the methods we want to generate and the ones we don't into two separate groups.
var methodsToGenerate = methodsWithDiagnostics.Where(static data => data.Diagnostic is null);
var invalidMethodDiagnostics = methodsWithDiagnostics.Where(static data => data.Diagnostic is not null);

context.RegisterSourceOutput(invalidMethodDiagnostics, static (context, invalidMethod) =>
{
context.ReportDiagnostic(invalidMethod.Diagnostic);
});
// Filter out methods that are invalid for generation (diagnostics for invalid methods are reported by the analyzer).
var methodsToGenerate = attributedMethods.Where(
static data => data is not null && VtableIndexStubDiagnosticsAnalyzer.GetDiagnosticIfInvalidMethodForGeneration(data.Syntax, data.Symbol) is null);

// Calculate all of information to generate both managed-to-unmanaged and unmanaged-to-managed stubs
// for each method.
Expand All @@ -84,8 +73,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.WithComparer(Comparers.GeneratedSyntax)
.WithTrackingName(StepNames.GenerateManagedToNativeStub);

context.RegisterDiagnostics(generateManagedToNativeStub.SelectMany((stubInfo, ct) => stubInfo.Item2));

context.RegisterConcatenatedSyntaxOutputs(generateManagedToNativeStub.Select((data, ct) => data.Item1), "ManagedToNativeStubs.g.cs");

// Filter the list of all stubs to only the stubs that requested unmanaged-to-managed stub generation.
Expand All @@ -101,8 +88,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.WithComparer(Comparers.GeneratedSyntax)
.WithTrackingName(StepNames.GenerateNativeToManagedStub);

context.RegisterDiagnostics(generateNativeToManagedStub.SelectMany((stubInfo, ct) => stubInfo.Item2));

context.RegisterConcatenatedSyntaxOutputs(generateNativeToManagedStub.Select((data, ct) => data.Item1), "NativeToManagedStubs.g.cs");

// Generate the native interface metadata for each interface that contains a method with the [VirtualMethodIndex] attribute.
Expand Down Expand Up @@ -195,7 +180,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
};
}

private static SourceAvailableIncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
internal static SourceAvailableIncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, StubEnvironment environment, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
Expand Down Expand Up @@ -388,42 +373,6 @@ private static (MemberDeclarationSyntax, ImmutableArray<DiagnosticInfo>) Generat
methodStub.Diagnostics.Array.AddRange(diagnostics));
}

private static Diagnostic? GetDiagnosticIfInvalidMethodForGeneration(MethodDeclarationSyntax methodSyntax, IMethodSymbol method)
{
// Verify the method has no generic types or defined implementation
// and is not marked static or sealed
if (methodSyntax.TypeParameterList is not null
|| methodSyntax.Body is not null
|| methodSyntax.Modifiers.Any(SyntaxKind.StaticKeyword)
|| methodSyntax.Modifiers.Any(SyntaxKind.SealedKeyword))
{
return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodSignature, methodSyntax.Identifier.GetLocation(), method.Name);
}

// Verify that the types the method is declared in are marked partial.
for (SyntaxNode? parentNode = methodSyntax.Parent; parentNode is TypeDeclarationSyntax typeDecl; parentNode = parentNode.Parent)
{
if (!typeDecl.Modifiers.Any(SyntaxKind.PartialKeyword))
{
return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingModifiers, methodSyntax.Identifier.GetLocation(), method.Name, typeDecl.Identifier);
}
}

// Verify the method does not have a ref return
if (method.ReturnsByRef || method.ReturnsByRefReadonly)
{
return Diagnostic.Create(GeneratorDiagnostics.ReturnConfigurationNotSupported, methodSyntax.Identifier.GetLocation(), "ref return", method.ToDisplayString());
}

// Verify there is an [UnmanagedObjectUnwrapperAttribute<TMapper>]
if (!method.ContainingType.GetAttributes().Any(att => att.AttributeClass.IsOfType(TypeNames.UnmanagedObjectUnwrapperAttribute)))
{
return Diagnostic.Create(GeneratorDiagnostics.InvalidAttributedMethodContainingTypeMissingUnmanagedObjectUnwrapperAttribute, methodSyntax.Identifier.GetLocation(), method.Name);
}

return null;
}

private static MemberDeclarationSyntax GenerateNativeInterfaceMetadata(ContainingSyntaxContext context)
{
return context.WrapMemberInContainingSyntaxWithUnsafeModifier(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
using Xunit;

using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpCodeFixVerifier<
Microsoft.CodeAnalysis.Testing.EmptyDiagnosticAnalyzer,
Microsoft.Interop.Analyzers.ComInterfaceGeneratorDiagnosticsAnalyzer,
Microsoft.Interop.Analyzers.AddMarshalAsToElementFixer>;

namespace ComInterfaceGenerator.Unit.Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
using Xunit;
using static Microsoft.Interop.UnitTests.TestUtils;
using StringMarshalling = System.Runtime.InteropServices.StringMarshalling;
using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.ComInterfaceGenerator, Microsoft.CodeAnalysis.Testing.EmptyDiagnosticAnalyzer>;
using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.ComInterfaceGenerator, Microsoft.Interop.Analyzers.ComInterfaceGeneratorDiagnosticsAnalyzer>;

namespace ComInterfaceGenerator.Unit.Tests
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
using Xunit;
using static Microsoft.Interop.UnitTests.TestUtils;
using StringMarshalling = System.Runtime.InteropServices.StringMarshalling;
using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.ComInterfaceGenerator, Microsoft.CodeAnalysis.Testing.EmptyDiagnosticAnalyzer>;
using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.ComInterfaceGenerator, Microsoft.Interop.Analyzers.ComInterfaceGeneratorDiagnosticsAnalyzer>;

namespace ComInterfaceGenerator.Unit.Tests
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
using Microsoft.Interop;
using Xunit;

using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.CodeAnalysis.Testing.EmptySourceGeneratorProvider, Microsoft.CodeAnalysis.Testing.EmptyDiagnosticAnalyzer>;
using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.CodeAnalysis.Testing.EmptySourceGeneratorProvider, Microsoft.Interop.Analyzers.ComInterfaceGeneratorDiagnosticsAnalyzer>;

namespace ComInterfaceGenerator.Unit.Tests
{
Expand Down
Loading