Skip to content

Commit 59414b5

Browse files
authored
Specialize Contains for Iterators in LINQ (#112684)
* Specialize Contains for Iterators in LINQ It appears that Contains ends up being reasonably common after a series of LINQ operations, whether explicitly at the call site or because a method that returns an enumerable uses LINQ internally and then the call site does Contains. We can optimize Contains for a bunch of operators, just as we can for First/Last. In some cases, we can skip the operator completely, e.g. Contains on a Shuffle or OrderBy is no different from one on the underlying source, in other cases we can optimize by processing the source directly, e.g. a Contains on a Concat can end up doing a Contains on each source, which can in turn pick up vectorized implementations if those individual sources support them. Some of the operators actually already provided Contains implementations as part of implementing IList, and this just makes those implementations accessible. In other cases, new overrides of a new virtual Contains on Iterator are added.
1 parent 03324cd commit 59414b5

19 files changed

+570
-26
lines changed

src/libraries/System.Linq/src/System/Linq/AppendPrepend.SpeedOpt.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ public override int GetCount(bool onlyIfCheap)
176176

177177
return base.TryGetElementAt(index, out found);
178178
}
179+
180+
public override bool Contains(TSource value) =>
181+
EqualityComparer<TSource>.Default.Equals(_item, value) ||
182+
_source.Contains(value);
179183
}
180184

181185
private sealed partial class AppendPrependN<TSource>
@@ -278,6 +282,22 @@ public override int GetCount(bool onlyIfCheap)
278282

279283
return !onlyIfCheap || _source is ICollection<TSource> ? _source.Count() + _appendCount + _prependCount : -1;
280284
}
285+
286+
public override bool Contains(TSource value)
287+
{
288+
foreach (SingleLinkedNode<TSource>? head in (ReadOnlySpan<SingleLinkedNode<TSource>?>)[_appended, _prepended])
289+
{
290+
for (SingleLinkedNode<TSource>? node = head; node is not null; node = node.Linked)
291+
{
292+
if (EqualityComparer<TSource>.Default.Equals(node.Item, value))
293+
{
294+
return true;
295+
}
296+
}
297+
}
298+
299+
return _source.Contains(value);
300+
}
281301
}
282302
}
283303
}

src/libraries/System.Linq/src/System/Linq/Cast.SpeedOpt.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,19 @@ public override List<TResult> ToList()
110110
(e as IDisposable)?.Dispose();
111111
}
112112
}
113+
114+
public override bool Contains(TResult value)
115+
{
116+
foreach (TResult item in _source)
117+
{
118+
if (EqualityComparer<TResult>.Default.Equals(item, value))
119+
{
120+
return true;
121+
}
122+
}
123+
124+
return false;
125+
}
113126
}
114127
}
115128
}

src/libraries/System.Linq/src/System/Linq/Concat.SpeedOpt.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ public override TSource[] ToArray()
148148

149149
return result;
150150
}
151+
152+
public override bool Contains(TSource value) =>
153+
_first.Contains(value) ||
154+
_second.Contains(value);
151155
}
152156

153157
private sealed partial class ConcatNIterator<TSource> : ConcatIterator<TSource>
@@ -342,6 +346,23 @@ private TSource[] PreallocatingToArray()
342346
Debug.Assert(node._tail is Concat2Iterator<TSource>);
343347
return node._tail.TryGetLast(out found);
344348
}
349+
350+
public override bool Contains(TSource value)
351+
{
352+
ConcatNIterator<TSource>? node, previousN = this;
353+
do
354+
{
355+
node = previousN;
356+
if (node._head.Contains(value))
357+
{
358+
return true;
359+
}
360+
}
361+
while ((previousN = node.PreviousN) is not null);
362+
363+
Debug.Assert(node._tail is Concat2Iterator<TSource>);
364+
return node._tail.Contains(value);
365+
}
345366
}
346367

347368
private abstract partial class ConcatIterator<TSource>
@@ -364,7 +385,6 @@ public override List<TSource> ToList()
364385

365386
return list;
366387
}
367-
368388
}
369389
}
370390
}

