Skip to content

Commit f8787f8

Browse files
committed
Fix size_hint for partially consumed QueryIter
Instead of returning the total count of elements in the `QueryIter` in `size_hint`, we return the count of remaining elements in it. This Fixes bevyengine#5149. This is also true of `QueryCombinationIter`. - bevyengine#5149 - bevyengine#5148
1 parent 9e34c74 commit f8787f8

File tree

2 files changed

+51
-41
lines changed

2 files changed

+51
-41
lines changed

crates/bevy_ecs/src/query/iter.rs

+35-35
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,7 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Iterator for QueryIter<'w, 's, Q, F>
5656
}
5757

5858
fn size_hint(&self) -> (usize, Option<usize>) {
59-
let max_size = self
60-
.query_state
61-
.matched_archetype_ids
62-
.iter()
63-
.map(|id| self.archetypes[*id].len())
64-
.sum();
65-
59+
let max_size = self.cursor.remaining(self.tables, self.archetypes);
6660
let archetype_query = Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL;
6761
let min_size = if archetype_query { max_size } else { 0 };
6862
(min_size, Some(max_size))
@@ -333,11 +327,16 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> QueryCombinationIter<
333327
return None;
334328
}
335329

336-
// first, iterate from last to first until next item is found
330+
// TODO: can speed up the following code using `cursor.remaining()` instead of `next_item.is_none()`
331+
// when Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL
332+
//
333+
// let `i` be the index of `c`, the last cursor in `self.cursors` that
334+
// returns `K-i` or more elements.
335+
// Make cursor in index `j` for all `j` in `[i, K)` a copy of `c` advanced `j-i+1` times.
336+
// If no such `c` exists, return `None`
337337
'outer: for i in (0..K).rev() {
338338
match self.cursors[i].next(self.tables, self.archetypes, self.query_state) {
339339
Some(_) => {
340-
// walk forward up to last element, propagating cursor state forward
341340
for j in (i + 1)..K {
342341
self.cursors[j] = self.cursors[j - 1].clone();
343342
match self.cursors[j].next(self.tables, self.archetypes, self.query_state) {
@@ -398,36 +397,29 @@ where
398397
}
399398

400399
fn size_hint(&self) -> (usize, Option<usize>) {
401-
if K == 0 {
402-
return (0, Some(0));
403-
}
404-
405-
let max_size: usize = self
406-
.query_state
407-
.matched_archetype_ids
408-
.iter()
409-
.map(|id| self.archetypes[*id].len())
410-
.sum();
411-
412-
if max_size < K {
413-
return (0, Some(0));
414-
}
415-
if max_size == K {
416-
return (1, Some(1));
417-
}
418-
419400
// binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
420401
// See https://en.wikipedia.org/wiki/Binomial_coefficient
421402
// See https://blog.plover.com/math/choose.html for implementation
422403
// It was chosen to reduce overflow potential.
423404
fn choose(n: usize, k: usize) -> Option<usize> {
405+
if k > n || n == 0 {
406+
return Some(0);
407+
}
408+
let k = k.min(n - k);
424409
let ks = 1..=k;
425-
let ns = (n - k + 1..=n).rev();
410+
let ns = (n + 1 - k..=n).rev();
426411
ks.zip(ns)
427412
.try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k))
428413
}
429-
let smallest = K.min(max_size - K);
430-
let max_combinations = choose(max_size, smallest);
414+
// sum_i=0..k choose(cursors[i].remaining, k-i)
415+
let max_combinations = self
416+
.cursors
417+
.iter()
418+
.enumerate()
419+
.try_fold(0, |acc, (i, cursor)| {
420+
let n = cursor.remaining(self.tables, self.archetypes);
421+
Some(acc + choose(n, K - i)?)
422+
});
431423

432424
let archetype_query = F::IS_ARCHETYPAL && Q::IS_ARCHETYPAL;
433425
let known_max = max_combinations.unwrap_or(usize::MAX);
@@ -441,11 +433,7 @@ where
441433
F: ArchetypeFilter,
442434
{
443435
fn len(&self) -> usize {
444-
self.query_state
445-
.matched_archetype_ids
446-
.iter()
447-
.map(|id| self.archetypes[*id].len())
448-
.sum()
436+
self.size_hint().0
449437
}
450438
}
451439

@@ -562,6 +550,18 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> QueryIterationCursor<'w, 's, Q, F> {
562550
}
563551
}
564552

553+
/// How many values will this cursor return?
554+
fn remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize {
555+
let remaining_matched: usize = if Self::IS_DENSE {
556+
let ids = self.table_id_iter.clone();
557+
ids.map(|id| tables[*id].len()).sum()
558+
} else {
559+
let ids = self.archetype_id_iter.clone();
560+
ids.map(|id| archetypes[*id].len()).sum()
561+
};
562+
remaining_matched + self.current_len - self.current_index
563+
}
564+
565565
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
566566
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
567567
/// # Safety

crates/bevy_ecs/src/query/mod.rs

+16-6
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,13 @@ mod tests {
7474
for<'w> QueryFetch<'w, F::ReadOnly>: Clone,
7575
{
7676
let mut query = world.query_filtered::<Q, F>();
77-
let iter = query.iter_combinations::<K>(world);
7877
let query_type = type_name::<QueryCombinationIter<Q, F, K>>();
79-
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
78+
let iter = query.iter_combinations::<K>(world);
79+
assert_all_sizes_iterator_equal(iter, expected_size, 0, query_type);
80+
let iter = query.iter_combinations::<K>(world);
81+
assert_all_sizes_iterator_equal(iter, expected_size, 1, query_type);
82+
let iter = query.iter_combinations::<K>(world);
83+
assert_all_sizes_iterator_equal(iter, expected_size, 5, query_type);
8084
}
8185
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
8286
where
@@ -87,23 +91,29 @@ mod tests {
8791
for<'w> QueryFetch<'w, F::ReadOnly>: Clone,
8892
{
8993
let mut query = world.query_filtered::<Q, F>();
90-
let iter = query.iter(world);
9194
let query_type = type_name::<QueryState<Q, F>>();
92-
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
95+
assert_all_sizes_iterator_equal(query.iter(world), expected_size, 0, query_type);
96+
assert_all_sizes_iterator_equal(query.iter(world), expected_size, 1, query_type);
97+
assert_all_sizes_iterator_equal(query.iter(world), expected_size, 5, query_type);
9398

9499
let expected = expected_size;
95100
assert_combination::<Q, F, 0>(world, choose(expected, 0));
96101
assert_combination::<Q, F, 1>(world, choose(expected, 1));
97102
assert_combination::<Q, F, 2>(world, choose(expected, 2));
98103
assert_combination::<Q, F, 5>(world, choose(expected, 5));
99104
assert_combination::<Q, F, 43>(world, choose(expected, 43));
100-
assert_combination::<Q, F, 128>(world, choose(expected, 128));
105+
assert_combination::<Q, F, 64>(world, choose(expected, 64));
101106
}
102107
fn assert_all_sizes_iterator_equal(
103-
iterator: impl ExactSizeIterator,
108+
mut iterator: impl ExactSizeIterator,
104109
expected_size: usize,
110+
skip: usize,
105111
query_type: &'static str,
106112
) {
113+
let expected_size = expected_size.saturating_sub(skip);
114+
for _ in 0..skip {
115+
iterator.next();
116+
}
107117
let size_hint_0 = iterator.size_hint().0;
108118
let size_hint_1 = iterator.size_hint().1;
109119
let len = iterator.len();

0 commit comments

Comments
 (0)