Skip to content

Commit aefb0fc

Browse files
Don't create a COM weak reference if the object is an aggregated COMWrappers RCW. (#61267)
* Don't create a COM weak reference if the object is an aggregated COMWrappers RCW. * Add test for weak reference + aggregation with native weak reference impl. * Apply suggestions from code review Co-authored-by: Aaron Robinson <arobins@microsoft.com> Co-authored-by: Aaron Robinson <arobins@microsoft.com>
1 parent dee4c0c commit aefb0fc

File tree

5 files changed

+218
-39
lines changed

5 files changed

+218
-39
lines changed

src/coreclr/vm/interoplibinterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class ComWrappersNative
3434
static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe);
3535

3636
public: // Unwrapping support
37-
static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId);
37+
static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId, _Out_ bool* isAggregated);
3838
static bool HasManagedObjectComWrapper(_In_ OBJECTREF object, _Out_ bool* isActive);
3939

4040
public: // GC interaction

src/coreclr/vm/interoplibinterface_comwrappers.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ namespace
5353
// The EOC is "detached" and no longer used to map between identity and a managed object.
5454
// This will only be set if the EOC was inserted into the cache.
5555
Flags_Detached = 8,
56+
57+
// This EOC is an aggregated instance
58+
Flags_Aggregated = 16
5659
};
5760
DWORD Flags;
5861

@@ -900,7 +903,11 @@ namespace
900903
: ExternalObjectContext::Flags_None) |
901904
(uniqueInstance
902905
? ExternalObjectContext::Flags_None
903-
: ExternalObjectContext::Flags_InCache);
906+
: ExternalObjectContext::Flags_InCache) |
907+
((flags & CreateObjectFlags::CreateObjectFlags_Aggregated) != 0
908+
? ExternalObjectContext::Flags_Aggregated
909+
: ExternalObjectContext::Flags_None);
910+
904911
ExternalObjectContext::Construct(
905912
resultHolder.GetContext(),
906913
identity,
@@ -1774,7 +1781,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(
17741781
objRef);
17751782
}
17761783

