Skip to content

Commit 6885a7d

Browse files
committed
Avoid CollectionsMarshal.AsSpan for derived List<T>
* `Task.WaitAll` * `Task.Whenall` * `string.Join` Related: dotnet#118682
1 parent 7bc40a5 commit 6885a7d

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

src/libraries/System.Private.CoreLib/src/System/String.Manipulation.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -784,19 +784,19 @@ private static string JoinCore(ReadOnlySpan<char> separator, string?[] value, in
784784

785785
public static string Join(string? separator, IEnumerable<string?> values)
786786
{
787-
if (values is List<string?> valuesList)
787+
if (values is null)
788788
{
789-
return JoinCore(separator.AsSpan(), CollectionsMarshal.AsSpan(valuesList));
789+
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.values);
790790
}
791791

792-
if (values is string?[] valuesArray)
792+
if (values.GetType() == typeof(List<string?>)) // avoid accidentally bypassing a derived type's reimplementation of IEnumerable<T>
793793
{
794-
return JoinCore(separator.AsSpan(), new ReadOnlySpan<string?>(valuesArray));
794+
return JoinCore(separator.AsSpan(), CollectionsMarshal.AsSpan((List<string?>)values));
795795
}
796796

797-
if (values == null)
797+
if (values is string?[] valuesArray)
798798
{
799-
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.values);
799+
return JoinCore(separator.AsSpan(), new ReadOnlySpan<string?>(valuesArray));
800800
}
801801

802802
using (IEnumerator<string?> en = values.GetEnumerator())

src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4936,10 +4936,15 @@ public static void WaitAll(IEnumerable<Task> tasks, CancellationToken cancellati
49364936
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.tasks);
49374937
}
49384938

4939-
ReadOnlySpan<Task> span =
4940-
tasks is List<Task> list ? CollectionsMarshal.AsSpan(list) :
4941-
tasks is Task[] array ? array :
4942-
CollectionsMarshal.AsSpan(new List<Task>(tasks));
4939+
ReadOnlySpan<Task> span;
4940+
if (tasks.GetType() == typeof(List<Task>)) // avoid accidentally bypassing a derived type's reimplementation of IEnumerable<T>
4941+
{
4942+
span = CollectionsMarshal.AsSpan((List<Task>)tasks);
4943+
}
4944+
else
4945+
{
4946+
span = tasks is Task[] array ? array : CollectionsMarshal.AsSpan(new List<Task>(tasks));
4947+
}
49434948

49444949
WaitAllCore(span, Timeout.Infinite, cancellationToken);
49454950
}
@@ -5933,26 +5938,26 @@ public static Task WhenAll(IEnumerable<Task> tasks)
59335938
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.tasks);
59345939
}
59355940

5936-
int? count = null;
5941+
if (tasks.GetType() == typeof(List<Task>)) // avoid accidentally bypassing a derived type's reimplementation of IEnumerable<T>
5942+
{
5943+
return WhenAll(CollectionsMarshal.AsSpan((List<Task>)tasks));
5944+
}
5945+
5946+
int capacity = 0;
59375947
if (tasks is ICollection<Task> taskCollection)
59385948
{
59395949
if (tasks is Task[] taskArray)
59405950
{
59415951
return WhenAll((ReadOnlySpan<Task>)taskArray);
59425952
}
59435953

5944-
if (tasks is List<Task> taskList)
5945-
{
5946-
return WhenAll(CollectionsMarshal.AsSpan(taskList));
5947-
}
5948-
5949-
count = taskCollection.Count;
5954+
capacity = taskCollection.Count;
59505955
}
59515956

59525957
// Buffer the tasks into a temporary span. Small sets of tasks are common,
59535958
// so for <= 8 we stack allocate.
5954-
ValueListBuilder<Task> builder = count is > 8 ?
5955-
new ValueListBuilder<Task>(count.Value) :
5959+
ValueListBuilder<Task> builder = capacity is > 8 ?
5960+
new ValueListBuilder<Task>(capacity) :
59565961
new ValueListBuilder<Task>([null, null, null, null, null, null, null, null]);
59575962
foreach (Task task in tasks)
59585963
{

0 commit comments

Comments
 (0)