Skip to content

Commit

Permalink
Implement 'ConditionalWeakTable<TKey,TValue>.GetOrAdd' APIs (dotnet#1…
Browse files Browse the repository at this point in the history
…11204)

* Add '[EditorBrowsable(Never)]' to APIs

* Add 'GetOrAdd' API

* Add 'GetOrAdd' API

* Add 'GetOrAdd' API

* Update ref assembly

* Add unit tests

* Add XML docs for new APIs

* Remove 'Atomically' to clarify docs

* Convert uses to new APIs

* Remove leftover unused method

* Switch 'GetOrCreateComInterfaceForObject' to local type

* Apply suggestions from code review

* Lower threshold time for new tests

---------

Co-authored-by: Jan Kotas <jkotas@microsoft.com>
  • Loading branch information
Sergio0694 and jkotas authored Jan 29, 2025
1 parent d9b7515 commit e66d834
Show file tree
Hide file tree
Showing 10 changed files with 344 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,16 @@ internal sealed class DispenserThatReusesAsLongAsKeyIsAlive<K, [DynamicallyAcces
{
public DispenserThatReusesAsLongAsKeyIsAlive(Func<K, V> factory)
{
_createValueCallback = CreateValue;
_conditionalWeakTable = new ConditionalWeakTable<K, V>();
_factory = factory;
}

public sealed override V GetOrAdd(K key)
{
return _conditionalWeakTable.GetValue(key, _createValueCallback);
}

private V CreateValue(K key)
{
return _factory(key);
return _conditionalWeakTable.GetOrAdd(key, _factory);
}

private readonly Func<K, V> _factory;
private readonly ConditionalWeakTable<K, V> _conditionalWeakTable;
private readonly ConditionalWeakTable<K, V>.CreateValueCallback _createValueCallback;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,13 @@ public void DisconnectTracker()
}
}

// Custom type instead of a value tuple to avoid rooting 'ITuple' and other value tuple stuff
private struct GetOrCreateComInterfaceForObjectParameters
{
public ComWrappers? This;
public CreateComInterfaceFlags Flags;
}

/// <summary>
/// Create a COM representation of the supplied object that can be passed to a non-managed environment.
/// </summary>
Expand All @@ -716,18 +723,12 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
{
ArgumentNullException.ThrowIfNull(instance);

ManagedObjectWrapperHolder? managedObjectWrapper;
if (_managedObjectWrapperTable.TryGetValue(instance, out managedObjectWrapper))
ManagedObjectWrapperHolder managedObjectWrapper = _managedObjectWrapperTable.GetOrAdd(instance, static (c, items) =>
{
managedObjectWrapper.AddRef();
return managedObjectWrapper.ComIp;
}

managedObjectWrapper = _managedObjectWrapperTable.GetValue(instance, (c) =>
{
ManagedObjectWrapper* value = CreateManagedObjectWrapper(c, flags);
ManagedObjectWrapper* value = items.This!.CreateManagedObjectWrapper(c, items.Flags);
return new ManagedObjectWrapperHolder(value, c);
});
}, new GetOrCreateComInterfaceForObjectParameters { This = this, Flags = flags });

