Skip to content

[Merged by Bors] - Add ExactSizeIterator implementation for QueryCombinatonIter #5148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,17 +343,26 @@ where
if max_size < K {
return (0, Some(0));
}
if max_size == K {
return (1, Some(1));
}

// n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
let max_combinations = (0..K)
.try_fold(1usize, |n, i| n.checked_mul(max_size - i))
.map(|n| {
let k_factorial: usize = (1..=K).product();
n / k_factorial
});
// binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
// See https://en.wikipedia.org/wiki/Binomial_coefficient
// See https://blog.plover.com/math/choose.html for implementation
// It was chosen to reduce overflow potential.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: It reduces overflow potential at least compared to the previous solution which computed n! and k!(n-k)!. This still can overflow even when the result is not an overflow. An example is provided in https://blog.plover.com/math/choose-2.html

It's also probably slower than the previous solution, because it does k div and k checked_mul, while the previous one did 1 div, k mul and k checked_mul.

fn choose(n: usize, k: usize) -> Option<usize> {
let ks = 1..=k;
let ns = (n - k + 1..=n).rev();
ks.zip(ns)
.try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k))
}
let smallest = K.min(max_size - K);
let max_combinations = choose(max_size, smallest);

let archetype_query = F::Fetch::IS_ARCHETYPAL && Q::Fetch::IS_ARCHETYPAL;
let min_combinations = if archetype_query { max_size } else { 0 };
let known_max = max_combinations.unwrap_or(usize::MAX);
let min_combinations = if archetype_query { known_max } else { 0 };
(min_combinations, max_combinations)
}
}
Expand All @@ -372,6 +381,21 @@ where
}
}

impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery + ArchetypeFilter, const K: usize>
ExactSizeIterator for QueryCombinationIter<'w, 's, Q, F, K>
where
QueryFetch<'w, Q>: Clone,
QueryFetch<'w, F>: Clone,
{
/// Returns the exact length of the iterator.
///
/// **NOTE**: When the iterator length overflows `usize`, this will
/// return `usize::MAX`.
fn len(&self) -> usize {
self.size_hint().0
}
}

// This is correct as [`QueryCombinationIter`] always returns `None` once exhausted.
impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery, const K: usize> FusedIterator
for QueryCombinationIter<'w, 's, Q, F, K>
Expand Down
267 changes: 101 additions & 166 deletions crates/bevy_ecs/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ pub(crate) unsafe fn debug_checked_unreachable() -> ! {
#[cfg(test)]
mod tests {
use super::WorldQuery;
use crate::prelude::{AnyOf, Entity, Or, With, Without};
use crate::prelude::{AnyOf, Entity, Or, QueryState, With, Without};
use crate::query::{ArchetypeFilter, QueryCombinationIter, QueryFetch, ReadOnlyWorldQuery};
use crate::system::{IntoSystem, Query, System};
use crate::{self as bevy_ecs, component::Component, world::World};
use std::any::type_name;
use std::collections::HashSet;

#[derive(Component, Debug, Hash, Eq, PartialEq, Clone, Copy)]
Expand Down Expand Up @@ -54,24 +56,81 @@ mod tests {
}

#[test]
fn query_filtered_len() {
fn query_filtered_exactsizeiterator_len() {
fn choose(n: usize, k: usize) -> usize {
if n == 0 || k == 0 || n < k {
return 0;
}
let ks = 1..=k;
let ns = (n - k + 1..=n).rev();
ks.zip(ns).fold(1, |acc, (k, n)| acc * n / k)
}
fn assert_combination<Q, F, const K: usize>(world: &mut World, expected_size: usize)
where
Q: ReadOnlyWorldQuery,
F: ReadOnlyWorldQuery + ArchetypeFilter,
for<'w> QueryFetch<'w, Q>: Clone,
for<'w> QueryFetch<'w, F>: Clone,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter_combinations::<K>(world);
let query_type = type_name::<QueryCombinationIter<Q, F, K>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
}
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
where
Q: ReadOnlyWorldQuery,
F: ReadOnlyWorldQuery + ArchetypeFilter,
for<'w> QueryFetch<'w, Q>: Clone,
for<'w> QueryFetch<'w, F>: Clone,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter(world);
let query_type = type_name::<QueryState<Q, F>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);

let expected = expected_size;
assert_combination::<Q, F, 0>(world, choose(expected, 0));
assert_combination::<Q, F, 1>(world, choose(expected, 1));
assert_combination::<Q, F, 2>(world, choose(expected, 2));
assert_combination::<Q, F, 5>(world, choose(expected, 5));
assert_combination::<Q, F, 43>(world, choose(expected, 43));
assert_combination::<Q, F, 128>(world, choose(expected, 128));
}
fn assert_all_sizes_iterator_equal(
iterator: impl ExactSizeIterator,
expected_size: usize,
query_type: &'static str,
) {
let size_hint_0 = iterator.size_hint().0;
let size_hint_1 = iterator.size_hint().1;
let len = iterator.len();
// `count` tests that not only it is the expected value, but also
// the value is accurate to what the query returns.
let count = iterator.count();
// This will show up when one of the asserts in this function fails
println!(
r#"query declared sizes:
for query: {query_type}
expected: {expected_size}
len(): {len}
size_hint().0: {size_hint_0}
size_hint().1: {size_hint_1:?}
count(): {count}"#
);
assert_eq!(len, expected_size);
assert_eq!(size_hint_0, expected_size);
assert_eq!(size_hint_1, Some(expected_size));
assert_eq!(count, expected_size);
}

let mut world = World::new();
world.spawn().insert_bundle((A(1), B(1)));
world.spawn().insert_bundle((A(2),));
world.spawn().insert_bundle((A(3),));

let mut values = world.query_filtered::<&A, With<B>>();
let n = 1;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Without<B>>();
let n = 2;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
assert_all_sizes_equal::<&A, With<B>>(&mut world, 1);
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 2);

