Skip to content

Commit 6c06fc5

Browse files
committed
Add ExactSizeIterator implementation for QueryCombinatonIter (#5148)
Following #5124 I decided to add the `ExactSizeIterator` impl for `QueryCombinationIter`. Also: - Clean up the tests for `size_hint` and `len` for both the normal `QueryIter` and `QueryCombinationIter`. - Add tests to `QueryCombinationIter` when it shouldn't be `ExactSizeIterator` --- ## Changelog - Added `ExactSizeIterator` implementation for `QueryCombinatonIter`
1 parent 1dbb1f7 commit 6c06fc5

File tree

4 files changed

+204
-174
lines changed

4 files changed

+204
-174
lines changed

crates/bevy_ecs/src/query/iter.rs

+32-8
Original file line numberDiff line numberDiff line change
@@ -343,17 +343,26 @@ where
343343
if max_size < K {
344344
return (0, Some(0));
345345
}
346+
if max_size == K {
347+
return (1, Some(1));
348+
}
346349

347-
// n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
348-
let max_combinations = (0..K)
349-
.try_fold(1usize, |n, i| n.checked_mul(max_size - i))
350-
.map(|n| {
351-
let k_factorial: usize = (1..=K).product();
352-
n / k_factorial
353-
});
350+
// binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
351+
// See https://en.wikipedia.org/wiki/Binomial_coefficient
352+
// See https://blog.plover.com/math/choose.html for implementation
353+
// It was chosen to reduce overflow potential.
354+
fn choose(n: usize, k: usize) -> Option<usize> {
355+
let ks = 1..=k;
356+
let ns = (n - k + 1..=n).rev();
357+
ks.zip(ns)
358+
.try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k))
359+
}
360+
let smallest = K.min(max_size - K);
361+
let max_combinations = choose(max_size, smallest);
354362

355363
let archetype_query = F::Fetch::IS_ARCHETYPAL && Q::Fetch::IS_ARCHETYPAL;
356-
let min_combinations = if archetype_query { max_size } else { 0 };
364+
let known_max = max_combinations.unwrap_or(usize::MAX);
365+
let min_combinations = if archetype_query { known_max } else { 0 };
357366
(min_combinations, max_combinations)
358367
}
359368
}
@@ -372,6 +381,21 @@ where
372381
}
373382
}
374383

384+
impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery + ArchetypeFilter, const K: usize>
385+
ExactSizeIterator for QueryCombinationIter<'w, 's, Q, F, K>
386+
where
387+
QueryFetch<'w, Q>: Clone,
388+
QueryFetch<'w, F>: Clone,
389+
{
390+
/// Returns the exact length of the iterator.
391+
///
392+
/// **NOTE**: When the iterator length overflows `usize`, this will
393+
/// return `usize::MAX`.
394+
fn len(&self) -> usize {
395+
self.size_hint().0
396+
}
397+
}
398+
375399
// This is correct as [`QueryCombinationIter`] always returns `None` once exhausted.
376400
impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery, const K: usize> FusedIterator
377401
for QueryCombinationIter<'w, 's, Q, F, K>

crates/bevy_ecs/src/query/mod.rs

