Skip to content

implement Iterator::nth for Combinations #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 69 additions & 26 deletions src/combinations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::fmt;
use super::lazy_buffer::LazyBuffer;

/// An iterator to iterate through all the `k`-length combinations in an iterator.
/// Note: it iterates over combinations in lexicographic order and
/// thus may not work as expected with infinite iterators.
///
/// See [`.combinations()`](../trait.Itertools.html#method.combinations) for more information.
#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
Expand All @@ -20,9 +22,27 @@ impl<I> Clone for Combinations<I>
clone_fields!(k, indices, pool, first);
}

impl<I: Iterator> Combinations<I> {
fn advance(&mut self) {
let pool_len = self.pool.len();

// Scan from the end, looking for an index to increment
let mut i = self.k - 1;
while self.indices[i] + self.k == i + pool_len {
i -= 1;
}

// Increment index, and reset the ones to its right
self.indices[i] += 1;
for j in i + 1..self.k {
self.indices[j] = self.indices[j - 1] + 1;
}
}
}

impl<I> fmt::Debug for Combinations<I>
where I: Iterator + fmt::Debug,
I::Item: fmt::Debug,
I::Item: fmt::Debug
{
debug_fmt_fields!(Combinations, k, indices, pool, first);
}
Expand All @@ -31,10 +51,7 @@ impl<I> fmt::Debug for Combinations<I>
pub fn combinations<I>(iter: I, k: usize) -> Combinations<I>
where I: Iterator
{
let mut indices: Vec<usize> = Vec::with_capacity(k);
for i in 0..k {
indices.push(i);
}
let indices: Vec<usize> = (0..k).collect();
let mut pool: LazyBuffer<I> = LazyBuffer::new(iter);

for _ in 0..k {
Expand Down Expand Up @@ -64,35 +81,61 @@ impl<I> Iterator for Combinations<I>
return None;
}
self.first = false;
} else if self.k == 0 {
return None;
} else {
// Scan from the end, looking for an index to increment
let mut i: usize = self.k - 1;
if self.k == 0 {
return None;
}

// Check if we need to consume more from the iterator
if self.indices[i] == pool_len - 1 && !self.pool.is_done() {
if self.pool.get_next() {
pool_len += 1;
}
if self.indices[self.k - 1] == pool_len - 1 && self.pool.get_next() {
pool_len += 1;
}

while self.indices[i] == i + pool_len - self.k {
if i > 0 {
i -= 1;
} else {
// Reached the last combination
return None;
}
if self.indices[0] == pool_len - self.k {
return None;
}

// Increment index, and reset the ones to its right
self.indices[i] += 1;
let mut j = i + 1;
while j < self.k {
self.indices[j] = self.indices[j - 1] + 1;
j += 1;
self.advance();
}

// Create result vector based on the indices
let mut result = Vec::with_capacity(self.k);
for i in self.indices.iter() {
result.push(self.pool[*i].clone());
}
Some(result)
}

fn nth(&mut self, n: usize) -> Option<Self::Item> {
if n == 0 {
return self.next();
}

let mut pool_len = self.pool.len();
if self.k == 0 || self.pool.is_done() && (pool_len == 0 || self.k > pool_len) {
return None;
}

let mut n = n;
if self.first {
self.first = false;
} else {
n += 1;
}

// Drain iterator and increase last index.
while n > 0 && self.pool.get_next() {
self.indices[self.k - 1] += 1;
pool_len += 1;
n -= 1;
}

for _ in 0..n {
// check if we have reached the end
if self.indices[0] == pool_len - self.k {
return None;
}
self.advance();
}

// Create result vector based on the indices
Expand Down
45 changes: 45 additions & 0 deletions tests/quick.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use quickcheck as qc;
use std::ops::Range;
use std::cmp::{max, min, Ordering};
use std::collections::HashSet;
use std::panic::catch_unwind;
use itertools::Itertools;
use itertools::{
multizip,
Expand All @@ -34,6 +35,25 @@ use rand::Rng;
use rand::seq::SliceRandom;
use quickcheck::TestResult;

/// An unspecialized wrapper around another iterator.
struct Unspecialized<I>(I);

impl<I> Iterator for Unspecialized<I>
where I: Iterator
{
type Item = I::Item;

#[inline(always)]
fn next(&mut self) -> Option<I::Item> {
self.0.next()
}

#[inline(always)]
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}

/// Trait for size hint modifier types
trait HintKind: Copy + Send + qc::Arbitrary {
fn loosen_bounds(&self, org_hint: (usize, Option<usize>)) -> (usize, Option<usize>);
Expand Down Expand Up @@ -1159,3 +1179,28 @@ quickcheck! {
}
}
}

quickcheck! {
// check that the specialization of `Combinations::nth` produces the same
// results as the unspecialized version.
fn combinations_nth(l: u8, n: u8, i: u8) -> TestResult {
let (l, n, i) = (l as usize, n as usize, i as usize);

let s_ith = catch_unwind(|| {
(0..l).combinations(n).nth(i)
}).map_err(|_| "PANICKED");

let u_ith = catch_unwind(|| {
Unspecialized((0..l).combinations(n as usize)).nth(i)
}).map_err(|_| "PANICKED");

if s_ith == u_ith {
TestResult::passed()
} else {
TestResult::error(
format!(
"(0..{}).combinations({}).nth({}): {:?} != {:?}",
l, n, i, s_ith, u_ith))
}
}
}
16 changes: 16 additions & 0 deletions tests/test_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,22 @@ fn combinations() {
vec![3, 4],
]);

assert!((1..3).combinations(0).nth(1).is_none());
assert_eq!((1..3).combinations(0).nth(0), Some(vec![]));

let mut it = (1..3).combinations(0);
assert_eq!(it.nth(0), Some(vec![]));
assert_eq!(it.nth(0), None);

let mut it = (1..6).combinations(2);
assert_eq!(it.nth(0), Some(vec![1, 2]));
assert_eq!(it.nth(3), Some(vec![2, 3]));
assert_eq!(it.nth(2), Some(vec![3, 4]));
assert_eq!(it.nth(2), None);

let mut it = (0..8).combinations(4);
assert_eq!(it.nth(1), Some(vec![0, 1, 2, 4]));

it::assert_equal((0..0).tuple_combinations::<(_, _)>(), <Vec<_>>::new());
it::assert_equal((0..1).tuple_combinations::<(_, _)>(), <Vec<_>>::new());
it::assert_equal((0..2).tuple_combinations::<(_, _)>(), vec![(0, 1)]);
Expand Down