let mut world = World::new();
world.spawn().insert_bundle((A(1), B(1), C(1)));
Expand All @@ -86,110 +145,37 @@ mod tests {
world.spawn().insert_bundle((A(10),));

// With/Without for B and C
let mut values = world.query_filtered::<&A, With<B>>();
let n = 3;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, With<C>>();
let n = 4;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Without<B>>();
let n = 7;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Without<C>>();
let n = 6;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
assert_all_sizes_equal::<&A, With<B>>(&mut world, 3);
assert_all_sizes_equal::<&A, With<C>>(&mut world, 4);
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 7);
assert_all_sizes_equal::<&A, Without<C>>(&mut world, 6);

// With/Without (And) combinations
let mut values = world.query_filtered::<&A, (With<B>, With<C>)>();
let n = 1;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, (With<B>, Without<C>)>();
let n = 2;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, (Without<B>, With<C>)>();
let n = 3;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, (Without<B>, Without<C>)>();
let n = 4;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
assert_all_sizes_equal::<&A, (With<B>, With<C>)>(&mut world, 1);
assert_all_sizes_equal::<&A, (With<B>, Without<C>)>(&mut world, 2);
assert_all_sizes_equal::<&A, (Without<B>, With<C>)>(&mut world, 3);
assert_all_sizes_equal::<&A, (Without<B>, Without<C>)>(&mut world, 4);

// With/Without Or<()> combinations
let mut values = world.query_filtered::<&A, Or<(With<B>, With<C>)>>();
let n = 6;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Or<(With<B>, Without<C>)>>();
let n = 7;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Or<(Without<B>, With<C>)>>();
let n = 8;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Or<(Without<B>, Without<C>)>>();
let n = 9;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);

let mut values = world.query_filtered::<&A, (Or<(With<B>,)>, Or<(With<C>,)>)>();
let n = 1;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>();
let n = 6;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);

world.spawn().insert_bundle((A(11), D(11)));

let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>();
let n = 7;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, Without<D>)>>();
let n = 10;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
assert_all_sizes_equal::<&A, Or<(With<B>, With<C>)>>(&mut world, 6);
assert_all_sizes_equal::<&A, Or<(With<B>, Without<C>)>>(&mut world, 7);
assert_all_sizes_equal::<&A, Or<(Without<B>, With<C>)>>(&mut world, 8);
assert_all_sizes_equal::<&A, Or<(Without<B>, Without<C>)>>(&mut world, 9);
assert_all_sizes_equal::<&A, (Or<(With<B>,)>, Or<(With<C>,)>)>(&mut world, 1);
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 6);

for i in 11..14 {
world.spawn().insert_bundle((A(i), D(i)));
}

assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 9);
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, Without<D>)>>(&mut world, 10);

// a fair amount of entities
for i in 14..20 {
world.spawn().insert_bundle((C(i), D(i)));
}
assert_all_sizes_equal::<Entity, (With<C>, With<D>)>(&mut world, 6);
}