+101-166
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ pub(crate) unsafe fn debug_checked_unreachable() -> ! {
2020
#[cfg(test)]
2121
mod tests {
2222
use super::WorldQuery;
23-
use crate::prelude::{AnyOf, Entity, Or, With, Without};
23+
use crate::prelude::{AnyOf, Entity, Or, QueryState, With, Without};
24+
use crate::query::{ArchetypeFilter, QueryCombinationIter, QueryFetch, ReadOnlyWorldQuery};
2425
use crate::system::{IntoSystem, Query, System};
2526
use crate::{self as bevy_ecs, component::Component, world::World};
27+
use std::any::type_name;
2628
use std::collections::HashSet;
2729

2830
#[derive(Component, Debug, Hash, Eq, PartialEq, Clone, Copy)]
@@ -54,24 +56,81 @@ mod tests {
5456
}
5557

5658
#[test]
57-
fn query_filtered_len() {
59+
fn query_filtered_exactsizeiterator_len() {
60+
fn choose(n: usize, k: usize) -> usize {
61+
if n == 0 || k == 0 || n < k {
62+
return 0;
63+
}
64+
let ks = 1..=k;
65+
let ns = (n - k + 1..=n).rev();
66+
ks.zip(ns).fold(1, |acc, (k, n)| acc * n / k)
67+
}
68+
fn assert_combination<Q, F, const K: usize>(world: &mut World, expected_size: usize)
69+
where
70+
Q: ReadOnlyWorldQuery,
71+
F: ReadOnlyWorldQuery + ArchetypeFilter,
72+
for<'w> QueryFetch<'w, Q>: Clone,
73+
for<'w> QueryFetch<'w, F>: Clone,
74+
{
75+
let mut query = world.query_filtered::<Q, F>();
76+
let iter = query.iter_combinations::<K>(world);
77+
let query_type = type_name::<QueryCombinationIter<Q, F, K>>();
78+
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
79+
}
80+
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
81+
where
82+
Q: ReadOnlyWorldQuery,
83+
F: ReadOnlyWorldQuery + ArchetypeFilter,
84+
for<'w> QueryFetch<'w, Q>: Clone,
85+
for<'w> QueryFetch<'w, F>: Clone,
86+
{
87+
let mut query = world.query_filtered::<Q, F>();
88+
let iter = query.iter(world);
89+
let query_type = type_name::<QueryState<Q, F>>();
90+
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
91+
92+
let expected = expected_size;
93+
assert_combination::<Q, F, 0>(world, choose(expected, 0));
94+
assert_combination::<Q, F, 1>(world, choose(expected, 1));
95+
assert_combination::<Q, F, 2>(world, choose(expected, 2));
96+
assert_combination::<Q, F, 5>(world, choose(expected, 5));
97+
assert_combination::<Q, F, 43>(world, choose(expected, 43));
98+
assert_combination::<Q, F, 128>(world, choose(expected, 128));
99+
}
100+
fn assert_all_sizes_iterator_equal(
101+
iterator: impl ExactSizeIterator,
102+
expected_size: usize,
103+
query_type: &'static str,
104+
) {
105+
let size_hint_0 = iterator.size_hint().0;
106+
let size_hint_1 = iterator.size_hint().1;
107+
let len = iterator.len();
108+
// `count` tests that not only it is the expected value, but also
109+
// the value is accurate to what the query returns.
110+
let count = iterator.count();
111+
// This will show up when one of the asserts in this function fails
112+
println!(
113+
r#"query declared sizes:
114+
for query: {query_type}
115+
expected: {expected_size}
116+
len(): {len}
117+
size_hint().0: {size_hint_0}
118+
size_hint().1: {size_hint_1:?}
119+
count(): {count}"#
120+
);
121+
assert_eq!(len, expected_size);
122+
assert_eq!(size_hint_0, expected_size);
123+
assert_eq!(size_hint_1, Some(expected_size));
124+
assert_eq!(count, expected_size);
125+
}
126+
58127
let mut world = World::new();
59128
world.spawn().insert_bundle((A(1), B(1)));
60129
world.spawn().insert_bundle((A(2),));
61130
world.spawn().insert_bundle((A(3),));
62131

63-
let mut values = world.query_filtered::<&A, With<B>>();
64-
let n = 1;
65-
assert_eq!(values.iter(&world).size_hint().0, n);
66-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
67-
assert_eq!(values.iter(&world).len(), n);
68-
assert_eq!(values.iter(&world).count(), n);
69-
let mut values = world.query_filtered::<&A, Without<B>>();
70-
let n = 2;
71-
assert_eq!(values.iter(&world).size_hint().0, n);
72-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
73-
assert_eq!(values.iter(&world).len(), n);
74-
assert_eq!(values.iter(&world).count(), n);
132+
assert_all_sizes_equal::<&A, With<B>>(&mut world, 1);
133+
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 2);
75134

76135
let mut world = World::new();
77136
world.spawn().insert_bundle((A(1), B(1), C(1)));
@@ -86,110 +145,37 @@ mod tests {
86145
world.spawn().insert_bundle((A(10),));
87146

88147
// With/Without for B and C
89-
let mut values = world.query_filtered::<&A, With<B>>();
90-
let n = 3;
91-
assert_eq!(values.iter(&world).size_hint().0, n);
92-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
93-
assert_eq!(values.iter(&world).len(), n);
94-
assert_eq!(values.iter(&world).count(), n);
95-
let mut values = world.query_filtered::<&A, With<C>>();
96-
let n = 4;
97-
assert_eq!(values.iter(&world).size_hint().0, n);
98-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
99-
assert_eq!(values.iter(&world).len(), n);
100-
assert_eq!(values.iter(&world).count(), n);
101-
let mut values = world.query_filtered::<&A, Without<B>>();
102-
let n = 7;
103-
assert_eq!(values.iter(&world).size_hint().0, n);
104-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
105-
assert_eq!(values.iter(&world).len(), n);
106-
assert_eq!(values.iter(&world).count(), n);
107-
let mut values = world.query_filtered::<&A, Without<C>>();
108-
let n = 6;
109-
assert_eq!(values.iter(&world).size_hint().0, n);
110-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
111-
assert_eq!(values.iter(&world).len(), n);
112-
assert_eq!(values.iter(&world).count(), n);
148+
assert_all_sizes_equal::<&A, With<B>>(&mut world, 3);
149+
assert_all_sizes_equal::<&A, With<C>>(&mut world, 4);
150+
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 7);
151+
assert_all_sizes_equal::<&A, Without<C>>(&mut world, 6);
113152

114153
// With/Without (And) combinations
115-
let mut values = world.query_filtered::<&A, (With<B>, With<C>)>();
116-
let n = 1;
117-
assert_eq!(values.iter(&world).size_hint().0, n);
118-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
119-
assert_eq!(values.iter(&world).len(), n);
120-
assert_eq!(values.iter(&world).count(), n);
121-
let mut values = world.query_filtered::<&A, (With<B>, Without<C>)>();
122-
let n = 2;
123-
assert_eq!(values.iter(&world).size_hint().0, n);
124-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
125-
assert_eq!(values.iter(&world).len(), n);
126-
assert_eq!(values.iter(&world).count(), n);
127-
let mut values = world.query_filtered::<&A, (Without<B>, With<C>)>();
128-
let n = 3;
129-
assert_eq!(values.iter(&world).size_hint().0, n);
130-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
131-
assert_eq!(values.iter(&world).len(), n);
132-
assert_eq!(values.iter(&world).count(), n);
133-
let mut values = world.query_filtered::<&A, (Without<B>, Without<C>)>();
134-
let n = 4;
135-
assert_eq!(values.iter(&world).size_hint().0, n);
136-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
137-
assert_eq!(values.iter(&world).len(), n);
138-
assert_eq!(values.iter(&world).count(), n);
154+
assert_all_sizes_equal::<&A, (With<B>, With<C>)>(&mut world, 1);
155+
assert_all_sizes_equal::<&A, (With<B>, Without<C>)>(&mut world, 2);
156+
assert_all_sizes_equal::<&A, (Without<B>, With<C>)>(&mut world, 3);
157+
assert_all_sizes_equal::<&A, (Without<B>, Without<C>)>(&mut world, 4);
139158

140159
// With/Without Or<()> combinations
141-
let mut values = world.query_filtered::<&A, Or<(With<B>, With<C>)>>();
142-
let n = 6;
143-
assert_eq!(values.iter(&world).size_hint().0, n);
144-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
145-
assert_eq!(values.iter(&world).len(), n);
146-
assert_eq!(values.iter(&world).count(), n);
147-
let mut values = world.query_filtered::<&A, Or<(With<B>, Without<C>)>>();
148-
let n = 7;
149-
assert_eq!(values.iter(&world).size_hint().0, n);
150-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
151-
assert_eq!(values.iter(&world).len(), n);
152-
assert_eq!(values.iter(&world).count(), n);
153-
let mut values = world.query_filtered::<&A, Or<(Without<B>, With<C>)>>();
154-
let n = 8;
155-
assert_eq!(values.iter(&world).size_hint().0, n);
156-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
157-
assert_eq!(values.iter(&world).len(), n);
158-
assert_eq!(values.iter(&world).count(), n);
159-
let mut values = world.query_filtered::<&A, Or<(Without<B>, Without<C>)>>();
160-
let n = 9;
161-
assert_eq!(values.iter(&world).size_hint().0, n);
162-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
163-
assert_eq!(values.iter(&world).len(), n);
164-
assert_eq!(values.iter(&world).count(), n);
165-
166-
let mut values = world.query_filtered::<&A, (Or<(With<B>,)>, Or<(With<C>,)>)>();
167-
let n = 1;
168-
assert_eq!(values.iter(&world).size_hint().0, n);
169-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
170-
assert_eq!(values.iter(&world).len(), n);
171-
assert_eq!(values.iter(&world).count(), n);
172-
let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>();
173-
let n = 6;
174-
assert_eq!(values.iter(&world).size_hint().0, n);
175-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
176-
assert_eq!(values.iter(&world).len(), n);
177-
assert_eq!(values.iter(&world).count(), n);
178-
179-
world.spawn().insert_bundle((A(11), D(11)));
180-
181-
let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>();
182-
let n = 7;
183-
assert_eq!(values.iter(&world).size_hint().0, n);
184-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
185-
assert_eq!(values.iter(&world).len(), n);
186-
assert_eq!(values.iter(&world).count(), n);
187-
let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, Without<D>)>>();
188-
let n = 10;
189-
assert_eq!(values.iter(&world).size_hint().0, n);
190-
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
191-
assert_eq!(values.iter(&world).len(), n);
192-
assert_eq!(values.iter(&world).count(), n);
160+
assert_all_sizes_equal::<&A, Or<(With<B>, With<C>)>>(&mut world, 6);
161+
assert_all_sizes_equal::<&A, Or<(With<B>, Without<C>)>>(&mut world, 7);
162+
assert_all_sizes_equal::<&A, Or<(Without<B>, With<C>)>>(&mut world, 8);
163+
assert_all_sizes_equal::<&A, Or<(Without<B>, Without<C>)>>(&mut world, 9);
164+
assert_all_sizes_equal::<&A, (Or<(With<B>,)>, Or<(With<C>,)>)>(&mut world, 1);
165+
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 6);
166+
167+
for i in 11..14 {
168+
world.spawn().insert_bundle((A(i), D(i)));
169+
}
170+
171+
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 9);
172+
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, Without<D>)>>(&mut world, 10);
173+
174+
// a fair amount of entities
175+
for i in 14..20 {
176+
world.spawn().insert_bundle((C(i), D(i)));
177+
}
178+
assert_all_sizes_equal::<Entity, (With<C>, With<D>)>(&mut world, 6);
193179
}
194180