1777-
IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId)
1784+
IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId, _Out_ bool* isAggregated)
17781785
{
17791786
CONTRACTL
17801787
{
@@ -1807,6 +1814,7 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE
18071814
{
18081815
ExternalObjectContext* context = reinterpret_cast<ExternalObjectContext*>(contextMaybe);
18091816
*wrapperId = context->WrapperId;
1817+
*isAggregated = context->IsSet(ExternalObjectContext::Flags_Aggregated);
18101818

18111819
IUnknown* identity = reinterpret_cast<IUnknown*>(context->Identity);
18121820
GCX_PREEMP();

src/coreclr/vm/weakreferencenative.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ struct WeakHandleSpinLockHolder
108108
//
109109
// In order to qualify to be used with a HNDTYPE_WEAK_NATIVE_COM, the incoming object must:
110110
// * be an RCW
111+
// * not be an aggregated RCW
111112
// * respond to a QI for IWeakReferenceSource
112113
// * succeed when asked for an IWeakReference*
113114
//
@@ -149,7 +150,14 @@ NativeComWeakHandleInfo* GetComWeakReferenceInfo(OBJECTREF* pObject)
149150
#endif
150151
{
151152
#ifdef FEATURE_COMWRAPPERS
152-
pWeakReferenceSource = reinterpret_cast<IWeakReferenceSource*>(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId));
153+
bool isAggregated = false;
154+
pWeakReferenceSource = reinterpret_cast<IWeakReferenceSource*>(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId, &isAggregated));
155+
if (isAggregated)
156+
{
157+
// If the RCW is an aggregated RCW, then the managed object cannot be recreated from the IUnknown as the outer IUnknown wraps the managed object.
158+
// In this case, don't create a weak reference backed by a COM weak reference.
159+
pWeakReferenceSource = nullptr;
160+
}
153161
#endif
154162
}
155163

@@ -448,7 +456,7 @@ FCIMPL3(void, WeakReferenceNative::Create, WeakReferenceObject * pThisUNSAFE, Ob
448456
_ASSERTE(gc.pThis->GetMethodTable()->CanCastToClass(pWeakReferenceMT));
449457

450458
// Create the handle.
451-
#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
459+
#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
452460
NativeComWeakHandleInfo *comWeakHandleInfo = nullptr;
453461
if (gc.pTarget != NULL)
454462
{
@@ -690,7 +698,7 @@ FCIMPL1(Object *, WeakReferenceNative::GetTarget, WeakReferenceObject * pThisUNS
690698

691699
OBJECTREF pTarget = GetWeakReferenceTarget(pThis);
692700

693-
#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
701+
#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
694702
// If we found an object, or we're not a native COM weak reference, then we're done. Othewrise
695703
// we can try to create a new RCW to the underlying native COM object if it's still alive.
696704
if (pTarget != NULL || !IsNativeComWeakReferenceHandle(pThis->m_Handle))
@@ -718,7 +726,7 @@ FCIMPL1(Object *, WeakReferenceOfTNative::GetTarget, WeakReferenceObject * pThis
718726
OBJECTREF pTarget = GetWeakReferenceTarget(pThis);
719727

720728

721-
#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
729+
#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
722730
// If we found an object, or we're not a native COM weak reference, then we're done. Othewrise
723731
// we can try to create a new RCW to the underlying native COM object if it's still alive.
724732
if (pTarget != NULL || !IsNativeComWeakReferenceHandle(pThis->m_Handle))

src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceNative.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,122 @@ namespace
167167
return UnknownImpl::DoRelease();
168168
}
169169
};
170+
171+
struct WeakReferenceSource : public IWeakReferenceSource, public IInspectable
172+
{
173+
private:
174+
IUnknown* _outerUnknown;
175+
ComSmartPtr<WeakReference> _weakReference;
176+
public:
177+
WeakReferenceSource(IUnknown* outerUnknown)
178+
:_outerUnknown(outerUnknown),
179+
_weakReference(new WeakReference(this, 1))
180+
{
181+
}
182+
183+
STDMETHOD(GetWeakReference)(IWeakReference** ppWeakReference)
184+
{
185+
_weakReference->AddRef();
186+
*ppWeakReference = _weakReference;
187+
return S_OK;
188+
}
189+
190+
STDMETHOD(QueryInterface)(
191+
/* [in] */ REFIID riid,
192+
/* [iid_is][out] */ void ** ppvObject)
193+
{
194+
if (riid == __uuidof(IWeakReferenceSource))
195+
{
196+
*ppvObject = static_cast<IWeakReferenceSource*>(this);
197+
_weakReference->AddStrongRef();
198+
return S_OK;
199+
}
200+
return _outerUnknown->QueryInterface(riid, ppvObject);
201+
}
202+
STDMETHOD_(ULONG, AddRef)(void)
203+
{
204+
return _weakReference->AddStrongRef();
205+
}
206+
STDMETHOD_(ULONG, Release)(void)
207+
{
208+
return _weakReference->ReleaseStrongRef();
209+
}
210+
211+
STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName)
212+
{
213+
return E_NOTIMPL;
214+
}
215+
216+
STDMETHOD(GetIids)(
217+
ULONG *iidCount,
218+
IID **iids)
219+
{
220+
return E_NOTIMPL;
221+
}
222+
223+
STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel)
224+
{
225+
*trustLevel = FullTrust;
226+
return S_OK;
227+
}
228+
};
229+
230+
struct AggregatedWeakReferenceSource : IInspectable
231+
{
232+
private:
233+
IUnknown* _outerUnknown;
234+
ComSmartPtr<WeakReferenceSource> _weakReference;
235+
public:
236+
AggregatedWeakReferenceSource(IUnknown* outerUnknown)
237+
:_outerUnknown(outerUnknown),
238+
_weakReference(new WeakReferenceSource(outerUnknown))
239+
{
240+
}
241+
242+
STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName)
243+
{
244+
return E_NOTIMPL;
245+
}
246+
247+
STDMETHOD(GetIids)(
248+
ULONG *iidCount,
249+
IID **iids)
250+
{
251+
return E_NOTIMPL;
252+
}
253+
254+
STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel)
255+
{
256+
*trustLevel = FullTrust;
257+
return S_OK;
258+
}
259+
260+
STDMETHOD(QueryInterface)(
261+
/* [in] */ REFIID riid,
262+
/* [iid_is][out] */ void ** ppvObject)
263+
{
264+
if (riid == __uuidof(IWeakReferenceSource))
265+
{
266+
return _weakReference->QueryInterface(riid, ppvObject);
267+
}
268+
return _outerUnknown->QueryInterface(riid, ppvObject);
269+
}
270+
STDMETHOD_(ULONG, AddRef)(void)
271+
{
272+
return _outerUnknown->AddRef();
273+
}
274+
STDMETHOD_(ULONG, Release)(void)
275+
{
276+
return _outerUnknown->Release();
277+
}
278+
};
170279
}
171280
extern "C" DLL_EXPORT WeakReferencableObject* STDMETHODCALLTYPE CreateWeakReferencableObject()
172281
{
173282
return new WeakReferencableObject();
174283
}
284+
285+
extern "C" DLL_EXPORT AggregatedWeakReferenceSource* STDMETHODCALLTYPE CreateAggregatedWeakReferenceObject(IUnknown* pOuter)
286+
{
287+
return new AggregatedWeakReferenceSource(pOuter);
288+
}

0 commit comments

Comments
 (0)