managedObjectWrapper.AddRef();
return managedObjectWrapper.ComIp;
}
Expand Down Expand Up @@ -1069,15 +1070,11 @@ private void RegisterWrapperForObject(NativeObjectWrapper wrapper, object comPro
Debug.Assert(wrapper.ProxyHandle.Target == comProxy);
Debug.Assert(wrapper.IsUniqueInstance || _rcwCache.FindProxyForComInstance(wrapper.ExternalComObject) == comProxy);

if (s_nativeObjectWrapperTable.TryGetValue(comProxy, out NativeObjectWrapper? registeredWrapper)
&& registeredWrapper != wrapper)
{
Debug.Assert(registeredWrapper.ExternalComObject != wrapper.ExternalComObject);
wrapper.Release();
throw new NotSupportedException();
}
// Add the input wrapper bound to the COM proxy, if there isn't one already. If another thread raced
// against this one and this lost, we'd get the wrapper added from that thread instead.
NativeObjectWrapper registeredWrapper = s_nativeObjectWrapperTable.GetOrAdd(comProxy, wrapper);

registeredWrapper = GetValueFromRcwTable(comProxy, wrapper);
// We lost the race, so we cannot register the incoming wrapper with the target object
if (registeredWrapper != wrapper)
{
Debug.Assert(registeredWrapper.ExternalComObject != wrapper.ExternalComObject);
Expand All @@ -1091,9 +1088,6 @@ private void RegisterWrapperForObject(NativeObjectWrapper wrapper, object comPro
// TrackerObjectManager and we could end up missing a section of the object graph.
// This cache deduplicates, so it is okay that the wrapper will be registered multiple times.
AddWrapperToReferenceTrackerHandleCache(registeredWrapper);

// Separate out into a local function to avoid the closure and delegate allocation unless we need it.
static NativeObjectWrapper GetValueFromRcwTable(object userObject, NativeObjectWrapper newWrapper) => s_nativeObjectWrapperTable.GetValue(userObject, _ => newWrapper);
}

private static void AddWrapperToReferenceTrackerHandleCache(NativeObjectWrapper wrapper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ private static IntPtr CreateReferenceTrackingHandleInternal(
throw new InvalidOperationException(SR.InvalidOperation_ObjectiveCTypeNoFinalizer);
}

var trackerInfo = s_objects.GetValue(obj, static o => new ObjcTrackingInformation());
var trackerInfo = s_objects.GetOrAdd(obj, static o => new ObjcTrackingInformation());
trackerInfo.EnsureInitialized(obj);
trackerInfo.GetTaggedMemory(out memInSizeT, out mem);
return RuntimeImports.RhHandleAllocRefCounted(obj);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ public static unsafe IntPtr GetFunctionPointerForDelegate(Delegate del)
//
// Marshalling a managed delegate created from managed code into a native function pointer
//
return GetPInvokeDelegates().GetValue(del, s_AllocateThunk ??= AllocateThunk).Thunk;
return GetPInvokeDelegates().GetOrAdd(del, s_AllocateThunk ??= AllocateThunk).Thunk;
}
}

/// <summary>
/// Used to lookup whether a delegate already has thunk allocated for it
/// </summary>
private static ConditionalWeakTable<Delegate, PInvokeDelegateThunk> s_pInvokeDelegates;
private static ConditionalWeakTable<Delegate, PInvokeDelegateThunk>.CreateValueCallback s_AllocateThunk;
private static Func<Delegate, PInvokeDelegateThunk> s_AllocateThunk;

private static ConditionalWeakTable<Delegate, PInvokeDelegateThunk> GetPInvokeDelegates()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ public static partial class Monitor
#region Object->Lock/Condition mapping

private static readonly ConditionalWeakTable<object, Condition> s_conditionTable = new ConditionalWeakTable<object, Condition>();
private static readonly ConditionalWeakTable<object, Condition>.CreateValueCallback s_createCondition = (o) => new Condition(ObjectHeader.GetLockObject(o));
private static readonly Func<object, Condition> s_createCondition = (o) => new Condition(ObjectHeader.GetLockObject(o));

private static Condition GetCondition(object obj)
{
Debug.Assert(
!(obj is Condition),
"Do not use Monitor.Pulse or Wait on a Condition instance; use the methods on Condition instead.");
return s_conditionTable.GetValue(obj, s_createCondition);
return s_conditionTable.GetOrAdd(obj, s_createCondition);
}
#endregion

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public static SocketError Set(SafeSocketHandle handle, SocketOptionName optionNa

public static SocketError Set(SafeSocketHandle handle, SocketOptionName optionName, int optionValueSeconds)
{
IOControlKeepAlive ioControlKeepAlive = s_socketKeepAliveTable.GetValue(handle, (SafeSocketHandle handle) => new IOControlKeepAlive());
IOControlKeepAlive ioControlKeepAlive = s_socketKeepAliveTable.GetOrAdd(handle, (SafeSocketHandle handle) => new IOControlKeepAlive());
if (optionName == SocketOptionName.TcpKeepAliveTime)
{
ioControlKeepAlive._timeMs = SecondsToMilliseconds(optionValueSeconds);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Numerics;
Expand Down Expand Up @@ -188,36 +189,125 @@ public void Clear()
}

/// <summary>
/// Atomically searches for a specified key in the table and returns the corresponding value.
/// If the key does not exist in the table, the method invokes a callback method to create a
/// value that is bound to the specified key.
/// Searches for a specified key in the table and returns the corresponding value. If the key does
/// not exist in the table, the method adds the given value and binds it to the specified key.
/// </summary>
/// <param name="key">The key of the value to find. It cannot be <see langword="null"/>.</param>
/// <param name="value">The value to add and bind to <typeparamref name="TKey"/>, if one does not exist already.</param>
/// <returns>The value bound to <typeparamref name="TKey"/> in the current <see cref="ConditionalWeakTable{TKey, TValue}"/> instance, after the method completes.</returns>
/// <exception cref="ArgumentNullException"><paramref name="key"/> is <see langword="null"/>.</exception>
public TValue GetOrAdd(TKey key, TValue value)
{
// key is validated by TryGetValue
if (TryGetValue(key, out TValue? existingValue))
{
return existingValue;
}

return GetOrAddLocked(key, value);
}

/// <summary>
/// Searches for a specified key in the table and returns the corresponding value. If the key does not exist
/// in the table, the method invokes the supplied factory to create a value that is bound to the specified key.
/// </summary>
/// <param name="key">The key of the value to find. It cannot be <see langword="null"/>.</param>
/// <param name="valueFactory">The callback that creates a value for key, if one does not exist already. It cannot be <see langword="null"/>.</param>
/// <returns>The value bound to <typeparamref name="TKey"/> in the current <see cref="ConditionalWeakTable{TKey, TValue}"/> instance, after the method completes.</returns>
/// <exception cref="ArgumentNullException"><paramref name="key"/> or <paramref name="valueFactory"/> are <see langword="null"/>.</exception>
/// <remarks>
/// If multiple threads try to initialize the same key, the table may invoke <paramref name="valueFactory"/> multiple times
/// with the same key. Exactly one of these calls will succeed and the returned value of that call will be the one added to
/// the table and returned by all the racing <see cref="GetOrAdd(TKey, Func{TKey, TValue})"/> calls. This rule permits the
/// table to invoke <paramref name="valueFactory"/> outside the internal table lock, to prevent deadlocks.
/// </remarks>
public TValue GetOrAdd(TKey key, Func<TKey, TValue> valueFactory)
{
ArgumentNullException.ThrowIfNull(valueFactory);

// key is validated by TryGetValue
if (TryGetValue(key, out TValue? existingValue))
{
return existingValue;
}

// create the value outside of the lock
TValue value = valueFactory(key);

return GetOrAddLocked(key, value);
}

/// <summary>
/// Searches for a specified key in the table and returns the corresponding value. If the key does not exist
/// in the table, the method invokes the supplied factory to create a value that is bound to the specified key.
/// </summary>
/// <typeparam name="TArg">The type of the additional argument to use with the value factory.</typeparam>
/// <param name="key">The key of the value to find. It cannot be <see langword="null"/>.</param>
/// <param name="valueFactory">The callback that creates a value for key, if one does not exist already. It cannot be <see langword="null"/>.</param>
/// <param name="factoryArgument">The additional argument to supply to <paramref name="valueFactory"/> upon invocation.</param>
/// <returns>The value bound to <typeparamref name="TKey"/> in the current <see cref="ConditionalWeakTable{TKey, TValue}"/> instance, after the method completes.</returns>
/// <exception cref="ArgumentNullException"><paramref name="key"/> or <paramref name="valueFactory"/> are <see langword="null"/>.</exception>
/// <remarks>
/// If multiple threads try to initialize the same key, the table may invoke <paramref name="valueFactory"/> multiple times with the
/// same key. Exactly one of these calls will succeed and the returned value of that call will be the one added to the table and
/// returned by all the racing <see cref="GetOrAdd{TArg}(TKey, Func{TKey, TArg, TValue}, TArg)"/> calls. This rule permits the
/// table to invoke <paramref name="valueFactory"/> outside the internal table lock, to prevent deadlocks.
/// </remarks>
public TValue GetOrAdd<TArg>(TKey key, Func<TKey, TArg, TValue> valueFactory, TArg factoryArgument)
where TArg : allows ref struct
{
ArgumentNullException.ThrowIfNull(valueFactory);

// key is validated by TryGetValue
if (TryGetValue(key, out TValue? existingValue))
{
return existingValue;
}

// create the value outside of the lock
TValue value = valueFactory(key, factoryArgument);

return GetOrAddLocked(key, value);
}

/// <summary>
/// Searches for a specified key in the table and returns the corresponding value. If the key does not exist
/// in the table, the method invokes a callback method to create a value that is bound to the specified key.
/// </summary>
/// <param name="key">key of the value to find. Cannot be null.</param>
/// <param name="createValueCallback">callback that creates value for key. Cannot be null.</param>
/// <returns></returns>
/// <remarks>
/// <para>
/// If multiple threads try to initialize the same key, the table may invoke createValueCallback
/// multiple times with the same key. Exactly one of these calls will succeed and the returned
/// value of that call will be the one added to the table and returned by all the racing GetValue() calls.
/// This rule permits the table to invoke createValueCallback outside the internal table lock
/// to prevent deadlocks.
/// </para>
/// <para>
/// Consider using <see cref="GetOrAdd(TKey, Func{TKey, TValue})"/> (or one of its overloads) instead.
/// </para>
/// </remarks>
[EditorBrowsable(EditorBrowsableState.Never)]
public TValue GetValue(TKey key, CreateValueCallback createValueCallback)
{
ArgumentNullException.ThrowIfNull(createValueCallback);

// key is validated by TryGetValue
return TryGetValue(key, out TValue? existingValue) ?
existingValue :
GetValueLocked(key, createValueCallback);
if (TryGetValue(key, out TValue? existingValue))
{
return existingValue;
}

// create the value outside of the lock
TValue value = createValueCallback(key);

return GetOrAddLocked(key, value);
}

private TValue GetValueLocked(TKey key, CreateValueCallback createValueCallback)
private TValue GetOrAddLocked(TKey key, TValue value)
{
// If we got here, the key was not in the table. Invoke the callback (outside the lock)
// to generate the new value for the key.
TValue newValue = createValueCallback(key);

lock (_lock)
{
// Now that we've taken the lock, must recheck in case we lost a race to add the key.
Expand All @@ -228,8 +318,8 @@ private TValue GetValueLocked(TKey key, CreateValueCallback createValueCallback)
else
{
// Verified in-lock that we won the race to add the key. Add it now.
CreateEntry(key, newValue);
return newValue;
CreateEntry(key, value);
return value;
}
}
}
Expand All @@ -239,8 +329,13 @@ private TValue GetValueLocked(TKey key, CreateValueCallback createValueCallback)
/// to create new instances as needed. If TValue does not have a default constructor, this will throw.
/// </summary>
/// <param name="key">key of the value to find. Cannot be null.</param>
/// <remarks>
/// Consider using <see cref="GetOrAdd(TKey, Func{TKey, TValue})"/> (or one of its overloads) instead.
/// </remarks>
[EditorBrowsable(EditorBrowsableState.Never)]
public TValue GetOrCreateValue(TKey key) => GetValue(key, _ => Activator.CreateInstance<TValue>());

[EditorBrowsable(EditorBrowsableState.Never)]
public delegate TValue CreateValueCallback(TKey key);

/// <summary>Gets an enumerator for the table.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ internal static object CreateProxyInstance(
AssemblyLoadContext? alc = AssemblyLoadContext.GetLoadContext(baseType.Assembly);
Debug.Assert(alc != null);

ProxyAssembly proxyAssembly = s_alcProxyAssemblyMap.GetValue(alc, static x => new ProxyAssembly(x));
ProxyAssembly proxyAssembly = s_alcProxyAssemblyMap.GetOrAdd(alc, static x => new ProxyAssembly(x));
GeneratedTypeInfo proxiedType = proxyAssembly.GetProxyType(baseType, interfaceType, interfaceParameter, proxyParameter);
return Activator.CreateInstance(proxiedType.GeneratedType, new object[] { proxiedType.MethodInfos })!;
}
Expand Down
6 changes: 6 additions & 0 deletions src/libraries/System.Runtime/ref/System.Runtime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13276,13 +13276,19 @@ public ConditionalWeakTable() { }
public void Add(TKey key, TValue value) { }
public void AddOrUpdate(TKey key, TValue value) { }
public void Clear() { }
public TValue GetOrAdd(TKey key, TValue value) { throw null; }
public TValue GetOrAdd(TKey key, System.Func<TKey, TValue> valueFactory) { throw null; }
public TValue GetOrAdd<TArg>(TKey key, System.Func<TKey, TArg, TValue> valueFactory, TArg factoryArgument) where TArg : allows ref struct { throw null; }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public TValue GetOrCreateValue(TKey key) { throw null; }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public TValue GetValue(TKey key, System.Runtime.CompilerServices.ConditionalWeakTable<TKey, TValue>.CreateValueCallback createValueCallback) { throw null; }
public bool Remove(TKey key) { throw null; }
System.Collections.Generic.IEnumerator<System.Collections.Generic.KeyValuePair<TKey, TValue>> System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<TKey, TValue>>.GetEnumerator() { throw null; }
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; }
public bool TryAdd(TKey key, TValue value) { throw null; }
public bool TryGetValue(TKey key, [System.Diagnostics.CodeAnalysis.MaybeNullWhenAttribute(false)] out TValue value) { throw null; }
[System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Never)]
public delegate TValue CreateValueCallback(TKey key);
}
public readonly partial struct ConfiguredAsyncDisposable
Expand Down
Loading

0 comments on commit e66d834

Please sign in to comment.