195181
#[test]
@@ -201,23 +187,6 @@ mod tests {
201187
world.spawn().insert_bundle((A(3),));
202188
world.spawn().insert_bundle((A(4),));
203189

204-
let mut a_query = world.query::<&A>();
205-
let w = &world;
206-
assert_eq!(a_query.iter_combinations::<0>(w).count(), 0);
207-
assert_eq!(a_query.iter_combinations::<0>(w).size_hint().1, Some(0));
208-
assert_eq!(a_query.iter_combinations::<1>(w).count(), 4);
209-
assert_eq!(a_query.iter_combinations::<1>(w).size_hint().1, Some(4));
210-
assert_eq!(a_query.iter_combinations::<2>(w).count(), 6);
211-
assert_eq!(a_query.iter_combinations::<2>(w).size_hint().1, Some(6));
212-
assert_eq!(a_query.iter_combinations::<3>(w).count(), 4);
213-
assert_eq!(a_query.iter_combinations::<3>(w).size_hint().1, Some(4));
214-
assert_eq!(a_query.iter_combinations::<4>(w).count(), 1);
215-
assert_eq!(a_query.iter_combinations::<4>(w).size_hint().1, Some(1));
216-
assert_eq!(a_query.iter_combinations::<5>(w).count(), 0);
217-
assert_eq!(a_query.iter_combinations::<5>(w).size_hint().1, Some(0));
218-
assert_eq!(a_query.iter_combinations::<128>(w).count(), 0);
219-
assert_eq!(a_query.iter_combinations::<128>(w).size_hint().1, Some(0));
220-
221190
let values: Vec<[&A; 2]> = world.query::<&A>().iter_combinations(&world).collect();
222191
assert_eq!(
223192
values,
@@ -230,8 +199,7 @@ mod tests {
230199
[&A(3), &A(4)],
231200
]
232201
);
233-
let size = a_query.iter_combinations::<3>(&world).size_hint();
234-
assert_eq!(size.1, Some(4));
202+
let mut a_query = world.query::<&A>();
235203
let values: Vec<[&A; 3]> = a_query.iter_combinations(&world).collect();
236204
assert_eq!(
237205
values,
@@ -282,40 +250,7 @@ mod tests {
282250
world.spawn().insert_bundle((A(3),));
283251
world.spawn().insert_bundle((A(4),));
284252

285-
let mut a_with_b = world.query_filtered::<&A, With<B>>();
286-
let w = &world;
287-
assert_eq!(a_with_b.iter_combinations::<0>(w).count(), 0);
288-
assert_eq!(a_with_b.iter_combinations::<0>(w).size_hint().1, Some(0));
289-
assert_eq!(a_with_b.iter_combinations::<1>(w).count(), 1);
290-
assert_eq!(a_with_b.iter_combinations::<1>(w).size_hint().1, Some(1));
291-
assert_eq!(a_with_b.iter_combinations::<2>(w).count(), 0);
292-
assert_eq!(a_with_b.iter_combinations::<2>(w).size_hint().1, Some(0));
293-
assert_eq!(a_with_b.iter_combinations::<3>(w).count(), 0);
294-
assert_eq!(a_with_b.iter_combinations::<3>(w).size_hint().1, Some(0));
295-
assert_eq!(a_with_b.iter_combinations::<4>(w).count(), 0);
296-
assert_eq!(a_with_b.iter_combinations::<4>(w).size_hint().1, Some(0));
297-
assert_eq!(a_with_b.iter_combinations::<5>(w).count(), 0);
298-
assert_eq!(a_with_b.iter_combinations::<5>(w).size_hint().1, Some(0));
299-
assert_eq!(a_with_b.iter_combinations::<128>(w).count(), 0);
300-
assert_eq!(a_with_b.iter_combinations::<128>(w).size_hint().1, Some(0));
301-
302253
let mut a_wout_b = world.query_filtered::<&A, Without<B>>();
303-
let w = &world;
304-
assert_eq!(a_wout_b.iter_combinations::<0>(w).count(), 0);
305-
assert_eq!(a_wout_b.iter_combinations::<0>(w).size_hint().1, Some(0));
306-
assert_eq!(a_wout_b.iter_combinations::<1>(w).count(), 3);
307-
assert_eq!(a_wout_b.iter_combinations::<1>(w).size_hint().1, Some(3));
308-
assert_eq!(a_wout_b.iter_combinations::<2>(w).count(), 3);
309-
assert_eq!(a_wout_b.iter_combinations::<2>(w).size_hint().1, Some(3));
310-
assert_eq!(a_wout_b.iter_combinations::<3>(w).count(), 1);
311-
assert_eq!(a_wout_b.iter_combinations::<3>(w).size_hint().1, Some(1));
312-
assert_eq!(a_wout_b.iter_combinations::<4>(w).count(), 0);
313-
assert_eq!(a_wout_b.iter_combinations::<4>(w).size_hint().1, Some(0));
314-
assert_eq!(a_wout_b.iter_combinations::<5>(w).count(), 0);
315-
assert_eq!(a_wout_b.iter_combinations::<5>(w).size_hint().1, Some(0));
316-
assert_eq!(a_wout_b.iter_combinations::<128>(w).count(), 0);
317-
assert_eq!(a_wout_b.iter_combinations::<128>(w).size_hint().1, Some(0));
318-
319254
let values: HashSet<[&A; 2]> = a_wout_b.iter_combinations(&world).collect();
320255
assert_eq!(
321256
values,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use bevy_ecs::prelude::*;
2+
3+
#[derive(Component)]
4+
struct Foo;
5+
#[derive(Component)]
6+
struct Bar;
7+
8+
fn on_changed(query: Query<&Foo, Or<(Changed<Foo>, With<Bar>)>>) {
9+
// this should fail to compile
10+
is_exact_size_iterator(query.iter_combinations::<2>());
11+
}
12+
13+
fn on_added(query: Query<&Foo, (Added<Foo>, Without<Bar>)>) {
14+
// this should fail to compile
15+
is_exact_size_iterator(query.iter_combinations::<2>());
16+
}
17+
18+
fn is_exact_size_iterator<T: ExactSizeIterator>(_iter: T) {}
19+
20+
fn main() {}

0 commit comments

Comments
 (0)