Skip to content

Make constructors private for parts of the ComInterfaceGenerator that should only be created from their static methods #101740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ namespace Microsoft.Interop
[Generator]
public class ComClassGenerator : IIncrementalGenerator
{
private sealed record ComClassInfo(string ClassName, ContainingSyntaxContext ContainingSyntaxContext, ContainingSyntax ClassSyntax, SequenceEqualImmutableArray<string> ImplementedInterfacesNames);
public void Initialize(IncrementalGeneratorInitializationContext context)
{
var unsafeCodeIsEnabled = context.CompilationProvider.Select((comp, ct) => comp.Options is CSharpCompilationOptions { AllowUnsafe: true }); // Unsafe code enabled
Expand All @@ -27,54 +26,13 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
static (node, ct) => node is ClassDeclarationSyntax,
static (context, ct) => context)
.Combine(unsafeCodeIsEnabled)
.Select((data, ct) =>
.Select(static (data, ct) =>
{
var context = data.Left;
var unsafeCodeIsEnabled = data.Right;
var type = (INamedTypeSymbol)context.TargetSymbol;
var syntax = (ClassDeclarationSyntax)context.TargetNode;
if (!unsafeCodeIsEnabled)
{
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, syntax.Identifier.GetLocation()));
}

if (!syntax.IsInPartialContext(out _))
{
return DiagnosticOr<ComClassInfo>.From(
DiagnosticInfo.Create(
GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier,
syntax.Identifier.GetLocation(),
type.ToDisplayString()));
}

ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
foreach (INamedTypeSymbol iface in type.AllInterfaces)
{
AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute);
if (generatedComInterfaceAttribute is not null)
{
var attributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute);
if (attributeData.Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper))
{
names.Add(iface.ToDisplayString());
}
}
}

if (names.Count == 0)
{
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface,
syntax.Identifier.GetLocation(),
type.ToDisplayString()));
}


return DiagnosticOr<ComClassInfo>.From(
new ComClassInfo(
type.ToDisplayString(),
new ContainingSyntaxContext(syntax),
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
new(names.ToImmutable())));
return ComClassInfo.From(type, syntax, unsafeCodeIsEnabled);
});

var attributedClasses = context.FilterAndReportDiagnostics(attributedClassesOrDiagnostics);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Microsoft.Interop
{
internal sealed record ComClassInfo
{
public string ClassName { get; init; }
public ContainingSyntaxContext ContainingSyntaxContext { get; init; }
public ContainingSyntax ClassSyntax { get; init; }
public SequenceEqualImmutableArray<string> ImplementedInterfacesNames { get; init; }

private ComClassInfo(string className, ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax, SequenceEqualImmutableArray<string> implementedInterfacesNames)
{
ClassName = className;
ContainingSyntaxContext = containingSyntaxContext;
ClassSyntax = classSyntax;
ImplementedInterfacesNames = implementedInterfacesNames;
}

public static DiagnosticOr<ComClassInfo> From(INamedTypeSymbol type, ClassDeclarationSyntax syntax, bool unsafeCodeIsEnabled)
{
if (!unsafeCodeIsEnabled)
{
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.RequiresAllowUnsafeBlocks, syntax.Identifier.GetLocation()));
}

if (!syntax.IsInPartialContext(out _))
{
return DiagnosticOr<ComClassInfo>.From(
DiagnosticInfo.Create(
GeneratorDiagnostics.InvalidAttributedClassMissingPartialModifier,
syntax.Identifier.GetLocation(),
type.ToDisplayString()));
}

ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
foreach (INamedTypeSymbol iface in type.AllInterfaces)
{
AttributeData? generatedComInterfaceAttribute = iface.GetAttributes().FirstOrDefault(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute);
if (generatedComInterfaceAttribute is not null)
{
var attributeData = GeneratedComInterfaceCompilationData.GetDataFromAttribute(generatedComInterfaceAttribute);
if (attributeData.Options.HasFlag(ComInterfaceOptions.ManagedObjectWrapper))
{
names.Add(iface.ToDisplayString());
}
}
}

if (names.Count == 0)
{
return DiagnosticOr<ComClassInfo>.From(DiagnosticInfo.Create(GeneratorDiagnostics.ClassDoesNotImplementAnyGeneratedComInterface,
syntax.Identifier.GetLocation(),
type.ToDisplayString()));
}

return DiagnosticOr<ComClassInfo>.From(
new ComClassInfo(
type.ToDisplayString(),
new ContainingSyntaxContext(syntax),
new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
new(names.ToImmutable())));
}

