Skip to content

Add CollectionsMarshal.GetValueRefOrAddDefault #54611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ public static void ComparerImplementations_Dictionary_WithWellKnownStringCompare
expectedInternalComparerTypeBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(),
expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture,
expectedInternalComparerTypeAfterCollisionThreshold: StringComparer.InvariantCulture.GetType());

// CollectionsMarshal.GetValueRefOrAddDefault

RunCollectionTestCommon(
() => new Dictionary<string, object>(StringComparer.Ordinal),
(dictionary, key) => CollectionsMarshal.GetValueRefOrAddDefault(dictionary, key, out _) = null,
(dictionary, key) => dictionary.ContainsKey(key),
dictionary => dictionary.Comparer,
expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalComparerType,
expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal,
expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalComparerType);

static void RunDictionaryTest(
IEqualityComparer<string> equalityComparer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ private int Initialize(int capacity)

private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
{
// NOTE: this method is mirrored in CollectionsMarshal.GetValueRefOrAddDefault below.
// If you make any changes here, make sure to keep that version in sync as well.

if (key == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
Expand Down Expand Up @@ -681,6 +684,190 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
return true;
}

/// <summary>
/// A helper class containing APIs exposed through <see cref="Runtime.InteropServices.CollectionsMarshal"/>.
/// These methods are relatively niche and only used in specific scenarios, so adding them in a separate type avoids
/// the additional overhead on each <see cref="Dictionary{TKey, TValue}"/> instantiation, especially in AOT scenarios.
/// </summary>
internal static class CollectionsMarshalHelper
{
/// <inheritdoc cref="Runtime.InteropServices.CollectionsMarshal.GetValueRefOrAddDefault{TKey, TValue}(Dictionary{TKey, TValue}, TKey, out bool)"/>
public static ref TValue? GetValueRefOrAddDefault(Dictionary<TKey, TValue> dictionary, TKey key, out bool exists)
{
// NOTE: this method is mirrored by Dictionary<TKey, TValue>.TryInsert above.
// If you make any changes here, make sure to keep that version in sync as well.

if (key == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

if (dictionary._buckets == null)
{
dictionary.Initialize(0);
}
Debug.Assert(dictionary._buckets != null);

Entry[]? entries = dictionary._entries;
Debug.Assert(entries != null, "expected entries to be non-null");

IEqualityComparer<TKey>? comparer = dictionary._comparer;
uint hashCode = (uint)((comparer == null) ? key.GetHashCode() : comparer.GetHashCode(key));

uint collisionCount = 0;
ref int bucket = ref dictionary.GetBucket(hashCode);
int i = bucket - 1; // Value in _buckets is 1-based

if (comparer == null)
{
if (typeof(TKey).IsValueType)
{
// ValueType: Devirtualize with EqualityComparer<TValue>.Default intrinsic
while (true)
{
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
// Test uint in if rather than loop condition to drop range check for following array access
if ((uint)i >= (uint)entries.Length)
{
break;
}

if (entries[i].hashCode == hashCode && EqualityComparer<TKey>.Default.Equals(entries[i].key, key))
{
exists = true;

return ref entries[i].value!;
}

i = entries[i].next;

collisionCount++;
if (collisionCount > (uint)entries.Length)
{
// The chain of entries forms a loop; which means a concurrent update has happened.
// Break out of the loop and throw, rather than looping forever.
ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
}
}
}
else
{
// Object type: Shared Generic, EqualityComparer<TValue>.Default won't devirtualize
// https://github.com/dotnet/runtime/issues/10050
// So cache in a local rather than get EqualityComparer per loop iteration
EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
while (true)
{
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
// Test uint in if rather than loop condition to drop range check for following array access
if ((uint)i >= (uint)entries.Length)
{
break;
}

if (entries[i].hashCode == hashCode && defaultComparer.Equals(entries[i].key, key))
{
exists = true;

return ref entries[i].value!;
}

i = entries[i].next;

collisionCount++;
if (collisionCount > (uint)entries.Length)
{
// The chain of entries forms a loop; which means a concurrent update has happened.
// Break out of the loop and throw, rather than looping forever.
ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
}
}
}
}
else
{
while (true)
{
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
// Test uint in if rather than loop condition to drop range check for following array access
if ((uint)i >= (uint)entries.Length)
{
break;
}

if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key))
{
exists = true;

return ref entries[i].value!;
}

i = entries[i].next;

collisionCount++;
if (collisionCount > (uint)entries.Length)
{
// The chain of entries forms a loop; which means a concurrent update has happened.
// Break out of the loop and throw, rather than looping forever.
ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
}
}
}

int index;
if (dictionary._freeCount > 0)
{
index = dictionary._freeList;
Debug.Assert((StartOfFreeList - entries[dictionary._freeList].next) >= -1, "shouldn't overflow because `next` cannot underflow");
dictionary._freeList = StartOfFreeList - entries[dictionary._freeList].next;
dictionary._freeCount--;
}
else
{
int count = dictionary._count;
if (count == entries.Length)
{
dictionary.Resize();
bucket = ref dictionary.GetBucket(hashCode);
}
index = count;
dictionary._count = count + 1;
entries = dictionary._entries;
}

ref Entry entry = ref entries![index];
entry.hashCode = hashCode;
entry.next = bucket - 1; // Value in _buckets is 1-based
entry.key = key;
entry.value = default!;
bucket = index + 1; // Value in _buckets is 1-based
dictionary._version++;

// Value types never rehash
if (!typeof(TKey).IsValueType && collisionCount > HashHelpers.HashCollisionThreshold && comparer is NonRandomizedStringEqualityComparer)
{
// If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing
// i.e. EqualityComparer<string>.Default.
dictionary.Resize(entries.Length, true);

exists = false;

// At this point the entries array has been resized, so the current reference we have is no longer valid.
// We're forced to do a new lookup and return an updated reference to the new entry instance. This new
// lookup is guaranteed to always find a value though and it will never return a null reference here.
ref TValue? value = ref dictionary.FindValue(key)!;

Debug.Assert(!Unsafe.IsNullRef(ref value), "the lookup result cannot be a null ref here");

return ref value;
}

exists = false;

return ref entry.value!;
}
}

public virtual void OnDeserialization(object? sender)
{
HashHelpers.SerializationInfoTable.TryGetValue(this, out SerializationInfo? siInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,15 @@ public static Span<T> AsSpan<T>(List<T>? list)
/// </remarks>
public static ref TValue GetValueRefOrNullRef<TKey, TValue>(Dictionary<TKey, TValue> dictionary, TKey key) where TKey : notnull
=> ref dictionary.FindValue(key);

/// <summary>
/// Gets a ref to a <typeparamref name="TValue"/> in the <see cref="Dictionary{TKey, TValue}"/>, adding a new entry with a default value if it does not exist in the <paramref name="dictionary"/>.
/// </summary>
/// <param name="dictionary">The dictionary to get the ref to <typeparamref name="TValue"/> from.</param>
/// <param name="key">The key used for lookup.</param>
/// <param name="exists">Whether or not a new entry for the given key was added to the dictionary.</param>
/// <remarks>Items should not be added to or removed from the <see cref="Dictionary{TKey, TValue}"/> while the ref <typeparamref name="TValue"/> is in use.</remarks>
public static ref TValue? GetValueRefOrAddDefault<TKey, TValue>(Dictionary<TKey, TValue> dictionary, TKey key, out bool exists) where TKey : notnull
=> ref Dictionary<TKey, TValue>.CollectionsMarshalHelper.GetValueRefOrAddDefault(dictionary, key, out exists);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ public static partial class CollectionsMarshal
{
public static System.Span<T> AsSpan<T>(System.Collections.Generic.List<T>? list) { throw null; }
public static ref TValue GetValueRefOrNullRef<TKey, TValue>(System.Collections.Generic.Dictionary<TKey, TValue> dictionary, TKey key) where TKey : notnull { throw null; }
public static ref TValue? GetValueRefOrAddDefault<TKey, TValue>(System.Collections.Generic.Dictionary<TKey, TValue> dictionary, TKey key, out bool exists) where TKey : notnull { throw null; }
}
[System.AttributeUsageAttribute(System.AttributeTargets.Field | System.AttributeTargets.Parameter | System.AttributeTargets.Property | System.AttributeTargets.ReturnValue, Inherited=false)]
public sealed partial class ComAliasNameAttribute : System.Attribute
Expand Down
Loading