Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Iterator::{min, max} #59138

Merged
merged 5 commits into from
Mar 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/libcore/benches/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn scatter(x: i32) -> i32 { (x * 31) % 127 }
fn bench_max_by_key(b: &mut Bencher) {
b.iter(|| {
let it = 0..100;
it.max_by_key(|&x| scatter(x))
it.map(black_box).max_by_key(|&x| scatter(x))
})
}

Expand All @@ -56,7 +56,7 @@ fn bench_max_by_key2(b: &mut Bencher) {
fn bench_max(b: &mut Bencher) {
b.iter(|| {
let it = 0..100;
it.map(scatter).max()
it.map(black_box).map(scatter).max()
})
}

Expand Down
75 changes: 19 additions & 56 deletions src/libcore/iter/traits/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2008,12 +2008,7 @@ pub trait Iterator {
#[stable(feature = "rust1", since = "1.0.0")]
fn max(self) -> Option<Self::Item> where Self: Sized, Self::Item: Ord
{
select_fold1(self,
|_| (),
// switch to y even if it is only equal, to preserve
// stability.
|_, x, _, y| *x <= *y)
.map(|(_, x)| x)
self.max_by(Ord::cmp)
}

/// Returns the minimum element of an iterator.
Expand All @@ -2038,12 +2033,7 @@ pub trait Iterator {
#[stable(feature = "rust1", since = "1.0.0")]
fn min(self) -> Option<Self::Item> where Self: Sized, Self::Item: Ord
{
select_fold1(self,
|_| (),
// only switch to y if it is strictly smaller, to
// preserve stability.
|_, x, _, y| *x > *y)
.map(|(_, x)| x)
self.min_by(Ord::cmp)
}

/// Returns the element that gives the maximum value from the
Expand All @@ -2062,15 +2052,11 @@ pub trait Iterator {
/// ```
#[inline]
#[stable(feature = "iter_cmp_by_key", since = "1.6.0")]
fn max_by_key<B: Ord, F>(self, f: F) -> Option<Self::Item>
fn max_by_key<B: Ord, F>(self, mut f: F) -> Option<Self::Item>
where Self: Sized, F: FnMut(&Self::Item) -> B,
{
select_fold1(self,
f,
// switch to y even if it is only equal, to preserve
// stability.
|x_p, _, y_p, _| x_p <= y_p)
.map(|(_, x)| x)
// switch to y even if it is only equal, to preserve stability.
select_fold1(self.map(|x| (f(&x), x)), |(x_p, _), (y_p, _)| x_p <= y_p).map(|(_, x)| x)
}

/// Returns the element that gives the maximum value with respect to the
Expand All @@ -2092,12 +2078,8 @@ pub trait Iterator {
fn max_by<F>(self, mut compare: F) -> Option<Self::Item>
where Self: Sized, F: FnMut(&Self::Item, &Self::Item) -> Ordering,
{
select_fold1(self,
|_| (),
// switch to y even if it is only equal, to preserve
// stability.
|_, x, _, y| Ordering::Greater != compare(x, y))
.map(|(_, x)| x)
// switch to y even if it is only equal, to preserve stability.
select_fold1(self, |x, y| compare(x, y) != Ordering::Greater)
}

/// Returns the element that gives the minimum value from the
Expand All @@ -2115,15 +2097,11 @@ pub trait Iterator {
/// assert_eq!(*a.iter().min_by_key(|x| x.abs()).unwrap(), 0);
/// ```
#[stable(feature = "iter_cmp_by_key", since = "1.6.0")]
fn min_by_key<B: Ord, F>(self, f: F) -> Option<Self::Item>
fn min_by_key<B: Ord, F>(self, mut f: F) -> Option<Self::Item>
where Self: Sized, F: FnMut(&Self::Item) -> B,
{
select_fold1(self,
f,
// only switch to y if it is strictly smaller, to
// preserve stability.
|x_p, _, y_p, _| x_p > y_p)
.map(|(_, x)| x)
// only switch to y if it is strictly smaller, to preserve stability.
select_fold1(self.map(|x| (f(&x), x)), |(x_p, _), (y_p, _)| x_p > y_p).map(|(_, x)| x)
}

/// Returns the element that gives the minimum value with respect to the
Expand All @@ -2145,12 +2123,8 @@ pub trait Iterator {
fn min_by<F>(self, mut compare: F) -> Option<Self::Item>
where Self: Sized, F: FnMut(&Self::Item, &Self::Item) -> Ordering,
{
select_fold1(self,
|_| (),
// switch to y even if it is strictly smaller, to
// preserve stability.
|_, x, _, y| Ordering::Greater == compare(x, y))
.map(|(_, x)| x)
// only switch to y if it is strictly smaller, to preserve stability.
select_fold1(self, |x, y| compare(x, y) == Ordering::Greater)
}


Expand Down Expand Up @@ -2693,34 +2667,23 @@ pub trait Iterator {
}
}

/// Select an element from an iterator based on the given "projection"
/// and "comparison" function.
/// Select an element from an iterator based on the given "comparison"
/// function.
///
/// This is an idiosyncratic helper to try to factor out the
/// commonalities of {max,min}{,_by}. In particular, this avoids
/// having to implement optimizations several times.
#[inline]
fn select_fold1<I, B, FProj, FCmp>(mut it: I,
mut f_proj: FProj,
mut f_cmp: FCmp) -> Option<(B, I::Item)>
where I: Iterator,
FProj: FnMut(&I::Item) -> B,
FCmp: FnMut(&B, &I::Item, &B, &I::Item) -> bool
fn select_fold1<I, F>(mut it: I, mut f: F) -> Option<I::Item>
where
I: Iterator,
F: FnMut(&I::Item, &I::Item) -> bool,
{
// start with the first element as our selection. This avoids
// having to use `Option`s inside the loop, translating to a
// sizeable performance gain (6x in one case).
it.next().map(|first| {
let first_p = f_proj(&first);

it.fold((first_p, first), |(sel_p, sel), x| {
let x_p = f_proj(&x);
if f_cmp(&sel_p, &sel, &x_p, &x) {
(x_p, x)
} else {
(sel_p, sel)
}
})
it.fold(first, |sel, x| if f(&sel, &x) { x } else { sel })
})
}

Expand Down
28 changes: 28 additions & 0 deletions src/libcore/tests/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1082,12 +1082,39 @@ fn test_iterator_product_result() {
assert_eq!(v.iter().cloned().product::<Result<i32, _>>(), Err(()));
}

/// A wrapper struct that implements `Eq` and `Ord` based on the wrapped
/// integer modulo 3. Used to test that `Iterator::max` and `Iterator::min`
/// return the correct element if some of them are equal.
#[derive(Debug)]
struct Mod3(i32);

impl PartialEq for Mod3 {
fn eq(&self, other: &Self) -> bool {
self.0 % 3 == other.0 % 3
}
}

impl Eq for Mod3 {}

impl PartialOrd for Mod3 {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Ord for Mod3 {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
(self.0 % 3).cmp(&(other.0 % 3))
}
}

#[test]
fn test_iterator_max() {
let v: &[_] = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
assert_eq!(v[..4].iter().cloned().max(), Some(3));
assert_eq!(v.iter().cloned().max(), Some(10));
assert_eq!(v[..0].iter().cloned().max(), None);
assert_eq!(v.iter().cloned().map(Mod3).max().map(|x| x.0), Some(8));
}

#[test]
Expand All @@ -1096,6 +1123,7 @@ fn test_iterator_min() {
assert_eq!(v[..4].iter().cloned().min(), Some(0));
assert_eq!(v.iter().cloned().min(), Some(0));
assert_eq!(v[..0].iter().cloned().min(), None);
assert_eq!(v.iter().cloned().map(Mod3).min().map(|x| x.0), Some(0));
}

#[test]
Expand Down