From b2079884700f1229d5b503171c4c8f92e9b49ec3 Mon Sep 17 00:00:00 2001 From: Johan Laanstra Date: Wed, 24 Jan 2024 12:12:05 -0800 Subject: [PATCH] Expose api's for Context on ObjectReference. (#1466) * Expose api's for Context on ObjectReference. * Add GC.KeepAlive. * Add old signature back. * Keep Resurrect. * Use Unsafe.AsRef * Update projections. * Fix. * Update baseline. * Use virtual method to safe cost of a field in every instance. * Fix passing iid. --- src/WinRT.Runtime/ComWrappersSupport.cs | 59 ++--- src/WinRT.Runtime/ComWrappersSupport.net5.cs | 10 +- src/WinRT.Runtime/ExceptionHelpers.cs | 1 + src/WinRT.Runtime/Marshalers.cs | 3 +- .../MatchingRefApiCompatBaseline.txt | 7 +- src/WinRT.Runtime/ObjectReference.cs | 158 +++++++++++-- src/cswinrt/code_writers.h | 12 +- src/cswinrt/strings/WinRT.cs | 223 ++++++------------ 8 files changed, 253 insertions(+), 220 deletions(-) diff --git a/src/WinRT.Runtime/ComWrappersSupport.cs b/src/WinRT.Runtime/ComWrappersSupport.cs index ce92e5d18..6b8d7b44c 100644 --- a/src/WinRT.Runtime/ComWrappersSupport.cs +++ b/src/WinRT.Runtime/ComWrappersSupport.cs @@ -75,14 +75,15 @@ public static void MarshalDelegateInvoke(IntPtr thisPtr, Action invoke) // If we are free threaded, we do not need to keep track of context. // This can either be if the object implements IAgileObject or the free threaded marshaler. - internal unsafe static bool IsFreeThreaded(IObjectReference objRef) + internal unsafe static bool IsFreeThreaded(IntPtr iUnknown) { - if (objRef.TryAs(InterfaceIIDs.IAgileObject_IID, out var agilePtr) >= 0) + if (Marshal.QueryInterface(iUnknown, ref Unsafe.AsRef(InterfaceIIDs.IAgileObject_IID), out var agilePtr) >= 0) { Marshal.Release(agilePtr); return true; } - else if (objRef.TryAs(InterfaceIIDs.IMarshal_IID, out var marshalPtr) >= 0) + + if (Marshal.QueryInterface(iUnknown, ref Unsafe.AsRef(InterfaceIIDs.IMarshal_IID), out var marshalPtr) >= 0) { try { @@ -103,6 +104,14 @@ internal unsafe static bool IsFreeThreaded(IObjectReference objRef) return false; } + internal unsafe static bool IsFreeThreaded(IObjectReference objRef) + { + var isFreeThreaded = IsFreeThreaded(objRef.ThisPtr); + // ThisPtr is owned by objRef, so need to make sure objRef stays alive. + GC.KeepAlive(objRef); + return isFreeThreaded; + } + public static IObjectReference GetObjectReferenceForInterface(IntPtr externalComObject) { return GetObjectReferenceForInterface(externalComObject); @@ -115,21 +124,7 @@ public static ObjectReference GetObjectReferenceForInterface(IntPtr extern return null; } - ObjectReference objRef = ObjectReference.FromAbi(externalComObject); - if (IsFreeThreaded(objRef)) - { - return objRef; - } - else - { - using (objRef) - { - return new ObjectReferenceWithContext( - objRef.GetRef(), - Context.GetContextCallback(), - Context.GetContextToken()); - } - } + return ObjectReference.FromAbi(externalComObject); } public static ObjectReference GetObjectReferenceForInterface(IntPtr externalComObject, Guid iid) @@ -144,31 +139,14 @@ internal static ObjectReference GetObjectReferenceForInterface(IntPtr exte return null; } - ObjectReference objRef; if (requireQI) { Marshal.ThrowExceptionForHR(Marshal.QueryInterface(externalComObject, ref iid, out IntPtr ptr)); - objRef = ObjectReference.Attach(ref ptr); + return ObjectReference.Attach(ref ptr, iid); } else { - objRef = ObjectReference.FromAbi(externalComObject); - } - - if (IsFreeThreaded(objRef)) - { - return objRef; - } - else - { - using (objRef) - { - return new ObjectReferenceWithContext( - objRef.GetRef(), - Context.GetContextCallback(), - Context.GetContextToken(), - iid); - } + return ObjectReference.FromAbi(externalComObject, iid); } } @@ -487,7 +465,12 @@ private static Func CreateCustomTypeMappingFactory(Type cu } var fromAbiMethodFunc = (Func) fromAbiMethod.CreateDelegate(typeof(Func)); - return (IInspectable obj) => fromAbiMethodFunc(obj.ThisPtr); + return (IInspectable obj) => + { + var fromAbiMethod = fromAbiMethodFunc(obj.ThisPtr); + GC.KeepAlive(obj); + return fromAbiMethod; + }; } internal static Func CreateTypedRcwFactory( diff --git a/src/WinRT.Runtime/ComWrappersSupport.net5.cs b/src/WinRT.Runtime/ComWrappersSupport.net5.cs index 05bd844f5..492de65db 100644 --- a/src/WinRT.Runtime/ComWrappersSupport.net5.cs +++ b/src/WinRT.Runtime/ComWrappersSupport.net5.cs @@ -175,7 +175,7 @@ public static object TryRegisterObjectForInterface(object obj, IntPtr thisPtr) public static IObjectReference CreateCCWForObject(object obj) { IntPtr ccw = ComWrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.TrackerSupport); - return ObjectReference.Attach(ref ccw); + return ObjectReference.Attach(ref ccw, InterfaceIIDs.IUnknown_IID); } internal static IntPtr CreateCCWForObjectForABI(object obj, Guid iid) @@ -350,8 +350,7 @@ public unsafe static void Init( // otherwise the new instance will be used. Since the inner was composed // it should answer immediately without going through the outer. Either way // the reference count will go to the new instance. - Guid iid = IReferenceTrackerVftbl.IID; - int hr = Marshal.QueryInterface(objRef.ThisPtr, ref iid, out referenceTracker); + int hr = Marshal.QueryInterface(objRef.ThisPtr, ref Unsafe.AsRef(IReferenceTrackerVftbl.IID), out referenceTracker); if (hr != 0) { referenceTracker = default; @@ -450,9 +449,8 @@ public unsafe static void Init( public unsafe static void Init(IObjectReference objRef, bool addRefFromTrackerSource = true) { if (objRef.ReferenceTrackerPtr == IntPtr.Zero) - { - Guid iid = IReferenceTrackerVftbl.IID; - int hr = Marshal.QueryInterface(objRef.ThisPtr, ref iid, out var referenceTracker); + { + int hr = Marshal.QueryInterface(objRef.ThisPtr, ref Unsafe.AsRef(IReferenceTrackerVftbl.IID), out var referenceTracker); if (hr == 0) { // WinUI scenario diff --git a/src/WinRT.Runtime/ExceptionHelpers.cs b/src/WinRT.Runtime/ExceptionHelpers.cs index 6f0f23752..ad40f032d 100644 --- a/src/WinRT.Runtime/ExceptionHelpers.cs +++ b/src/WinRT.Runtime/ExceptionHelpers.cs @@ -364,6 +364,7 @@ public static void ReportUnhandledError(Exception ex) if (restrictedErrorInfoRef != null) { roReportUnhandledError(restrictedErrorInfoRef.ThisPtr); + GC.KeepAlive(restrictedErrorInfoRef); } } } diff --git a/src/WinRT.Runtime/Marshalers.cs b/src/WinRT.Runtime/Marshalers.cs index f6dafb437..2d3acb376 100644 --- a/src/WinRT.Runtime/Marshalers.cs +++ b/src/WinRT.Runtime/Marshalers.cs @@ -1581,8 +1581,7 @@ public static T FromAbi(IntPtr ptr) IntPtr iunknownPtr = IntPtr.Zero; try { - Guid iid_iunknown = IUnknownVftbl.IID; - Marshal.QueryInterface(ptr, ref iid_iunknown, out iunknownPtr); + Marshal.QueryInterface(ptr, ref Unsafe.AsRef(IUnknownVftbl.IID), out iunknownPtr); if (IUnknownVftbl.IsReferenceToManagedObject(iunknownPtr)) { return (T)ComWrappersSupport.FindObject(iunknownPtr); diff --git a/src/WinRT.Runtime/MatchingRefApiCompatBaseline.txt b/src/WinRT.Runtime/MatchingRefApiCompatBaseline.txt index 1124a27ab..887755b32 100644 --- a/src/WinRT.Runtime/MatchingRefApiCompatBaseline.txt +++ b/src/WinRT.Runtime/MatchingRefApiCompatBaseline.txt @@ -131,4 +131,9 @@ MembersMustExist : Member 'public System.String WinRT.WindowsRuntimeTypeAttribut TypesMustExist : Type 'WinRT.WinRTExposedTypeAttribute' does not exist in the reference but it does exist in the implementation. MembersMustExist : Member 'public T ABI.System.Nullable.GetValue(WinRT.IInspectable)' does not exist in the reference but it does exist in the implementation. TypesMustExist : Type 'WinRT.EventRegistrationTokenTable' does not exist in the reference but it does exist in the implementation. -Total Issues: 132 +MembersMustExist : Member 'public System.Boolean WinRT.IObjectReference.IsFreeThreaded.get()' does not exist in the reference but it does exist in the implementation. +MembersMustExist : Member 'public System.Boolean WinRT.IObjectReference.IsInCurrentContext.get()' does not exist in the reference but it does exist in the implementation. +MembersMustExist : Member 'public WinRT.ObjectReference WinRT.ObjectReference.Attach(System.IntPtr, System.Guid)' does not exist in the reference but it does exist in the implementation. +MembersMustExist : Member 'public WinRT.ObjectReference WinRT.ObjectReference.FromAbi(System.IntPtr, System.Guid)' does not exist in the reference but it does exist in the implementation. +MembersMustExist : Member 'public WinRT.ObjectReference WinRT.ObjectReference.FromAbi(System.IntPtr, T, System.Guid)' does not exist in the reference but it does exist in the implementation. +Total Issues: 137 diff --git a/src/WinRT.Runtime/ObjectReference.cs b/src/WinRT.Runtime/ObjectReference.cs index 40efc1eb2..74fbe7b90 100644 --- a/src/WinRT.Runtime/ObjectReference.cs +++ b/src/WinRT.Runtime/ObjectReference.cs @@ -36,6 +36,17 @@ public IntPtr ThisPtr } } + public bool IsFreeThreaded => GetContextToken() == IntPtr.Zero; + + public bool IsInCurrentContext + { + get + { + var contextToken = GetContextToken(); + return contextToken == IntPtr.Zero || contextToken == Context.GetContextToken(); + } + } + private protected IntPtr ThisPtrFromOriginalContext { get @@ -306,7 +317,7 @@ internal bool Resurrect() protected virtual unsafe void AddRef(bool refFromTrackerSource) { Marshal.AddRef(ThisPtr); - if(refFromTrackerSource) + if (refFromTrackerSource) { AddRefFromTrackerSource(); } @@ -382,6 +393,11 @@ private protected virtual IntPtr GetThisPtrForCurrentContext() return ThisPtrFromOriginalContext; } + private protected virtual IntPtr GetContextToken() + { + return IntPtr.Zero; + } + public ObjectReferenceValue AsValue() { // Sharing ptr with objref. @@ -423,26 +439,64 @@ public T Vftbl } } + private protected ObjectReference(IntPtr thisPtr, T vftblT) : + base(thisPtr) + { + _vftbl = vftblT; + } + + private protected ObjectReference(IntPtr thisPtr) : + this(thisPtr, GetVtable(thisPtr)) + { + } + public static ObjectReference Attach(ref IntPtr thisPtr) { if (thisPtr == IntPtr.Zero) { return null; } - var obj = new ObjectReference(thisPtr); - thisPtr = IntPtr.Zero; - return obj; - } - ObjectReference(IntPtr thisPtr, T vftblT) : - base(thisPtr) - { - _vftbl = vftblT; + if (ComWrappersSupport.IsFreeThreaded(thisPtr)) + { + var obj = new ObjectReference(thisPtr); + thisPtr = IntPtr.Zero; + return obj; + } + else + { + var obj = new ObjectReferenceWithContext( + thisPtr, + Context.GetContextCallback(), + Context.GetContextToken()); + thisPtr = IntPtr.Zero; + return obj; + } } - private protected ObjectReference(IntPtr thisPtr) : - this(thisPtr, GetVtable(thisPtr)) + public static ObjectReference Attach(ref IntPtr thisPtr, Guid iid) { + if (thisPtr == IntPtr.Zero) + { + return null; + } + + if (ComWrappersSupport.IsFreeThreaded(thisPtr)) + { + var obj = new ObjectReference(thisPtr); + thisPtr = IntPtr.Zero; + return obj; + } + else + { + var obj = new ObjectReferenceWithContext( + thisPtr, + Context.GetContextCallback(), + Context.GetContextToken(), + iid); + thisPtr = IntPtr.Zero; + return obj; + } } public static unsafe ObjectReference FromAbi(IntPtr thisPtr, T vftblT) @@ -450,10 +504,48 @@ public static unsafe ObjectReference FromAbi(IntPtr thisPtr, T vftblT) if (thisPtr == IntPtr.Zero) { return null; + } + + Marshal.AddRef(thisPtr); + if (ComWrappersSupport.IsFreeThreaded(thisPtr)) + { + var obj = new ObjectReference(thisPtr, vftblT); + return obj; } + else + { + var obj = new ObjectReferenceWithContext( + thisPtr, + vftblT, + Context.GetContextCallback(), + Context.GetContextToken()); + return obj; + } + } + + public static unsafe ObjectReference FromAbi(IntPtr thisPtr, T vftblT, Guid iid) + { + if (thisPtr == IntPtr.Zero) + { + return null; + } + Marshal.AddRef(thisPtr); - var obj = new ObjectReference(thisPtr, vftblT); - return obj; + if (ComWrappersSupport.IsFreeThreaded(thisPtr)) + { + var obj = new ObjectReference(thisPtr, vftblT); + return obj; + } + else + { + var obj = new ObjectReferenceWithContext( + thisPtr, + vftblT, + Context.GetContextCallback(), + Context.GetContextToken(), + iid); + return obj; + } } public static ObjectReference FromAbi(IntPtr thisPtr) @@ -466,6 +558,16 @@ public static ObjectReference FromAbi(IntPtr thisPtr) return FromAbi(thisPtr, vftblT); } + public static ObjectReference FromAbi(IntPtr thisPtr, Guid iid) + { + if (thisPtr == IntPtr.Zero) + { + return null; + } + var vftblT = GetVtable(thisPtr); + return FromAbi(thisPtr, vftblT, iid); + } + private static unsafe T GetVtable(IntPtr thisPtr) { var vftblPtr = Unsafe.AsRef(thisPtr.ToPointer()); @@ -504,7 +606,7 @@ internal sealed class ObjectReferenceWithContext< #endif T> : ObjectReference { - private readonly IntPtr _contextCallbackPtr; + private readonly IntPtr _contextCallbackPtr; private readonly IntPtr _contextToken; private volatile ConcurrentDictionary> __cachedContext; @@ -520,7 +622,7 @@ private ConcurrentDictionary> Make_CachedContext() private volatile AgileReference __agileReference; private AgileReference AgileReference => _isAgileReferenceSet ? __agileReference : Make_AgileReference(); private AgileReference Make_AgileReference() - { + { Context.CallInContext(_contextCallbackPtr, _contextToken, InitAgileReference, null); // Set after CallInContext callback given callback can fail to occur. @@ -536,16 +638,29 @@ void InitAgileReference() private readonly Guid _iid; internal ObjectReferenceWithContext(IntPtr thisPtr, IntPtr contextCallbackPtr, IntPtr contextToken) - :base(thisPtr) + : base(thisPtr) { - _contextCallbackPtr = contextCallbackPtr; + _contextCallbackPtr = contextCallbackPtr; _contextToken = contextToken; - } - + } + internal ObjectReferenceWithContext(IntPtr thisPtr, IntPtr contextCallbackPtr, IntPtr contextToken, Guid iid) : this(thisPtr, contextCallbackPtr, contextToken) { _iid = iid; + } + + internal ObjectReferenceWithContext(IntPtr thisPtr, T vftblT, IntPtr contextCallbackPtr, IntPtr contextToken) + : base(thisPtr, vftblT) + { + _contextCallbackPtr = contextCallbackPtr; + _contextToken = contextToken; + } + + internal ObjectReferenceWithContext(IntPtr thisPtr, T vftblT, IntPtr contextCallbackPtr, IntPtr contextToken, Guid iid) + : this(thisPtr, vftblT, contextCallbackPtr, contextToken) + { + _iid = iid; } private protected override IntPtr GetThisPtrForCurrentContext() @@ -559,6 +674,11 @@ private protected override IntPtr GetThisPtrForCurrentContext() return cachedObjRef.ThisPtr; } + private protected override IntPtr GetContextToken() + { + return this._contextToken; + } + private protected override T GetVftblForCurrentContext() { ObjectReference cachedObjRef = GetCurrentContext(); diff --git a/src/cswinrt/code_writers.h b/src/cswinrt/code_writers.h index 38a3f97d3..53dc2aca5 100644 --- a/src/cswinrt/code_writers.h +++ b/src/cswinrt/code_writers.h @@ -1935,13 +1935,13 @@ private static % _% = new %("%.%", %.IID); { auto objrefname = w.write_temp("%", bind(classType)); w.write(R"( -private static volatile FactoryObjectReference __%; -private static FactoryObjectReference % +private static volatile ObjectReference __%; +private static ObjectReference % { get { var factory = __%; - if (factory != null && factory.IsObjectInContext()) + if (factory != null && factory.IsInCurrentContext) { return factory; } @@ -1993,13 +1993,13 @@ private static ObjectReference<%> % => __% ?? Make__%(); { auto objrefname = w.write_temp("%", bind(staticsType)); w.write(R"( -private static volatile FactoryObjectReference<%> __%; -private static FactoryObjectReference<%> % +private static volatile ObjectReference<%> __%; +private static ObjectReference<%> % { get { var factory = __%; - if (factory != null && factory.IsObjectInContext()) + if (factory != null && factory.IsInCurrentContext) { return factory; } diff --git a/src/cswinrt/strings/WinRT.cs b/src/cswinrt/strings/WinRT.cs index 430a13778..32e3a58cc 100644 --- a/src/cswinrt/strings/WinRT.cs +++ b/src/cswinrt/strings/WinRT.cs @@ -120,8 +120,8 @@ internal static unsafe int CoCreateInstance(ref Guid clsid, IntPtr outer, uint c internal static extern int CoDecrementMTAUsage(IntPtr cookie); [DllImport("api-ms-win-core-com-l1-1-0.dll")] - internal static extern unsafe int CoIncrementMTAUsage(IntPtr* cookie); - + internal static extern unsafe int CoIncrementMTAUsage(IntPtr* cookie); + #if NET6_0_OR_GREATER internal static bool FreeLibrary(IntPtr moduleHandle) { @@ -142,8 +142,8 @@ internal static bool FreeLibrary(IntPtr moduleHandle) // Local P/Invoke [DllImportAttribute("kernel32.dll", EntryPoint = "FreeLibrary", ExactSpelling = true)] static extern unsafe int PInvoke(IntPtr nativeModuleHandle); - } - + } + internal static unsafe void* TryGetProcAddress(IntPtr moduleHandle, sbyte* functionName) { int lastError; @@ -152,15 +152,15 @@ internal static bool FreeLibrary(IntPtr moduleHandle) Marshal.SetLastSystemError(0); returnValue = PInvoke(moduleHandle, functionName); lastError = Marshal.GetLastSystemError(); - } - + } + Marshal.SetLastPInvokeError(lastError); return returnValue; // Local P/Invoke [DllImportAttribute("kernel32.dll", EntryPoint = "GetProcAddress", ExactSpelling = true)] static extern unsafe void* PInvoke(IntPtr nativeModuleHandle, sbyte* nativeFunctionName); - } + } #else [DllImport("kernel32.dll", SetLastError = true)] [return: MarshalAs(UnmanagedType.Bool)] @@ -244,9 +244,9 @@ internal static bool FreeLibrary(IntPtr moduleHandle) Marshal.ThrowExceptionForHR(Marshal.GetHRForLastWin32Error(), new IntPtr(-1)); } return functionPtr; - } - -#if NET6_0_OR_GREATER + } + +#if NET6_0_OR_GREATER internal static unsafe IntPtr LoadLibraryExW(ushort* fileName, IntPtr fileHandle, uint flags) { int lastError; @@ -255,15 +255,15 @@ internal static unsafe IntPtr LoadLibraryExW(ushort* fileName, IntPtr fileHandle Marshal.SetLastSystemError(0); returnValue = PInvoke(fileName, fileHandle, flags); lastError = Marshal.GetLastSystemError(); - } - + } + Marshal.SetLastPInvokeError(lastError); return returnValue; // Local P/Invoke [DllImportAttribute("kernel32.dll", EntryPoint = "LoadLibraryExW", ExactSpelling = true)] static extern unsafe IntPtr PInvoke(ushort* nativeFileName, IntPtr nativeFileHandle, uint nativeFlags); - } + } #else [DllImport("kernel32.dll", SetLastError = true)] internal static unsafe extern IntPtr LoadLibraryExW(ushort* fileName, IntPtr fileHandle, uint flags); @@ -273,7 +273,7 @@ internal static unsafe IntPtr LoadLibraryExW(string fileName, IntPtr fileHandle, fixed (char* lpFileName = fileName) return LoadLibraryExW((ushort*)lpFileName, fileHandle, flags); } - + [DllImport("api-ms-win-core-winrt-l1-1-0.dll")] internal static extern unsafe int RoGetActivationFactory(IntPtr runtimeClassId, Guid* iid, IntPtr* factory); @@ -335,17 +335,17 @@ internal struct VftblPtr { public IntPtr Vftbl; } - internal static partial class Context - { - [DllImport("api-ms-win-core-com-l1-1-0.dll")] - private static extern unsafe int CoGetContextToken(IntPtr* contextToken); - - public unsafe static IntPtr GetContextToken() - { - IntPtr contextToken; - Marshal.ThrowExceptionForHR(CoGetContextToken(&contextToken)); - return contextToken; - } + internal static partial class Context + { + [DllImport("api-ms-win-core-com-l1-1-0.dll")] + private static extern unsafe int CoGetContextToken(IntPtr* contextToken); + + public unsafe static IntPtr GetContextToken() + { + IntPtr contextToken; + Marshal.ThrowExceptionForHR(CoGetContextToken(&contextToken)); + return contextToken; + } } internal unsafe sealed class DllModule @@ -406,11 +406,11 @@ private static unsafe bool TryCreate(string fileName, out DllModule module) { module = null; return false; - } - + } + module = new DllModule( - fileName, - moduleHandle, + fileName, + moduleHandle, getActivationFactory); return true; } @@ -435,7 +435,7 @@ private DllModule(string fileName, IntPtr moduleHandle, void* getActivationFacto } } - public unsafe (FactoryObjectReference obj, int hr) GetActivationFactory(string runtimeClassId) + public unsafe (ObjectReference obj, int hr) GetActivationFactory(string runtimeClassId) { IntPtr instancePtr = IntPtr.Zero; try @@ -446,7 +446,7 @@ public unsafe (FactoryObjectReference obj, int hr) GetA int hr = _GetActivationFactory(MarshalString.GetAbi(ref __runtimeClassId), &instancePtr); if (hr == 0) { - var objRef = FactoryObjectReference.Attach(ref instancePtr); + var objRef = ObjectReference.Attach(ref instancePtr); return (objRef, hr); } else @@ -493,7 +493,7 @@ public unsafe WinrtModule() _mtaCookie = mtaCookie; } - public static unsafe (FactoryObjectReference obj, int hr) GetActivationFactory(string runtimeClassId, Guid iid) + public static unsafe (ObjectReference obj, int hr) GetActivationFactory(string runtimeClassId, Guid iid) { var module = Instance; // Ensure COM is initialized IntPtr instancePtr = IntPtr.Zero; @@ -501,11 +501,11 @@ public static unsafe (FactoryObjectReference obj, int hr) GetActivationFactor { MarshalString.Pinnable __runtimeClassId = new(runtimeClassId); fixed (void* ___runtimeClassId = __runtimeClassId) - { + { int hr = Platform.RoGetActivationFactory(MarshalString.GetAbi(ref __runtimeClassId), &iid, &instancePtr); if (hr == 0) { - var objRef = FactoryObjectReference.Attach(ref instancePtr); + var objRef = ObjectReference.Attach(ref instancePtr); return (objRef, hr); } else @@ -526,72 +526,7 @@ public static unsafe (FactoryObjectReference obj, int hr) GetActivationFactor } } - internal sealed class FactoryObjectReference< -#if NET - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.NonPublicConstructors)] -#endif - T> : IObjectReference - { - private readonly IntPtr _contextToken; - - public static FactoryObjectReference Attach(ref IntPtr thisPtr) - { - if (thisPtr == IntPtr.Zero) - { - return null; - } - var obj = new FactoryObjectReference(thisPtr); - thisPtr = IntPtr.Zero; - return obj; - } - - internal FactoryObjectReference(IntPtr thisPtr) : - base(thisPtr) - { - if (!IsFreeThreaded(this)) - { - _contextToken = Context.GetContextToken(); - } - } - - internal FactoryObjectReference(IntPtr thisPtr, IntPtr contextToken) - : base(thisPtr) - { - _contextToken = contextToken; - } - - public static new unsafe FactoryObjectReference FromAbi(IntPtr thisPtr) - { - if (thisPtr == IntPtr.Zero) - { - return null; - } - var obj = new FactoryObjectReference(thisPtr); - obj.VftblIUnknown.AddRef(obj.ThisPtr); - return obj; - } - - public bool IsObjectInContext() - { - return _contextToken == IntPtr.Zero || _contextToken == Context.GetContextToken(); - } - - // If we are free threaded, we do not need to keep track of context. - // This can either be if the object implements IAgileObject or the free threaded marshaler. - // We only check IAgileObject for now as the necessary code to check the - // free threaded marshaler is not exposed from WinRT.Runtime. - private unsafe static bool IsFreeThreaded(IObjectReference objRef) - { - if (objRef.TryAs(InterfaceIIDs.IAgileObject_IID, out var agilePtr) >= 0) - { - Marshal.Release(agilePtr); - return true; - } - return false; - } - } - - internal static class IActivationFactoryMethods + internal static class IActivationFactoryMethods { public static unsafe ObjectReference ActivateInstance(IObjectReference obj) { @@ -605,20 +540,20 @@ public static unsafe ObjectReference ActivateInstance(IObjectReference obj { MarshalInspectable.DisposeAbi(instancePtr); } - } + } } internal static class ActivationFactory { - public static FactoryObjectReference Get(string typeName) - { + public static ObjectReference Get(string typeName) + { // Prefer the RoGetActivationFactory HRESULT failure over the LoadLibrary/etc. failure int hr; - FactoryObjectReference factory; + ObjectReference factory; (factory, hr) = WinrtModule.GetActivationFactory(typeName, InterfaceIIDs.IActivationFactory_IID); - if (factory != null) + if (factory != null) { - return factory; + return factory; } var moduleName = typeName; @@ -634,37 +569,37 @@ public static FactoryObjectReference Get(string typeNam DllModule module = null; if (DllModule.TryLoad(moduleName + ".dll", out module)) { - (factory, hr) = module.GetActivationFactory(typeName); - if (factory != null) - { - return factory; + (factory, hr) = module.GetActivationFactory(typeName); + if (factory != null) + { + return factory; } } - } - } - + } + } + #if NET - public static FactoryObjectReference Get< - [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.NonPublicConstructors | DynamicallyAccessedMemberTypes.PublicFields)] -#else public static ObjectReference Get< -#endif - I>(string typeName, Guid iid) - { + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.NonPublicConstructors | DynamicallyAccessedMemberTypes.PublicFields)] +#else + public static ObjectReference Get< +#endif + I>(string typeName, Guid iid) + { // Prefer the RoGetActivationFactory HRESULT failure over the LoadLibrary/etc. failure int hr; - FactoryObjectReference factory; + ObjectReference factory; (factory, hr) = WinrtModule.GetActivationFactory(typeName, iid); - if (factory != null) - { + if (factory != null) + { #if NET - return factory; -#else - using (factory) - { - return factory.As(iid); - } -#endif + return factory; +#else + using (factory) + { + return factory.As(iid); + } +#endif } var moduleName = typeName; @@ -680,24 +615,17 @@ public static ObjectReference Get< DllModule module = null; if (DllModule.TryLoad(moduleName + ".dll", out module)) { - FactoryObjectReference activationFactory; - (activationFactory, hr) = module.GetActivationFactory(typeName); - if (activationFactory != null) - { - using (activationFactory) - { -#if NET - if (activationFactory.TryAs(iid, out IntPtr iidPtr) >= 0) - { - return FactoryObjectReference.Attach(ref iidPtr); - } -#else - return activationFactory.As(iid); -#endif - } + ObjectReference activationFactory; + (activationFactory, hr) = module.GetActivationFactory(typeName); + if (activationFactory != null) + { + using (activationFactory) + { + return activationFactory.As(iid); + } } } - } + } } } @@ -781,8 +709,7 @@ public void Dispose() public void InitalizeReferenceTracking(IntPtr ptr) { eventInvokePtr = ptr; - Guid iid = IReferenceTrackerTargetVftbl.IID; - int hr = Marshal.QueryInterface(ptr, ref iid, out referenceTrackerTargetPtr); + int hr = Marshal.QueryInterface(ptr, ref Unsafe.AsRef(IReferenceTrackerTargetVftbl.IID), out referenceTrackerTargetPtr); if (hr != 0) { referenceTrackerTargetPtr = default; @@ -1112,7 +1039,7 @@ protected override Delegate GetEventInvoke() #pragma warning restore CA2002 internal static class InterfaceIIDs - { + { #if NET internal static readonly Guid IInspectable_IID = new Guid(new global::System.ReadOnlySpan(new byte[] { 0xE0, 0xE2, 0x86, 0xAF, 0x2D, 0xB1, 0x6A, 0x4C, 0x9C, 0x5A, 0xD7, 0xAA, 0x65, 0x10, 0x1E, 0x90 })); internal static readonly Guid IUnknown_IID = new Guid(new global::System.ReadOnlySpan(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46 }));