Skip to content

Commit 297fdf3

Browse files
committed
Avoid CollectionsMarshal.AsSpan for derived List<T> in Linq
`CollectionsMarshal.AsSpan` relies on the exact internal layout of `List<T>`. Passing a subclass is unsafe because a derived list may reimplement `IEnumerable<T>` with altered enumeration semantics.
1 parent b622b6b commit 297fdf3

File tree

4 files changed

+31
-46
lines changed

4 files changed

+31
-46
lines changed

src/libraries/System.Linq/src/System/Linq/Enumerable.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ internal static bool TryGetSpan<TSource>(this IEnumerable<TSource> source, out R
5555
{
5656
span = Unsafe.As<TSource[]>(source);
5757
}
58-
else if (source.GetType() == typeof(List<TSource>))
58+
else if (source.GetType() == typeof(List<TSource>)) // avoid accidentally bypassing a derived type's reimplementation of IEnumerable<T>
5959
{
6060
span = CollectionsMarshal.AsSpan(Unsafe.As<List<TSource>>(source));
6161
}

src/libraries/System.Linq/src/System/Linq/Select.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Collections.Generic;
55
using System.Diagnostics;
66
using System.Diagnostics.CodeAnalysis;
7+
using System.Runtime.CompilerServices;
78
using static System.Linq.Utilities;
89

910
namespace System.Linq
@@ -52,9 +53,9 @@ public static IEnumerable<TResult> Select<TSource, TResult>(
5253
return new ArraySelectIterator<TSource, TResult>(array, selector);
5354
}
5455

55-
if (source is List<TSource> list)
56+
if (source.GetType() == typeof(List<TSource>)) // avoid accidentally bypassing a derived type's reimplementation of IEnumerable<T>
5657
{
57-
return new ListSelectIterator<TSource, TResult>(list, selector);
58+
return new ListSelectIterator<TSource, TResult>(Unsafe.As<List<TSource>>(source), selector);
5859
}
5960

6061
return new IListSelectIterator<TSource, TResult>(ilist, selector);

src/libraries/System.Linq/src/System/Linq/ToCollection.cs

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -148,42 +148,34 @@ public static Dictionary<TKey, TSource> ToDictionary<TSource, TKey>(this IEnumer
148148
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector);
149149
}
150150

151+
Dictionary<TKey, TSource> dict;
152+
151153
if (source.TryGetNonEnumeratedCount(out int capacity))
152154
{
153155
if (capacity == 0)
154156
{
155157
return new Dictionary<TKey, TSource>(comparer);
156158
}
157159

158-
if (source is TSource[] array)
160+
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
159161
{
160-
return SpanToDictionary(array, keySelector, comparer);
161-
}
162+
dict = new Dictionary<TKey, TSource>(span.Length, comparer);
163+
foreach (TSource element in span)
164+
{
165+
dict.Add(keySelector(element), element);
166+
}
162167

163-
if (source is List<TSource> list)
164-
{
165-
ReadOnlySpan<TSource> span = CollectionsMarshal.AsSpan(list);
166-
return SpanToDictionary(span, keySelector, comparer);
168+
return dict;
167169
}
168170
}
169171

170-
Dictionary<TKey, TSource> d = new Dictionary<TKey, TSource>(capacity, comparer);
172+
dict = new Dictionary<TKey, TSource>(capacity, comparer);
171173
foreach (TSource element in source)
172174
{
173-
d.Add(keySelector(element), element);
175+
dict.Add(keySelector(element), element);
174176
}
175177

176-
return d;
177-
}
178-
179-
private static Dictionary<TKey, TSource> SpanToDictionary<TSource, TKey>(ReadOnlySpan<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? comparer) where TKey : notnull
180-
{
181-
Dictionary<TKey, TSource> d = new Dictionary<TKey, TSource>(source.Length, comparer);
182-
foreach (TSource element in source)
183-
{
184-
d.Add(keySelector(element), element);
185-
}
186-
return d;
178+
return dict;
187179
}
188180

189181
public static Dictionary<TKey, TElement> ToDictionary<TSource, TKey, TElement>(this IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector) where TKey : notnull =>
@@ -206,42 +198,33 @@ public static Dictionary<TKey, TElement> ToDictionary<TSource, TKey, TElement>(t
206198
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.elementSelector);
207199
}
208200

201+
Dictionary<TKey, TElement> dict;
202+
209203
if (source.TryGetNonEnumeratedCount(out int capacity))
210204
{
211205
if (capacity == 0)
212206
{
213207
return new Dictionary<TKey, TElement>(comparer);
214208
}
215209

216-
if (source is TSource[] array)
210+
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
217211
{
218-
return SpanToDictionary(array, keySelector, elementSelector, comparer);
219-
}
220-
221-
if (source is List<TSource> list)
222-
{
223-
ReadOnlySpan<TSource> span = CollectionsMarshal.AsSpan(list);
224-
return SpanToDictionary(span, keySelector, elementSelector, comparer);
212+
dict = new Dictionary<TKey, TElement>(span.Length, comparer);
213+
foreach (TSource element in span)
214+
{
215+
dict.Add(keySelector(element), elementSelector(element));
216+
}
217+
return dict;
225218
}
226219
}
227220

228-
Dictionary<TKey, TElement> d = new Dictionary<TKey, TElement>(capacity, comparer);
221+
dict = new Dictionary<TKey, TElement>(capacity, comparer);
229222
foreach (TSource element in source)
230223
{
231-
d.Add(keySelector(element), elementSelector(element));
224+
dict.Add(keySelector(element), elementSelector(element));
232225
}
233226

234-
return d;
235-
}
236-
237-
private static Dictionary<TKey, TElement> SpanToDictionary<TSource, TKey, TElement>(ReadOnlySpan<TSource> source, Func<TSource, TKey> keySelector, Func<TSource, TElement> elementSelector, IEqualityComparer<TKey>? comparer) where TKey : notnull
238-
{
239-
Dictionary<TKey, TElement> d = new Dictionary<TKey, TElement>(source.Length, comparer);
240-
foreach (TSource element in source)
241-
{
242-
d.Add(keySelector(element), elementSelector(element));
243-
}
244-
return d;
227+
return dict;
245228
}
246229

247230
public static HashSet<TSource> ToHashSet<TSource>(this IEnumerable<TSource> source) => source.ToHashSet(comparer: null);

src/libraries/System.Linq/src/System/Linq/Where.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System.Collections.Generic;
55
using System.Diagnostics;
6+
using System.Runtime.CompilerServices;
67
using static System.Linq.Utilities;
78

89
namespace System.Linq
@@ -36,9 +37,9 @@ public static IEnumerable<TSource> Where<TSource>(this IEnumerable<TSource> sour
3637
return new ArrayWhereIterator<TSource>(array, predicate);
3738
}
3839

39-
if (source is List<TSource> list)
40+
if (source.GetType() == typeof(List<TSource>))
4041
{
41-
return new ListWhereIterator<TSource>(list, predicate);
42+
return new ListWhereIterator<TSource>(Unsafe.As<List<TSource>>(source), predicate);
4243
}
4344

4445
return new IEnumerableWhereIterator<TSource>(source, predicate);

0 commit comments

Comments
 (0)