Skip to content

Fix LINQ handling of iterator.Take(...).Last(...) #112680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,12 @@ public override Iterator<TSource> Take(int count)
{
if (_source is Iterator<TSource> iterator &&
iterator.GetCount(onlyIfCheap: true) is int count &&
count >= _minIndexInclusive)
count > _minIndexInclusive)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we miss a fast path if count != -1 but count <= _minIndexInclusive.

Copy link
Member Author

@stephentoub stephentoub Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this change is to fix a bug. I don't want to add additional optimizations as part of it. Please feel free to submit a follow-up PR after this goes in.

{
return !HasLimit ?
// If there's no upper bound, or if there are fewer items in the list
// than the upper bound allows, just return the last element of the list.
// Otherwise, get the element at the upper bound.
return (uint)count <= (uint)_maxIndexInclusive ?
iterator.TryGetLast(out found) :
iterator.TryGetElementAt(_maxIndexInclusive, out found);
}
Expand Down
6 changes: 3 additions & 3 deletions src/libraries/System.Linq/tests/AggregateByTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ public class AggregateByTests : EnumerableTests
[Fact]
public void Empty()
{
Assert.All(IdentityTransforms<int>(), transform =>
Assert.All(CreateSources<int>([]), source =>
{
Assert.Equal([], transform([]).AggregateBy(i => i, i => i, (a, i) => a + i));
Assert.Equal([], transform([]).AggregateBy(i => i, 0, (a, i) => a + i));
Assert.Equal([], source.AggregateBy(i => i, i => i, (a, i) => a + i));
Assert.Equal([], source.AggregateBy(i => i, 0, (a, i) => a + i));
});
}

Expand Down
20 changes: 5 additions & 15 deletions src/libraries/System.Linq/tests/ChunkTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ public void ChunkSourceLazily()
[InlineData(new[] {9999, 0, 888, -1, 66, -777, 1, 2, -12345})]
public void ChunkSourceRepeatCalls(int[] array)
{
Assert.All(IdentityTransforms<int>(), t =>
Assert.All(CreateSources(array), source =>
{
IEnumerable<int> source = t(array);

Assert.Equal(source.Chunk(3), source.Chunk(3));
});
}
Expand All @@ -54,10 +52,8 @@ public void ChunkSourceRepeatCalls(int[] array)
[InlineData(new[] {9999, 0, 888, -1, 66, -777, 1, 2, -12345})]
public void ChunkSourceEvenly(int[] array)
{
Assert.All(IdentityTransforms<int>(), t =>
Assert.All(CreateSources(array), source =>
{
IEnumerable<int> source = t(array);

using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
chunks.MoveNext();
Assert.Equal(new[] { 9999, 0, 888 }, chunks.Current);
Expand All @@ -73,10 +69,8 @@ public void ChunkSourceEvenly(int[] array)
[InlineData(new[] {9999, 0, 888, -1, 66, -777, 1, 2})]
public void ChunkSourceUnevenly(int[] array)
{
Assert.All(IdentityTransforms<int>(), t =>
Assert.All(CreateSources(array), source =>
{
IEnumerable<int> source = t(array);

using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
chunks.MoveNext();
Assert.Equal(new[] { 9999, 0, 888 }, chunks.Current);
Expand All @@ -92,10 +86,8 @@ public void ChunkSourceUnevenly(int[] array)
[InlineData(new[] {9999, 0})]
public void ChunkSourceSmallerThanMaxSize(int[] array)
{
Assert.All(IdentityTransforms<int>(), t =>
Assert.All(CreateSources(array), source =>
{
IEnumerable<int> source = t(array);

using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
chunks.MoveNext();
Assert.Equal(new[] { 9999, 0 }, chunks.Current);
Expand All @@ -107,10 +99,8 @@ public void ChunkSourceSmallerThanMaxSize(int[] array)
[InlineData(new int[0])]
public void EmptySourceYieldsNoChunks(int[] array)
{
Assert.All(IdentityTransforms<int>(), t =>
Assert.All(CreateSources(array), source =>
{
IEnumerable<int> source = t(array);

using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
Assert.False(chunks.MoveNext());
});
Expand Down
29 changes: 6 additions & 23 deletions src/libraries/System.Linq/tests/ConcatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ private static void SameResultsWithQueryAndRepeatCallsWorker<T>(IEnumerable<T> f
first = from item in first select item;
second = from item in second select item;

VerifyEqualsWorker(first.Concat(second), first.Concat(second));
VerifyEqualsWorker(second.Concat(first), second.Concat(first));
Assert.Equal(first.Concat(second), first.Concat(second));
Assert.Equal(second.Concat(first), second.Concat(first));
}

[Theory]
Expand All @@ -41,8 +41,8 @@ private static void SameResultsWithQueryAndRepeatCallsWorker<T>(IEnumerable<T> f
[InlineData(new int[] { 2, 3, 5, 9 }, new int[] { 8, 10 }, new int[] { 2, 3, 5, 9, 8, 10 })] // Neither side is empty
public void PossiblyEmptyInputs(IEnumerable<int> first, IEnumerable<int> second, IEnumerable<int> expected)
{
VerifyEqualsWorker(expected, first.Concat(second));
VerifyEqualsWorker(expected.Skip(first.Count()).Concat(expected.Take(first.Count())), second.Concat(first)); // Swap the inputs around
Assert.Equal(expected, first.Concat(second));
Assert.Equal(expected.Skip(first.Count()).Concat(expected.Take(first.Count())), second.Concat(first)); // Swap the inputs around
}

[Fact]
Expand Down Expand Up @@ -80,7 +80,7 @@ public void SecondNull()
public void VerifyEquals(IEnumerable<int> expected, IEnumerable<int> actual)
{
// workaround: xUnit type inference doesn't work if the input type is not T (like IEnumerable<T>)
VerifyEqualsWorker(expected, actual);
Assert.Equal(expected, actual);
}

[Theory]
Expand Down Expand Up @@ -133,23 +133,6 @@ public void First_Last_ElementAt(IEnumerable<int> _, IEnumerable<int> actual)
}
}

private static void VerifyEqualsWorker<T>(IEnumerable<T> expected, IEnumerable<T> actual)
{
// Returns a list of functions that, when applied to enumerable, should return
// another one that has equivalent contents.
var identityTransforms = IdentityTransforms<T>();

// We run the transforms N^2 times, by testing all transforms
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm puzzled by the purpose this method, it all seems entirely unnecessary. Am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of the method I deleted because it served no worthwhile purpose? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I'm legitimately curious about what motivated its inclusion back then.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misguided, I think. The stated intent is to validate that all transformations are equivalent to all other transformations, but that doesn't require N^2 tests, and we also don't need to test the test transformations themselves.

I guess it didn't hurt much when it was fast because there were only a few transformations, but my change causes N to grow from 13 to 1099, which means N^2 grows from 169 to 1,207,801... and that difference is very noticeable :)

// of expected against all transforms of actual.
foreach (var outTransform in identityTransforms)
{
foreach (var inTransform in identityTransforms)
{
Assert.Equal(outTransform(expected), inTransform(actual));
}
}
}

public static IEnumerable<object[]> ArraySourcesData() => GenerateSourcesData(outerTransform: e => e.ToArray());

public static IEnumerable<object[]> SelectArraySourcesData() => GenerateSourcesData(outerTransform: e => e.Select(i => i).ToArray());
Expand Down Expand Up @@ -292,7 +275,7 @@ public void ManyConcats(IEnumerable<IEnumerable<int>> sources)
}

Assert.Equal(sources.Sum(s => s.Count()), concatee.Count());
VerifyEqualsWorker(sources.SelectMany(s => s), concatee);
Assert.Equal(sources.SelectMany(s => s), concatee);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/libraries/System.Linq/tests/CountTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ public void RunOnce<T>(int count, IEnumerable<T> enumerable)

private static IEnumerable<object[]> EnumerateCollectionTypesAndCounts<T>(int count, IEnumerable<T> enumerable)
{
foreach (var transform in IdentityTransforms<T>())
foreach (IEnumerable<T> source in CreateSources(enumerable))
{
yield return [count, transform(enumerable)];
yield return [count, source];
}
}

Expand Down
89 changes: 74 additions & 15 deletions src/libraries/System.Linq/tests/EnumerableTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Collections.ObjectModel;
using Xunit;
using Xunit.Sdk;

namespace System.Linq.Tests
{
Expand Down Expand Up @@ -243,6 +244,7 @@ protected static IEnumerable<T> FlipIsCollection<T>(IEnumerable<T> source)
{
return source is ICollection<T> ? ForceNotCollection(source) : new List<T>(source);
}

protected static T[] Repeat<T>(Func<int, T> factory, int count)
{
T[] results = new T[count];
Expand Down Expand Up @@ -316,26 +318,83 @@ protected static IEnumerable<IEnumerable<T>> CreateSources<T>(IEnumerable<T> sou
}
}

protected static List<Func<IEnumerable<T>, IEnumerable<T>>> IdentityTransforms<T>()
protected static IEnumerable<Func<IEnumerable<T>, IEnumerable<T>>> IdentityTransforms<T>()
{
// All of these transforms should take an enumerable and produce
// another enumerable with the same contents.
return
// Various collection types all representing the same source.
List<Func<IEnumerable<T>, IEnumerable<T>>> sources =
[
e => e,
e => e.ToArray(),
e => e.ToList(),
e => e.ToList().Take(int.MaxValue),
e => e, // original
e => e.ToArray(), // T[]
e => e.ToList(), // List<T>
e => new ReadOnlyCollection<T>(e.ToArray()), // IList<T> that's not List<T>/T[]
e => new TestCollection<T>(e.ToArray()), // ICollection<T> that's not IList<T>
e => new TestReadOnlyCollection<T>(e.ToArray()), // IReadOnlyCollection<T> that's not ICollection<T>
e => ForceNotCollection(e), // IEnumerable<T> with no other interfaces
];
if (typeof(T) == typeof(char))
{
sources.Add(e => (IEnumerable<T>)(object)string.Concat((IEnumerable<char>)(object)e)); // string
}

// Various transforms that all yield the same elements as the source.
List<Func<IEnumerable<T>, IEnumerable<T>>> transforms =
[
// Append
e =>
{
T[] values = e.ToArray();
return values.Length == 0 ? [] : values[0..^1].Append(values[^1]);
},

// Concat
e => e.Concat(ForceNotCollection<T>([])),
e => ForceNotCollection<T>([]).Concat(e),

// Prepend
e =>
{
T[] values = e.ToArray();
return values.Length == 0 ? [] : values[1..].Prepend(values[0]);
},

// Reverse
e => e.Reverse().Reverse(),

// Select
e => e.Select(i => i),
e => e.Select(i => i).Take(int.MaxValue),
e => e.Select(i => i).Where(i => true),

// SelectMany
e => e.SelectMany<T, T>(i => [i]),

// Take
e => e.Take(int.MaxValue),
e => e.TakeLast(int.MaxValue),
e => e.TakeWhile(i => true),

// Skip
e => e.SkipWhile(i => false),

// Where
e => e.Where(i => true),
e => e.Concat(Array.Empty<T>()),
e => e.Concat(ForceNotCollection(Array.Empty<T>())),
e => ForceNotCollection(e),
e => ForceNotCollection(e).Skip(0),
e => new ReadOnlyCollection<T>(e.ToArray())
];

foreach (Func<IEnumerable<T>, IEnumerable<T>> source in sources)
{
// Yield the source itself.
yield return source;

foreach (Func<IEnumerable<T>, IEnumerable<T>> transform in transforms)
{
// Yield a single transform on the source
yield return e => transform(source(e));

foreach (Func<IEnumerable<T>, IEnumerable<T>> transform2 in transforms)
{
// Yield a second transform on the first transform on the source.
yield return e => transform2(transform(source(e)));
}
}
}
}

protected sealed class DelegateIterator<TSource> : IEnumerable<TSource>, IEnumerator<TSource>
Expand Down
34 changes: 13 additions & 21 deletions src/libraries/System.Linq/tests/SelectManyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -372,31 +372,23 @@ public void ForcedToEnumeratorDoesntEnumerateIndexedResultSel()
Assert.False(en is not null && en.MoveNext());
}

[Theory]
[MemberData(nameof(ParameterizedTestsData))]
public void ParameterizedTests(IEnumerable<int> source, Func<int, IEnumerable<int>> selector)
[Fact]
public void ParameterizedTests()
{
Assert.All(CreateSources(source), source =>
for (int i = 1; i <= 20; i++)
{
var expected = source.Select(i => selector(i)).Aggregate((l, r) => l.Concat(r));
var actual = source.SelectMany(selector);
Assert.All(CreateSources(Enumerable.Range(1, i)), source =>
{
Func<int, IEnumerable<int>> selector = n => Enumerable.Range(i, n);

Assert.Equal(expected, actual);
Assert.Equal(expected.Count(), actual.Count()); // SelectMany may employ an optimized Count implementation.
Assert.Equal(expected.ToArray(), actual.ToArray());
Assert.Equal(expected.ToList(), actual.ToList());
});
}
var expected = source.Select(i => selector(i)).Aggregate((l, r) => l.Concat(r)).ToArray();
var actual = source.SelectMany(selector);

public static IEnumerable<object[]> ParameterizedTestsData()
{
foreach (Func<IEnumerable<int>, IEnumerable<int>> transform in IdentityTransforms<int>())
{
for (int i = 1; i <= 20; i++)
{
Func<int, IEnumerable<int>> selector = n => transform(Enumerable.Range(i, n));
yield return [Enumerable.Range(1, i), selector];
}
Assert.Equal(expected, actual);
Assert.Equal(expected.Length, actual.Count()); // SelectMany may employ an optimized Count implementation.
Assert.Equal(expected, actual.ToArray());
Assert.Equal(expected, actual.ToList());
});
}
}

Expand Down
Loading
Loading