#[test]
Expand All @@ -201,23 +187,6 @@ mod tests {
world.spawn().insert_bundle((A(3),));
world.spawn().insert_bundle((A(4),));

let mut a_query = world.query::<&A>();
let w = &world;
assert_eq!(a_query.iter_combinations::<0>(w).count(), 0);
assert_eq!(a_query.iter_combinations::<0>(w).size_hint().1, Some(0));
assert_eq!(a_query.iter_combinations::<1>(w).count(), 4);
assert_eq!(a_query.iter_combinations::<1>(w).size_hint().1, Some(4));
assert_eq!(a_query.iter_combinations::<2>(w).count(), 6);
assert_eq!(a_query.iter_combinations::<2>(w).size_hint().1, Some(6));
assert_eq!(a_query.iter_combinations::<3>(w).count(), 4);
assert_eq!(a_query.iter_combinations::<3>(w).size_hint().1, Some(4));
assert_eq!(a_query.iter_combinations::<4>(w).count(), 1);
assert_eq!(a_query.iter_combinations::<4>(w).size_hint().1, Some(1));
assert_eq!(a_query.iter_combinations::<5>(w).count(), 0);
assert_eq!(a_query.iter_combinations::<5>(w).size_hint().1, Some(0));
assert_eq!(a_query.iter_combinations::<128>(w).count(), 0);
assert_eq!(a_query.iter_combinations::<128>(w).size_hint().1, Some(0));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved those tests to query_filtered_exactsizeiterator_len, so that they can be deduplicated and tested on more stuff.


let values: Vec<[&A; 2]> = world.query::<&A>().iter_combinations(&world).collect();
assert_eq!(
values,
Expand All @@ -230,8 +199,7 @@ mod tests {
[&A(3), &A(4)],
]
);
let size = a_query.iter_combinations::<3>(&world).size_hint();
assert_eq!(size.1, Some(4));
let mut a_query = world.query::<&A>();
let values: Vec<[&A; 3]> = a_query.iter_combinations(&world).collect();
assert_eq!(
values,
Expand Down Expand Up @@ -282,40 +250,7 @@ mod tests {
world.spawn().insert_bundle((A(3),));
world.spawn().insert_bundle((A(4),));

let mut a_with_b = world.query_filtered::<&A, With<B>>();
let w = &world;
assert_eq!(a_with_b.iter_combinations::<0>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<0>(w).size_hint().1, Some(0));
assert_eq!(a_with_b.iter_combinations::<1>(w).count(), 1);
assert_eq!(a_with_b.iter_combinations::<1>(w).size_hint().1, Some(1));
assert_eq!(a_with_b.iter_combinations::<2>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<2>(w).size_hint().1, Some(0));
assert_eq!(a_with_b.iter_combinations::<3>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<3>(w).size_hint().1, Some(0));
assert_eq!(a_with_b.iter_combinations::<4>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<4>(w).size_hint().1, Some(0));
assert_eq!(a_with_b.iter_combinations::<5>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<5>(w).size_hint().1, Some(0));
assert_eq!(a_with_b.iter_combinations::<128>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<128>(w).size_hint().1, Some(0));

let mut a_wout_b = world.query_filtered::<&A, Without<B>>();
let w = &world;
assert_eq!(a_wout_b.iter_combinations::<0>(w).count(), 0);
assert_eq!(a_wout_b.iter_combinations::<0>(w).size_hint().1, Some(0));
assert_eq!(a_wout_b.iter_combinations::<1>(w).count(), 3);
assert_eq!(a_wout_b.iter_combinations::<1>(w).size_hint().1, Some(3));
assert_eq!(a_wout_b.iter_combinations::<2>(w).count(), 3);
assert_eq!(a_wout_b.iter_combinations::<2>(w).size_hint().1, Some(3));
assert_eq!(a_wout_b.iter_combinations::<3>(w).count(), 1);
assert_eq!(a_wout_b.iter_combinations::<3>(w).size_hint().1, Some(1));
assert_eq!(a_wout_b.iter_combinations::<4>(w).count(), 0);
assert_eq!(a_wout_b.iter_combinations::<4>(w).size_hint().1, Some(0));
assert_eq!(a_wout_b.iter_combinations::<5>(w).count(), 0);
assert_eq!(a_wout_b.iter_combinations::<5>(w).size_hint().1, Some(0));
assert_eq!(a_wout_b.iter_combinations::<128>(w).count(), 0);
assert_eq!(a_wout_b.iter_combinations::<128>(w).size_hint().1, Some(0));

let values: HashSet<[&A; 2]> = a_wout_b.iter_combinations(&world).collect();
assert_eq!(
values,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use bevy_ecs::prelude::*;

#[derive(Component)]
struct Foo;
#[derive(Component)]
struct Bar;

fn on_changed(query: Query<&Foo, Or<(Changed<Foo>, With<Bar>)>>) {
// this should fail to compile
is_exact_size_iterator(query.iter_combinations::<2>());
}

fn on_added(query: Query<&Foo, (Added<Foo>, Without<Bar>)>) {
// this should fail to compile
is_exact_size_iterator(query.iter_combinations::<2>());
}

fn is_exact_size_iterator<T: ExactSizeIterator>(_iter: T) {}

fn main() {}
Loading