src/libraries/System.Linq/src/System/Linq/Contains.cs

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,25 @@ namespace System.Linq
77
{
88
public static partial class Enumerable
99
{
10-
public static bool Contains<TSource>(this IEnumerable<TSource> source, TSource value) =>
11-
source is ICollection<TSource> collection ? collection.Contains(value) :
12-
Contains(source, value, null);
10+
public static bool Contains<TSource>(this IEnumerable<TSource> source, TSource value)
11+
{
12+
if (source is ICollection<TSource> collection)
13+
{
14+
return collection.Contains(value);
15+
}
16+
17+
if (!IsSizeOptimized && source is Iterator<TSource> iterator)
18+
{
19+
return iterator.Contains(value);
20+
}
21+
22+
if (source is null)
23+
{
24+
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
25+
}
26+
27+
return ContainsIterate(source, value, null);
28+
}
1329

1430
public static bool Contains<TSource>(this IEnumerable<TSource> source, TSource value, IEqualityComparer<TSource>? comparer)
1531
{
@@ -18,17 +34,22 @@ public static bool Contains<TSource>(this IEnumerable<TSource> source, TSource v
1834
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
1935
}
2036

37+
// While it's tempting, this must not delegate to ICollection<TSource>.Contains, as the historical semantics
38+
// of a null comparer with this method are to use EqualityComparer<TSource>.Default, and that might differ
39+
// from the semantics encoded in ICollection<TSource>.Contains.
40+
2141
if (source.TryGetSpan(out ReadOnlySpan<TSource> span))
2242
{
2343
return span.Contains(value, comparer);
2444
}
2545

46+
return ContainsIterate(source, value, comparer);
47+
}
48+
49+
private static bool ContainsIterate<TSource>(IEnumerable<TSource> source, TSource value, IEqualityComparer<TSource>? comparer)
50+
{
2651
if (comparer is null)
2752
{
28-
// While it's tempting, this must not delegate to ICollection<TSource>.Contains, as the historical semantics
29-
// of a null comparer with this method are to use EqualityComparer<TSource>.Default, and that might differ
30-
// from the semantics encoded in ICollection<TSource>.Contains.
31-
3253
if (typeof(TSource).IsValueType)
3354
{
3455
foreach (TSource element in source)
@@ -44,6 +65,7 @@ public static bool Contains<TSource>(this IEnumerable<TSource> source, TSource v
4465

4566
comparer = EqualityComparer<TSource>.Default;
4667
}
68+
4769
foreach (TSource element in source)
4870
{
4971
if (comparer.Equals(element, value))

src/libraries/System.Linq/src/System/Linq/DefaultIfEmpty.SpeedOpt.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,40 @@ public override int GetCount(bool onlyIfCheap)
8181

8282
return _default;
8383
}
84+
85+
public override bool Contains(TSource value)
86+
{
87+
if (_source.TryGetNonEnumeratedCount(out int count))
88+
{
89+
return count > 0 ?
90+
_source.Contains(value) :
91+
EqualityComparer<TSource>.Default.Equals(value, _default);
92+
}
93+
94+
IEnumerator<TSource> enumerator = _source.GetEnumerator();
95+
try
96+
{
97+
if (!enumerator.MoveNext())
98+
{
99+
return EqualityComparer<TSource>.Default.Equals(value, _default);
100+
}
101+
102+
do
103+
{
104+
if (EqualityComparer<TSource>.Default.Equals(enumerator.Current, value))
105+
{
106+
return true;
107+
}
108+
}
109+
while (enumerator.MoveNext());
110+
111+
return false;
112+
}
113+
finally
114+
{
115+
enumerator.Dispose();
116+
}
117+
}
84118
}
85119
}
86120
}

src/libraries/System.Linq/src/System/Linq/Distinct.SpeedOpt.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ private sealed partial class DistinctIterator<TSource>
1616
public override int GetCount(bool onlyIfCheap) => onlyIfCheap ? -1 : new HashSet<TSource>(_source, _comparer).Count;
1717

1818
public override TSource? TryGetFirst(out bool found) => _source.TryGetFirst(out found);
19+
20+
public override bool Contains(TSource value) => _source.Contains(value);
1921
}
2022
}
2123
}

src/libraries/System.Linq/src/System/Linq/Iterator.SpeedOpt.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ private abstract partial class Iterator<TSource>
6666
/// <param name="found"><c>true</c> if the sequence contains an element, <c>false</c> otherwise.</param>
6767
/// <returns>The element if <paramref name="found"/> is <c>true</c>, otherwise, the default value of <typeparamref name="TSource"/>.</returns>
6868
public virtual TSource? TryGetLast(out bool found) => TryGetLastNonIterator(this, out found);
69+
70+
public virtual bool Contains(TSource value) => ContainsIterate(this, value, null);
6971
}
7072
}
7173
}

src/libraries/System.Linq/src/System/Linq/OfType.SpeedOpt.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,25 @@ public override IEnumerable<TResult2> Select<TResult2>(Func<TResult, TResult2> s
174174

175175
return base.Select(selector);
176176
}
177+
178+
public override bool Contains(TResult value)
179+
{
180+
if (!typeof(TResult).IsValueType && // don't box TResult
181+
_source is IList list)
182+
{
183+
return list.Contains(value);
184+
}
185+
186+
foreach (object? item in _source)
187+
{
188+
if (item is TResult castItem && EqualityComparer<TResult>.Default.Equals(castItem, value))
189+
{
190+
return true;
191+
}
192+
}
193+
194+
return false;
195+
}
177196
}
178197
}
179198
}

