Skip to content

Commit 61d986c

Browse files
authored
Fix LINQ handling of iterator.Take(...).Last(...) (#112680) (#112714)
When the Take amount is larger than the number of elements in the source `Iterator<T>`, Last ends up throwing an exception and LastOrDefault ends up returning the default value, rather than returning the last value in the taken region. As part of fixing this, I sured up the tests to try to cover more such sequences of operations. In doing so, the tests got a lot slower, so I tracked down and fixed places where we were doing a lot of unnecessary work.
1 parent 98c77e4 commit 61d986c

14 files changed

+707
-1305
lines changed

src/libraries/System.Linq/src/System/Linq/SkipTake.SpeedOpt.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,12 @@ public override Iterator<TSource> Take(int count)
430430
{
431431
if (_source is Iterator<TSource> iterator &&
432432
iterator.GetCount(onlyIfCheap: true) is int count &&
433-
count >= _minIndexInclusive)
433+
count > _minIndexInclusive)
434434
{
435-
return !HasLimit ?
435+
// If there's no upper bound, or if there are fewer items in the list
436+
// than the upper bound allows, just return the last element of the list.
437+
// Otherwise, get the element at the upper bound.
438+
return (uint)count <= (uint)_maxIndexInclusive ?
436439
iterator.TryGetLast(out found) :
437440
iterator.TryGetElementAt(_maxIndexInclusive, out found);
438441
}

src/libraries/System.Linq/tests/AggregateByTests.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ public class AggregateByTests : EnumerableTests
1111
[Fact]
1212
public void Empty()
1313
{
14-
Assert.All(IdentityTransforms<int>(), transform =>
14+
Assert.All(CreateSources<int>([]), source =>
1515
{
16-
Assert.Equal(Enumerable.Empty<KeyValuePair<int, int>>(), transform(Enumerable.Empty<int>()).AggregateBy(i => i, i => i, (a, i) => a + i));
17-
Assert.Equal(Enumerable.Empty<KeyValuePair<int, int>>(), transform(Enumerable.Empty<int>()).AggregateBy(i => i, 0, (a, i) => a + i));
16+
Assert.Equal([], source.AggregateBy(i => i, i => i, (a, i) => a + i));
17+
Assert.Equal([], source.AggregateBy(i => i, 0, (a, i) => a + i));
1818
});
1919
}
2020

src/libraries/System.Linq/tests/ChunkTests.cs

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,8 @@ public void ChunkSourceLazily()
4242
[InlineData(new[] {9999, 0, 888, -1, 66, -777, 1, 2, -12345})]
4343
public void ChunkSourceRepeatCalls(int[] array)
4444
{
45-
Assert.All(IdentityTransforms<int>(), t =>
45+
Assert.All(CreateSources(array), source =>
4646
{
47-
IEnumerable<int> source = t(array);
48-
4947
Assert.Equal(source.Chunk(3), source.Chunk(3));
5048
});
5149
}
@@ -54,10 +52,8 @@ public void ChunkSourceRepeatCalls(int[] array)
5452
[InlineData(new[] {9999, 0, 888, -1, 66, -777, 1, 2, -12345})]
5553
public void ChunkSourceEvenly(int[] array)
5654
{
57-
Assert.All(IdentityTransforms<int>(), t =>
55+
Assert.All(CreateSources(array), source =>
5856
{
59-
IEnumerable<int> source = t(array);
60-
6157
using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
6258
chunks.MoveNext();
6359
Assert.Equal(new[] { 9999, 0, 888 }, chunks.Current);
@@ -73,10 +69,8 @@ public void ChunkSourceEvenly(int[] array)
7369
[InlineData(new[] {9999, 0, 888, -1, 66, -777, 1, 2})]
7470
public void ChunkSourceUnevenly(int[] array)
7571
{
76-
Assert.All(IdentityTransforms<int>(), t =>
72+
Assert.All(CreateSources(array), source =>
7773
{
78-
IEnumerable<int> source = t(array);
79-
8074
using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
8175
chunks.MoveNext();
8276
Assert.Equal(new[] { 9999, 0, 888 }, chunks.Current);
@@ -92,10 +86,8 @@ public void ChunkSourceUnevenly(int[] array)
9286
[InlineData(new[] {9999, 0})]
9387
public void ChunkSourceSmallerThanMaxSize(int[] array)
9488
{
95-
Assert.All(IdentityTransforms<int>(), t =>
89+
Assert.All(CreateSources(array), source =>
9690
{
97-
IEnumerable<int> source = t(array);
98-
9991
using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
10092
chunks.MoveNext();
10193
Assert.Equal(new[] { 9999, 0 }, chunks.Current);
@@ -107,10 +99,8 @@ public void ChunkSourceSmallerThanMaxSize(int[] array)
10799
[InlineData(new int[0])]
108100
public void EmptySourceYieldsNoChunks(int[] array)
109101
{
110-
Assert.All(IdentityTransforms<int>(), t =>
102+
Assert.All(CreateSources(array), source =>
111103
{
112-
IEnumerable<int> source = t(array);
113-
114104
using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
115105
Assert.False(chunks.MoveNext());
116106
});

src/libraries/System.Linq/tests/ConcatTests.cs

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ private static void SameResultsWithQueryAndRepeatCallsWorker<T>(IEnumerable<T> f
3131
first = from item in first select item;
3232
second = from item in second select item;
3333

34-
VerifyEqualsWorker(first.Concat(second), first.Concat(second));
35-
VerifyEqualsWorker(second.Concat(first), second.Concat(first));
34+
Assert.Equal(first.Concat(second), first.Concat(second));
35+
Assert.Equal(second.Concat(first), second.Concat(first));
3636
}
3737

3838
[Theory]
@@ -41,8 +41,8 @@ private static void SameResultsWithQueryAndRepeatCallsWorker<T>(IEnumerable<T> f
4141
[InlineData(new int[] { 2, 3, 5, 9 }, new int[] { 8, 10 }, new int[] { 2, 3, 5, 9, 8, 10 })] // Neither side is empty
4242
public void PossiblyEmptyInputs(IEnumerable<int> first, IEnumerable<int> second, IEnumerable<int> expected)
4343
{
44-
VerifyEqualsWorker(expected, first.Concat(second));
45-
VerifyEqualsWorker(expected.Skip(first.Count()).Concat(expected.Take(first.Count())), second.Concat(first)); // Swap the inputs around
44+
Assert.Equal(expected, first.Concat(second));
45+
Assert.Equal(expected.Skip(first.Count()).Concat(expected.Take(first.Count())), second.Concat(first)); // Swap the inputs around
4646
}
4747

4848
[Fact]
@@ -80,7 +80,7 @@ public void SecondNull()
8080
public void VerifyEquals(IEnumerable<int> expected, IEnumerable<int> actual)
8181
{
8282
// workaround: xUnit type inference doesn't work if the input type is not T (like IEnumerable<T>)
83-
VerifyEqualsWorker(expected, actual);
83+
Assert.Equal(expected, actual);
8484
}
8585

8686
[Theory]
@@ -133,23 +133,6 @@ public void First_Last_ElementAt(IEnumerable<int> _, IEnumerable<int> actual)
133133
}
134134
}
135135

136-
private static void VerifyEqualsWorker<T>(IEnumerable<T> expected, IEnumerable<T> actual)
137-
{
138-
// Returns a list of functions that, when applied to enumerable, should return
139-
// another one that has equivalent contents.
140-
var identityTransforms = IdentityTransforms<T>();
141-
142-
// We run the transforms N^2 times, by testing all transforms
143-
// of expected against all transforms of actual.
144-
foreach (var outTransform in identityTransforms)
145-
{
146-
foreach (var inTransform in identityTransforms)
147-
{
148-
Assert.Equal(outTransform(expected), inTransform(actual));
149-
}
150-
}
151-
}
152-
153136
public static IEnumerable<object[]> ArraySourcesData() => GenerateSourcesData(outerTransform: e => e.ToArray());
154137

155138
public static IEnumerable<object[]> SelectArraySourcesData() => GenerateSourcesData(outerTransform: e => e.Select(i => i).ToArray());
@@ -292,7 +275,7 @@ public void ManyConcats(IEnumerable<IEnumerable<int>> sources)
292275
}
293276

294277
Assert.Equal(sources.Sum(s => s.Count()), concatee.Count());
295-
VerifyEqualsWorker(sources.SelectMany(s => s), concatee);
278+
Assert.Equal(sources.SelectMany(s => s), concatee);
296279
}
297280
}
298281

src/libraries/System.Linq/tests/CountTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ public void RunOnce<T>(int count, IEnumerable<T> enumerable)
9999

100100
private static IEnumerable<object[]> EnumerateCollectionTypesAndCounts<T>(int count, IEnumerable<T> enumerable)
101101
{
102-
foreach (var transform in IdentityTransforms<T>())
102+
foreach (IEnumerable<T> source in CreateSources(enumerable))
103103
{
104-
yield return new object[] { count, transform(enumerable) };
104+
yield return [count, source];
105105
}
106106
}
107107

src/libraries/System.Linq/tests/EnumerableTests.cs

Lines changed: 75 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Collections.Generic;
66
using System.Collections.ObjectModel;
77
using Xunit;
8+
using Xunit.Sdk;
89

910
namespace System.Linq.Tests
1011
{
@@ -243,6 +244,7 @@ protected static IEnumerable<T> FlipIsCollection<T>(IEnumerable<T> source)
243244
{
244245
return source is ICollection<T> ? ForceNotCollection(source) : new List<T>(source);
245246
}
247+
246248
protected static T[] Repeat<T>(Func<int, T> factory, int count)
247249
{
248250
T[] results = new T[count];
@@ -316,26 +318,83 @@ protected static IEnumerable<IEnumerable<T>> CreateSources<T>(IEnumerable<T> sou
316318
}
317319
}
318320

319-
protected static List<Func<IEnumerable<T>, IEnumerable<T>>> IdentityTransforms<T>()
321+
protected static IEnumerable<Func<IEnumerable<T>, IEnumerable<T>>> IdentityTransforms<T>()
320322
{
321-
// All of these transforms should take an enumerable and produce
322-
// another enumerable with the same contents.
323-
return new List<Func<IEnumerable<T>, IEnumerable<T>>>
323+
// Various collection types all representing the same source.
324+
List<Func<IEnumerable<T>, IEnumerable<T>>> sources =
325+
[
326+
e => e, // original
327+
e => e.ToArray(), // T[]
328+
e => e.ToList(), // List<T>
329+
e => new ReadOnlyCollection<T>(e.ToArray()), // IList<T> that's not List<T>/T[]
330+
e => new TestCollection<T>(e.ToArray()), // ICollection<T> that's not IList<T>
331+
e => new TestReadOnlyCollection<T>(e.ToArray()), // IReadOnlyCollection<T> that's not ICollection<T>
332+
e => ForceNotCollection(e), // IEnumerable<T> with no other interfaces
333+
];
334+
if (typeof(T) == typeof(char))
324335
{
325-
e => e,
326-
e => e.ToArray(),
327-
e => e.ToList(),
328-
e => e.ToList().Take(int.MaxValue),
336+
sources.Add(e => (IEnumerable<T>)(object)string.Concat((IEnumerable<char>)(object)e)); // string
337+
}
338+
339+
// Various transforms that all yield the same elements as the source.
340+
List<Func<IEnumerable<T>, IEnumerable<T>>> transforms =
341+
[
342+
// Append
343+
e =>
344+
{
345+
T[] values = e.ToArray();
346+
return values.Length == 0 ? [] : values[0..^1].Append(values[^1]);
347+
},
348+
349+
// Concat
350+
e => e.Concat(ForceNotCollection<T>([])),
351+
e => ForceNotCollection<T>([]).Concat(e),
352+
353+
// Prepend
354+
e =>
355+
{
356+
T[] values = e.ToArray();
357+
return values.Length == 0 ? [] : values[1..].Prepend(values[0]);
358+
},
359+
360+
// Reverse
361+
e => e.Reverse().Reverse(),
362+
363+
// Select
329364
e => e.Select(i => i),
330-
e => e.Select(i => i).Take(int.MaxValue),
331-
e => e.Select(i => i).Where(i => true),
365+
366+
// SelectMany
367+
e => e.SelectMany<T, T>(i => [i]),
368+
369+
// Take
370+
e => e.Take(int.MaxValue),
371+
e => e.TakeLast(int.MaxValue),
372+
e => e.TakeWhile(i => true),
373+
374+
// Skip
375+
e => e.SkipWhile(i => false),
376+
377+
// Where
332378
e => e.Where(i => true),
333-
e => e.Concat(Array.Empty<T>()),
334-
e => e.Concat(ForceNotCollection(Array.Empty<T>())),
335-
e => ForceNotCollection(e),
336-
e => ForceNotCollection(e).Skip(0),
337-
e => new ReadOnlyCollection<T>(e.ToArray()),
338-
};
379+
];
380+
381+
foreach (Func<IEnumerable<T>, IEnumerable<T>> source in sources)
382+
{
383+
// Yield the source itself.
384+
yield return source;
385+
386+
foreach (Func<IEnumerable<T>, IEnumerable<T>> transform in transforms)
387+
{
388+
// Yield a single transform on the source
389+
yield return e => transform(source(e));
390+
391+
foreach (Func<IEnumerable<T>, IEnumerable<T>> transform2 in transforms)
392+
{
393+
// Yield a second transform on the first transform on the source.
394+
yield return e => transform2(transform(source(e)));
395+
}
396+
}
397+
}
339398
}
340399

341400
protected sealed class DelegateIterator<TSource> : IEnumerable<TSource>, IEnumerator<TSource>

src/libraries/System.Linq/tests/SelectManyTests.cs

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -372,31 +372,23 @@ public void ForcedToEnumeratorDoesntEnumerateIndexedResultSel()
372372
Assert.False(en is not null && en.MoveNext());
373373
}
374374

375-
[Theory]
376-
[MemberData(nameof(ParameterizedTestsData))]
377-
public void ParameterizedTests(IEnumerable<int> source, Func<int, IEnumerable<int>> selector)
375+
[Fact]
376+
public void ParameterizedTests()
378377
{
379-
Assert.All(CreateSources(source), source =>
378+
for (int i = 1; i <= 20; i++)
380379
{
381-
var expected = source.Select(i => selector(i)).Aggregate((l, r) => l.Concat(r));
382-
var actual = source.SelectMany(selector);
380+
Assert.All(CreateSources(Enumerable.Range(1, i)), source =>
381+
{
382+
Func<int, IEnumerable<int>> selector = n => Enumerable.Range(i, n);
383383

384-
Assert.Equal(expected, actual);
385-
Assert.Equal(expected.Count(), actual.Count()); // SelectMany may employ an optimized Count implementation.
386-
Assert.Equal(expected.ToArray(), actual.ToArray());
387-
Assert.Equal(expected.ToList(), actual.ToList());
388-
});
389-
}
384+
var expected = source.Select(i => selector(i)).Aggregate((l, r) => l.Concat(r)).ToArray();
385+
var actual = source.SelectMany(selector);
390386

391-
public static IEnumerable<object[]> ParameterizedTestsData()
392-
{
393-
foreach (Func<IEnumerable<int>, IEnumerable<int>> transform in IdentityTransforms<int>())
394-
{
395-
for (int i = 1; i <= 20; i++)
396-
{
397-
Func<int, IEnumerable<int>> selector = n => transform(Enumerable.Range(i, n));
398-
yield return new object[] { Enumerable.Range(1, i), selector };
399-
}
387+
Assert.Equal(expected, actual);
388+
Assert.Equal(expected.Length, actual.Count()); // SelectMany may employ an optimized Count implementation.
389+
Assert.Equal(expected, actual.ToArray());
390+
Assert.Equal(expected, actual.ToList());
391+
});
400392
}
401393
}
402394

0 commit comments

Comments
 (0)