Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -239,28 +239,29 @@ private MarshallingInfo GetMarshallingInfo(

// If we aren't overriding the marshalling at usage time,
// then fall back to the information on the element type itself.
foreach (AttributeData typeAttribute in type.GetAttributes())
{
if (GetMarshallingInfoForAttribute(typeAttribute, type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) is MarshallingInfo marshallingInfo)
{
return marshallingInfo;
}
}
if (GetMarshallingInfoForAttributes(type.GetAttributes().AsSpan(), type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) is MarshallingInfo info)
return info;

// If the type doesn't have custom attributes that dictate marshalling,
// then consider the type itself.
return GetMarshallingInfoForType(type, indirectionDepth, useSiteAttributes, GetMarshallingInfo) ?? NoMarshallingInfo.Instance;
}

private MarshallingInfo? GetMarshallingInfoForAttribute(AttributeData attribute, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback)
{
return GetMarshallingInfoForAttributes(new AttributeData[] { attribute }.AsSpan(), type, indirectionDepth, useSiteAttributes, marshallingInfoCallback);
}

private MarshallingInfo? GetMarshallingInfoForAttributes(ReadOnlySpan<AttributeData> attrs, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback)
{
foreach (var parser in _marshallingAttributeParsers)
{
// Automatically ignore invalid attributes.
// The compiler will already error on them.
if (attribute.AttributeConstructor is not null && parser.CanParseAttributeType(attribute.AttributeClass))
foreach (var attr in attrs)
{
return parser.ParseAttribute(attribute, type, indirectionDepth, useSiteAttributes, marshallingInfoCallback);
if (attr.AttributeConstructor is not null && parser.CanParseAttributeType(attr.AttributeClass))
{
return parser.ParseAttribute(attr, type, indirectionDepth, useSiteAttributes, marshallingInfoCallback);
}
}
}
return null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using SharedTypes.ComInterfaces;
using Xunit;

namespace ComInterfaceGenerator.Tests
{
public unsafe partial class NativeMarshallingAttributeTests
{
[LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_unique_marshalling")]
internal static partial IUniqueMarshalling NewUniqueMarshalling();

[Fact]
public void MethodReturningComInterfaceReturnsUniqueInstance()
{
// When a COM interface method returns the same interface type,
// it should return a new managed instance, not the cached one
var obj = NewUniqueMarshalling();
obj.SetValue(42);

var returnedObj = obj.GetThis();

// Should be a different managed object
Assert.NotSame(obj, returnedObj);

// But should refer to the same underlying COM object
Assert.Equal(42, returnedObj.GetValue());

// Modifying through one should affect the other
returnedObj.SetValue(100);
Assert.Equal(100, obj.GetValue());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -751,5 +751,21 @@ public Bidirectional(IComInterfaceAttributeProvider attributeProvider)

public IComInterfaceAttributeProvider AttributeProvider { get; }
}

public string ComInterfaceWithNativeMarshalling => $$"""
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

[assembly:DisableRuntimeMarshalling]

{{GeneratedComInterface()}}
[NativeMarshalling(typeof(UniqueComInterfaceMarshaller<IFoo>))]
partial interface IFoo
{
void DoWorkTogether(IFoo foo);
}
""";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ public static IEnumerable<object[]> ComInterfaceSnippetsToCompile()
yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("ref readonly") };
yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("in") };
yield return new object[] { ID(), codeSnippets.ForwarderWithPreserveSigAndRefKind("out") };
yield return new object[] { ID(), codeSnippets.ComInterfaceWithNativeMarshalling };
}

public static IEnumerable<object[]> ManagedToUnmanagedComInterfaceSnippetsToCompile()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

namespace SharedTypes.ComInterfaces
{
[GeneratedComInterface]
[Guid(IID)]
[NativeMarshalling(typeof(UniqueComInterfaceMarshaller<IUniqueMarshalling>))]
internal partial interface IUniqueMarshalling
{
int GetValue();
void SetValue(int x);
IUniqueMarshalling GetThis();

public const string IID = "E11D5F3E-DD57-4E7E-A78C-F5F8B8E0A1F4";
}

[GeneratedComClass]
internal partial class UniqueMarshalling : IUniqueMarshalling
{
int _data = 0;
public int GetValue() => _data;
public void SetValue(int x) => _data = x;
public IUniqueMarshalling GetThis() => this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
<ItemGroup>
<Compile Include="$(CommonPath)DisableRuntimeMarshalling.cs"
Link="Common\DisableRuntimeMarshalling.cs" />
<Compile Include="..\Common\ComInterfaces\IUniqueMarshalling.cs"
Link="ComInterfaces\IUniqueMarshalling.cs" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using SharedTypes.ComInterfaces;
using Xunit;

namespace LibraryImportGenerator.IntegrationTests
{
partial class NativeExportsNE
{
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "new_unique_marshalling")]
internal static partial IUniqueMarshalling GetUniqueMarshalling();
}

public class NativeMarshallingAttributeTests
{
[Fact]
public void GetSameComInterfaceTwiceReturnsUniqueInstances()
{
// When using NativeMarshalling with UniqueComInterfaceMarshaller,
// calling GetUniqueMarshalling() twice returns different managed instances for the same COM object
var obj1 = NativeExportsNE.GetUniqueMarshalling();
var obj2 = NativeExportsNE.GetUniqueMarshalling();

Assert.NotSame(obj1, obj2);

// Both refer to the same underlying COM object (same cached pointer)
obj1.SetValue(42);
Assert.Equal(42, obj2.GetValue());

// Modifying through one should affect the other
obj2.SetValue(100);
Assert.Equal(100, obj1.GetValue());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1526,5 +1526,25 @@ public void Free() { }
}
}
""";

public static string ComInterfaceWithNativeMarshallingInLibraryImport => """
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;

[GeneratedComInterface]
[Guid("0E7204B5-4B61-4E06-B872-82BA652F2ECA")]
[NativeMarshalling(typeof(UniqueComInterfaceMarshaller<IFoo>))]
partial interface IFoo
{
void DoWork();
}

static partial class PInvokes
{
[LibraryImport("lib")]
[return: MarshalAs(UnmanagedType.I1)]
public static partial bool TryGetFoo(out IFoo foo);
}
""";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()

// Type-level interop generator trigger attributes
yield return new[] { ID(), CodeSnippets.GeneratedComInterface };
yield return new[] { ID(), CodeSnippets.ComInterfaceWithNativeMarshallingInLibraryImport };

// Parameter modifiers
yield return new[] { ID(), CodeSnippets.SingleParameterWithModifier("int", "scoped ref") };
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using SharedTypes.ComInterfaces;

namespace NativeExports.ComInterfaceGenerator
{
public static unsafe class UniqueMarshalling
{
private static void* s_cachedPtr = null;

// Call from another assembly to get a ptr to make an RCW
[UnmanagedCallersOnly(EntryPoint = "new_unique_marshalling")]
public static void* CreateComObject()
{
if (s_cachedPtr == null)
{
StrategyBasedComWrappers wrappers = new();
var myObject = new SharedTypes.ComInterfaces.UniqueMarshalling();
nint ptr = wrappers.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
s_cachedPtr = (void*)ptr;
}

return s_cachedPtr;
}
}
}
Loading