src/libraries/System.Linq/src/System/Linq/OrderedEnumerable.SpeedOpt.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ internal int GetCount(int minIdx, int maxIdx, bool onlyIfCheap)
227227
return default;
228228
}
229229

230+
public override bool Contains(TElement value) =>
231+
_source.Contains(value);
232+
230233
private TElement Last(TElement[] items)
231234
{
232235
CachingComparer<TElement> comparer = GetComparer();

src/libraries/System.Linq/src/System/Linq/Range.SpeedOpt.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public override int TryGetLast(out bool found)
8282
return _end - 1;
8383
}
8484

85-
public bool Contains(int item) =>
85+
public override bool Contains(int item) =>
8686
(uint)(item - _start) < (uint)(_end - _start);
8787

8888
public int IndexOf(int item) =>

src/libraries/System.Linq/src/System/Linq/Repeat.SpeedOpt.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public override TResult TryGetLast(out bool found)
8181
return _current;
8282
}
8383

84-
public bool Contains(TResult item)
84+
public override bool Contains(TResult item)
8585
{
8686
Debug.Assert(_count > 0);
8787
return EqualityComparer<TResult>.Default.Equals(_current, item);

src/libraries/System.Linq/src/System/Linq/Reverse.SpeedOpt.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ public override int GetCount(bool onlyIfCheap) =>
116116
found = false;
117117
return default;
118118
}
119+
120+
public override bool Contains(TSource value) => _source.Contains(value);
119121
}
120122
}
121123
}

src/libraries/System.Linq/src/System/Linq/Select.SpeedOpt.cs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,19 @@ public override TResult TryGetLast(out bool found)
231231
found = true;
232232
return _selector(_source[^1]);
233233
}
234+
235+
public override bool Contains(TResult value)
236+
{
237+
foreach (TSource item in _source)
238+
{
239+
if (EqualityComparer<TResult>.Default.Equals(_selector(item), value))
240+
{
241+
return true;
242+
}
243+
}
244+
245+
return false;
246+
}
234247
}
235248

236249
private sealed partial class RangeSelectIterator<TResult> : Iterator<TResult>
@@ -358,6 +371,19 @@ public override TResult TryGetLast(out bool found)
358371
found = true;
359372
return _selector(_end - 1);
360373
}
374+
375+
public override bool Contains(TResult value)
376+
{
377+
for (int i = _start; i != _end; i++)
378+
{
379+
if (EqualityComparer<TResult>.Default.Equals(_selector(i), value))
380+
{
381+
return true;
382+
}
383+
}
384+
385+
return false;
386+
}
361387
}
362388

363389
private sealed partial class ListSelectIterator<TSource, TResult>
@@ -460,6 +486,21 @@ public override Iterator<TResult> Take(int count)
460486
found = false;
461487
return default;
462488
}
489+
490+
public override bool Contains(TResult value)
491+
{
492+
int count = _source.Count;
493+
494+
for (int i = 0; i < count; i++)
495+
{
496+
if (EqualityComparer<TResult>.Default.Equals(_selector(_source[i]), value))
497+
{
498+
return true;
499+
}
500+
}
501+
502+
return false;
503+
}
463504
}
464505

465506
private sealed partial class IListSelectIterator<TSource, TResult>
@@ -563,6 +604,21 @@ public override Iterator<TResult> Take(int count)
563604
found = false;
564605
return default;
565606
}
607+
608+
public override bool Contains(TResult value)
609+
{
610+
int count = _source.Count;
611+
612+
for (int i = 0; i < count; i++)
613+
{
614+
if (EqualityComparer<TResult>.Default.Equals(_selector(_source[i]), value))
615+
{
616+
return true;
617+
}
618+
}
619+
620+
return false;
621+
}
566622
}
567623

568624
/// <summary>
@@ -943,6 +999,22 @@ public override int GetCount(bool onlyIfCheap)
943999

9441000
return count;
9451001
}
1002+
1003+
public override bool Contains(TResult value)
1004+
{
1005+
int count = Count;
1006+
1007+
int end = _minIndexInclusive + count;
1008+
for (int i = _minIndexInclusive; i != end; ++i)
1009+
{
1010+
if (EqualityComparer<TResult>.Default.Equals(_selector(_source[i]), value))
1011+
{
1012+
return true;
1013+
}
1014+
}
1015+
1016+
return false;
1017+
}
9461018
}
9471019
}
9481020
}

0 commit comments

Comments
 (0)