Skip to content

Commit ff7b6c5

Browse files
authored
Unwrap RCWs that are passed to Marshal.GetIUnknownForObject when using the global marshalling ComWrappers instance (#115436)
1 parent 9b493d2 commit ff7b6c5

File tree

5 files changed

+58
-18
lines changed

5 files changed

+58
-18
lines changed

src/coreclr/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.CoreCLR.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,14 @@ internal static int CallICustomQueryInterface(ManagedObjectWrapperHolder holder,
6767

6868
internal static IntPtr GetOrCreateComInterfaceForObjectWithGlobalMarshallingInstance(object obj)
6969
{
70+
if (s_globalInstanceForMarshalling == null)
71+
{
72+
return IntPtr.Zero;
73+
}
74+
7075
try
7176
{
72-
return s_globalInstanceForMarshalling is null
73-
? IntPtr.Zero
74-
: s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.TrackerSupport);
77+
return ComInterfaceForObject(obj);
7578
}
7679
catch (ArgumentException)
7780
{
@@ -83,9 +86,14 @@ internal static IntPtr GetOrCreateComInterfaceForObjectWithGlobalMarshallingInst
8386

8487
internal static object? GetOrCreateObjectForComInstanceWithGlobalMarshallingInstance(IntPtr comObject, CreateObjectFlags flags)
8588
{
89+
if (s_globalInstanceForMarshalling == null)
90+
{
91+
return null;
92+
}
93+
8694
try
8795
{
88-
return s_globalInstanceForMarshalling?.GetOrCreateObjectForComInstance(comObject, flags);
96+
return ComObjectForInterface(comObject, flags);
8997
}
9098
catch (ArgumentNullException)
9199
{

src/coreclr/nativeaot/System.Private.CoreLib/src/Internal/Runtime/CompilerHelpers/InteropHelpers.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ public static object ConvertNativeComInterfaceToManaged(IntPtr pUnk)
499499

500500
#if TARGET_WINDOWS
501501
#pragma warning disable CA1416
502-
return ComWrappers.ComObjectForInterface(pUnk);
502+
return ComWrappers.ComObjectForInterface(pUnk, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap);
503503
#pragma warning restore CA1416
504504
#else
505505
throw new PlatformNotSupportedException(SR.PlatformNotSupported_ComInterop);

src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshal.Com.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ public static object GetTypedObjectForIUnknown(IntPtr pUnk, Type t)
295295
[SupportedOSPlatform("windows")]
296296
public static object GetObjectForIUnknown(IntPtr pUnk)
297297
{
298-
return ComWrappers.ComObjectForInterface(pUnk);
298+
return ComWrappers.ComObjectForInterface(pUnk, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap);
299299
}
300300

301301
[SupportedOSPlatform("windows")]

src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -743,13 +743,17 @@ public void DisconnectTracker()
743743

744744
internal static object? GetOrCreateObjectFromWrapper(ComWrappers wrapper, IntPtr externalComObject)
745745
{
746-
if (s_globalInstanceForTrackerSupport != null && s_globalInstanceForTrackerSupport == wrapper)
746+
if (wrapper is null)
747+
{
748+
return null;
749+
}
750+
if (s_globalInstanceForTrackerSupport == wrapper)
747751
{
748752
return s_globalInstanceForTrackerSupport.GetOrCreateObjectForComInstance(externalComObject, CreateObjectFlags.TrackerObject);
749753
}
750-
else if (s_globalInstanceForMarshalling != null && s_globalInstanceForMarshalling == wrapper)
754+
else if (s_globalInstanceForMarshalling == wrapper)
751755
{
752-
return ComObjectForInterface(externalComObject);
756+
return ComObjectForInterface(externalComObject, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap);
753757
}
754758
else
755759
{
@@ -1470,7 +1474,12 @@ internal static IntPtr ComInterfaceForObject(object instance)
14701474
throw new NotSupportedException(SR.InvalidOperation_ComInteropRequireComWrapperInstance);
14711475
}
14721476

1473-
return s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(instance, CreateComInterfaceFlags.None);
1477+
if (TryGetComInstance(instance, out IntPtr comObject))
1478+
{
1479+
return comObject;
1480+
}
1481+
1482+
return s_globalInstanceForMarshalling.GetOrCreateComInterfaceForObject(instance, CreateComInterfaceFlags.TrackerSupport);
14741483
}
14751484

14761485
internal static unsafe IntPtr ComInterfaceForObject(object instance, Guid targetIID)
@@ -1489,15 +1498,14 @@ internal static unsafe IntPtr ComInterfaceForObject(object instance, Guid target
14891498
return comObjectInterface;
14901499
}
14911500

1492-
internal static object ComObjectForInterface(IntPtr externalComObject)
1501+
internal static object ComObjectForInterface(IntPtr externalComObject, CreateObjectFlags flags)
14931502
{
14941503
if (s_globalInstanceForMarshalling == null)
14951504
{
14961505
throw new NotSupportedException(SR.InvalidOperation_ComInteropRequireComWrapperInstance);
14971506
}
14981507

1499-
// TrackerObject support and unwrapping matches the built-in semantics that the global marshalling scenario mimics.
1500-
return s_globalInstanceForMarshalling.GetOrCreateObjectForComInstance(externalComObject, CreateObjectFlags.TrackerObject | CreateObjectFlags.Unwrap);
1508+
return s_globalInstanceForMarshalling.GetOrCreateObjectForComInstance(externalComObject, flags);
15011509
}
15021510

15031511
internal static IntPtr GetOrCreateTrackerTarget(IntPtr externalComObject)

src/tests/Interop/COM/ComWrappers/GlobalInstance/GlobalInstance.cs

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ extern public static int UpdateTestObjectAsInterface(
4949

5050
private const string ManagedServerTypeName = "ConsumeNETServerTesting";
5151

52+
private const string IID_IUNKNOWN = "00000000-0000-0000-C000-000000000046";
5253
private const string IID_IDISPATCH = "00020400-0000-0000-C000-000000000046";
5354
private const string IID_IINSPECTABLE = "AF86E2E0-B12D-4c6a-9C5A-D7AA65101E90";
5455
class TestEx : Test
@@ -278,7 +279,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
278279
var testObj = new Test();
279280
IntPtr comWrapper1 = Marshal.GetIUnknownForObject(testObj);
280281
Assert.NotEqual(IntPtr.Zero, comWrapper1);
281-
Assert.Equal(testObj, registeredWrapper.LastComputeVtablesObject);
282+
Assert.Same(testObj, registeredWrapper.LastComputeVtablesObject);
282283

283284
IntPtr comWrapper2 = Marshal.GetIUnknownForObject(testObj);
284285
Assert.Equal(comWrapper1, comWrapper2);
@@ -295,7 +296,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
295296
var dispatchObj = new TestEx(IID_IDISPATCH);
296297
IntPtr dispatchWrapper = Marshal.GetIDispatchForObject(dispatchObj);
297298
Assert.NotEqual(IntPtr.Zero, dispatchWrapper);
298-
Assert.Equal(dispatchObj, registeredWrapper.LastComputeVtablesObject);
299+
Assert.Same(dispatchObj, registeredWrapper.LastComputeVtablesObject);
299300

300301
Console.WriteLine($" -- Validate Marshal.GetIDispatchForObject != Marshal.GetIUnknownForObject...");
301302
IntPtr unknownWrapper = Marshal.GetIUnknownForObject(dispatchObj);
@@ -309,7 +310,7 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
309310
object objWrapper1 = Marshal.GetObjectForIUnknown(trackerObjRaw);
310311
Assert.Equal(validateUseRegistered, objWrapper1 is FakeWrapper);
311312
object objWrapper2 = Marshal.GetObjectForIUnknown(trackerObjRaw);
312-
Assert.Equal(objWrapper1, objWrapper2);
313+
Assert.Same(objWrapper1, objWrapper2);
313314

314315
Console.WriteLine($" -- Validate Marshal.GetUniqueObjectForIUnknown...");
315316

@@ -319,6 +320,29 @@ private static void ValidateMarshalAPIs(bool validateUseRegistered)
319320
Assert.NotEqual(objWrapper1, objWrapper3);
320321

321322
Marshal.Release(trackerObjRaw);
323+
324+
if (validateUseRegistered)
325+
{
326+
Console.WriteLine($" -- Validate Marshal.GetObjectForIUnknown and Marshal.GetIUnknownForObject unwrapping...");
327+
// Validate that the object returned by Marshal.GetObjectForIUnknown is the same as the original object passed to
328+
// Marshal.GetIUnknownForObject.
329+
IntPtr comWrapper3 = Marshal.GetIUnknownForObject(testObj);
330+
object unwrappedObj = Marshal.GetObjectForIUnknown(comWrapper3);
331+
Assert.Same(testObj, unwrappedObj);
332+
333+
// Validate that the pointer returned by Marshal.GetIUnknownForObject is the same one that was passed into
334+
// Marshal.GetObjectForIUnknown.
335+
IntPtr trackerObj2 = MockReferenceTrackerRuntime.CreateTrackerObject();
336+
Marshal.ThrowExceptionForHR(Marshal.QueryInterface(trackerObj2, Guid.Parse(IID_IUNKNOWN), out IntPtr trackerObj2Identity));
337+
Marshal.Release(trackerObj2);
338+
339+
object trackerObjectWrapper = Marshal.GetObjectForIUnknown(trackerObj2);
340+
IntPtr trackerObjUnwrapped = Marshal.GetIUnknownForObject(trackerObjectWrapper);
341+
Assert.Equal(trackerObj2Identity, trackerObjUnwrapped);
342+
343+
Marshal.Release(trackerObj2Identity);
344+
Marshal.Release(trackerObjUnwrapped);
345+
}
322346
}
323347

324348
private static void ValidatePInvokes(bool validateUseRegistered)
@@ -362,12 +386,12 @@ private static void ValidateInterfaceMarshaler<T>(UpdateTestObject<T> func, bool
362386

363387
T retObj;
364388
int hr = func(testObj as T, value, out retObj);
365-
Assert.Equal(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject);
389+
Assert.Same(testObj, GlobalComWrappers.Instance.LastComputeVtablesObject);
366390
if (shouldSucceed)
367391
{
368392
Assert.True(retObj is Test);
369393
Assert.Equal(value, testObj.GetValue());
370-
Assert.Equal<object>(testObj, retObj);
394+
Assert.Same(testObj, retObj);
371395
}
372396
else
373397
{

0 commit comments

Comments
 (0)