public bool Equals(ComClassInfo? other)
{
return other is not null
&& ClassName == other.ClassName
&& ContainingSyntaxContext.Equals(other.ContainingSyntaxContext)
&& ImplementedInterfacesNames.SequenceEqual(other.ImplementedInterfacesNames);
}

public override int GetHashCode()
{
return HashCode.Combine(ClassName, ContainingSyntaxContext, ImplementedInterfacesNames);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
Expand All @@ -9,8 +10,19 @@

namespace Microsoft.Interop
{
internal sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base, ComInterfaceOptions Options)
internal sealed record ComInterfaceContext
{
internal ComInterfaceInfo Info { get; init; }
internal ComInterfaceContext? Base { get; init; }
internal ComInterfaceOptions Options { get; init; }

private ComInterfaceContext(ComInterfaceInfo info, ComInterfaceContext? @base, ComInterfaceOptions options)
{
Info = info;
Base = @base;
Options = options;
}

/// <summary>
/// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,40 @@ namespace Microsoft.Interop
/// <summary>
/// Information about a Com interface, but not its methods.
/// </summary>
internal sealed record ComInterfaceInfo(
ManagedTypeInfo Type,
string ThisInterfaceKey, // For associating interfaces to its base
string? BaseInterfaceKey, // For associating interfaces to its base
InterfaceDeclarationSyntax Declaration,
ContainingSyntaxContext TypeDefinitionContext,
ContainingSyntax ContainingSyntax,
Guid InterfaceId,
ComInterfaceOptions Options,
Location DiagnosticLocation)
internal sealed record ComInterfaceInfo
{
public ManagedTypeInfo Type { get; init; }
public string ThisInterfaceKey { get; init; }
public string? BaseInterfaceKey { get; init; }
public InterfaceDeclarationSyntax Declaration { get; init; }
public ContainingSyntaxContext TypeDefinitionContext { get; init; }
public ContainingSyntax ContainingSyntax { get; init; }
public Guid InterfaceId { get; init; }
public ComInterfaceOptions Options { get; init; }
public Location DiagnosticLocation { get; init; }

private ComInterfaceInfo(
ManagedTypeInfo type,
string thisInterfaceKey,
string? baseInterfaceKey,
InterfaceDeclarationSyntax declaration,
ContainingSyntaxContext typeDefinitionContext,
ContainingSyntax containingSyntax,
Guid interfaceId,
ComInterfaceOptions options,
Location diagnosticLocation)
{
Type = type;
ThisInterfaceKey = thisInterfaceKey;
BaseInterfaceKey = baseInterfaceKey;
Declaration = declaration;
TypeDefinitionContext = typeDefinitionContext;
ContainingSyntax = containingSyntax;
InterfaceId = interfaceId;
Options = options;
DiagnosticLocation = diagnosticLocation;
}

public static DiagnosticOrInterfaceInfo From(INamedTypeSymbol symbol, InterfaceDeclarationSyntax syntax, StubEnvironment env, CancellationToken _)
{
if (env.Compilation.Options is not CSharpCompilationOptions { AllowUnsafe: true }) // Unsafe code enabled
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
Expand All @@ -14,12 +15,25 @@ namespace Microsoft.Interop
/// <summary>
/// Represents a method that has been determined to be a COM interface method. Only contains info immediately available from an IMethodSymbol and MethodDeclarationSyntax.
/// </summary>
internal sealed record ComMethodInfo(
MethodDeclarationSyntax Syntax,
string MethodName,
SequenceEqualImmutableArray<AttributeInfo> Attributes,
bool IsUserDefinedShadowingMethod)
internal sealed record ComMethodInfo
{
public MethodDeclarationSyntax Syntax { get; init; }
public string MethodName { get; init; }
public SequenceEqualImmutableArray<AttributeInfo> Attributes { get; init; }
public bool IsUserDefinedShadowingMethod { get; init; }

private ComMethodInfo(
MethodDeclarationSyntax syntax,
string methodName,
SequenceEqualImmutableArray<AttributeInfo> attributes,
bool isUserDefinedShadowingMethod)
{
Syntax = syntax;
MethodName = methodName;
Attributes = attributes;
IsUserDefinedShadowingMethod = isUserDefinedShadowingMethod;
}

/// <summary>
/// Returns a list of tuples of ComMethodInfo, IMethodSymbol, and Diagnostic. If ComMethodInfo is null, Diagnostic will not be null, and vice versa.
/// </summary>
Expand Down Expand Up @@ -95,7 +109,6 @@ internal sealed record ComMethodInfo(
return DiagnosticOr<(ComMethodInfo, IMethodSymbol)>.From(DiagnosticInfo.Create(GeneratorDiagnostics.MethodNotDeclaredInAttributedInterface, method.Locations.FirstOrDefault(), method.ToDisplayString()));
}


// Find the matching declaration syntax
MethodDeclarationSyntax? comMethodDeclaringSyntax = null;
foreach (var declaringSyntaxReference in method.DeclaringSyntaxReferences)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ public static ManagedTypeInfo CreateTypeInfoForTypeSymbol(ITypeSymbol type)
public sealed record SpecialTypeInfo(string FullTypeName, string DiagnosticFormattedName, SpecialType SpecialType) : ManagedTypeInfo(FullTypeName, DiagnosticFormattedName)
{
public static readonly SpecialTypeInfo Byte = new("byte", "byte", SpecialType.System_Byte);
public static readonly SpecialTypeInfo SByte = new("sbyte", "sbyte", SpecialType.System_SByte);
public static readonly SpecialTypeInfo Int16 = new("short", "short", SpecialType.System_Int16);
public static readonly SpecialTypeInfo UInt16 = new("ushort", "ushort", SpecialType.System_UInt16);
public static readonly SpecialTypeInfo Int32 = new("int", "int", SpecialType.System_Int32);
public static readonly SpecialTypeInfo UInt32 = new("uint", "uint", SpecialType.System_UInt32);
public static readonly SpecialTypeInfo Void = new("void", "void", SpecialType.System_Void);
public static readonly SpecialTypeInfo String = new("string", "string", SpecialType.System_String);
public static readonly SpecialTypeInfo Boolean = new("bool", "bool", SpecialType.System_Boolean);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ private ResolvedGenerator CreateNativeCollectionMarshaller(
marshallerType = marshallerType with
{
FullTypeName = marshallerTypeSyntax.ToString(),
DiagnosticFormattedName = marshallerTypeSyntax.ToString(),
DiagnosticFormattedName = marshallerTypeSyntax.ToString()
};
string newNativeTypeName = ReplacePlaceholderSyntaxWithUnmanagedTypeSyntax(marshallerData.NativeType.Syntax, marshalInfo, unmanagedElementType).ToFullString();
ManagedTypeInfo nativeType = marshallerData.NativeType with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ protected BoolMarshallerBase(ManagedTypeInfo nativeType, int trueValue, int fals

public ManagedTypeInfo AsNativeType(TypePositionInfo info)
{
Debug.Assert(info.ManagedType is SpecialTypeInfo(_, _, SpecialType.System_Boolean));
Debug.Assert(info.ManagedType is SpecialTypeInfo { SpecialType: SpecialType.System_Boolean });
return _nativeType;
}

Expand Down Expand Up @@ -118,7 +118,7 @@ public sealed class ByteBoolMarshaller : BoolMarshallerBase
/// </summary>
/// <param name="signed">True if the byte should be signed, otherwise false</param>
public ByteBoolMarshaller(bool signed)
: base(new SpecialTypeInfo(signed ? "sbyte" : "byte", signed ? "sbyte" : "byte", signed ? SpecialType.System_SByte : SpecialType.System_Byte), trueValue: 1, falseValue: 0, compareToTrue: false)
: base(signed ? SpecialTypeInfo.SByte : SpecialTypeInfo.Byte, trueValue: 1, falseValue: 0, compareToTrue: false)
{
}
}
Expand All @@ -136,7 +136,7 @@ public sealed class WinBoolMarshaller : BoolMarshallerBase
/// </summary>
/// <param name="signed">True if the int should be signed, otherwise false</param>
public WinBoolMarshaller(bool signed)
: base(new SpecialTypeInfo(signed ? "int" : "uint", signed ? "int" : "uint", signed ? SpecialType.System_Int32 : SpecialType.System_UInt32), trueValue: 1, falseValue: 0, compareToTrue: false)
: base(signed ? SpecialTypeInfo.Int32 : SpecialTypeInfo.UInt32, trueValue: 1, falseValue: 0, compareToTrue: false)
{
}
}
Expand All @@ -149,7 +149,7 @@ public sealed class VariantBoolMarshaller : BoolMarshallerBase
private const short VARIANT_TRUE = -1;
private const short VARIANT_FALSE = 0;
public VariantBoolMarshaller()
: base(new SpecialTypeInfo("short", "short", SpecialType.System_Int16), trueValue: VARIANT_TRUE, falseValue: VARIANT_FALSE, compareToTrue: true)
: base(SpecialTypeInfo.Int16, trueValue: VARIANT_TRUE, falseValue: VARIANT_FALSE, compareToTrue: true)
{
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public ResolvedGenerator Create(TypePositionInfo info, StubCodeContext context)
}

// Breaking change: [MarshalAs(UnmanagedType.Struct)] in object in unmanaged-to-managed scenarios will not respect VT_BYREF.
if (info is { RefKind: RefKind.In or RefKind.RefReadOnlyParameter, MarshallingAttributeInfo: NativeMarshallingAttributeInfo(ManagedTypeInfo(_, TypeNames.ComVariantMarshaller), _) }
if (info is { RefKind: RefKind.In or RefKind.RefReadOnlyParameter, MarshallingAttributeInfo: NativeMarshallingAttributeInfo(ManagedTypeInfo { DiagnosticFormattedName: TypeNames.ComVariantMarshaller }, _) }
&& context.Direction == MarshalDirection.UnmanagedToManaged)
{
gen = ResolvedGenerator.ResolvedWithDiagnostics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Microsoft.Interop
{
public sealed class Utf16CharMarshaller : IMarshallingGenerator
{
private static readonly ManagedTypeInfo s_nativeType = new SpecialTypeInfo("ushort", "ushort", SpecialType.System_UInt16);
private static readonly ManagedTypeInfo s_nativeType = SpecialTypeInfo.UInt16;

public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, StubCodeContext context)
{
Expand All @@ -35,7 +35,7 @@ public ValueBoundaryBehavior GetValueBoundaryBehavior(TypePositionInfo info, Stu

public ManagedTypeInfo AsNativeType(TypePositionInfo info)
{
Debug.Assert(info.ManagedType is SpecialTypeInfo(_, _, SpecialType.System_Char));
Debug.Assert(info.ManagedType is SpecialTypeInfo {SpecialType: SpecialType.System_Char });
return s_nativeType;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ public ResolvedGenerator Create(
return ResolvedGenerator.Resolved(s_blittable);

// Pointer with no marshalling info
case { ManagedType: PointerTypeInfo(_, _, IsFunctionPointer: false), MarshallingAttributeInfo: NoMarshallingInfo }:
case { ManagedType: PointerTypeInfo{ IsFunctionPointer: false }, MarshallingAttributeInfo: NoMarshallingInfo }:
return ResolvedGenerator.Resolved(s_blittable);

// Function pointer with no marshalling info
case { ManagedType: PointerTypeInfo(_, _, IsFunctionPointer: true), MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }:
case { ManagedType: PointerTypeInfo { IsFunctionPointer: true }, MarshallingAttributeInfo: NoMarshallingInfo or MarshalAsInfo(UnmanagedType.FunctionPtr, _) }:
return ResolvedGenerator.Resolved(s_blittable);

// Bool with marshalling info
Expand Down
Loading
Loading