Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override QueryIter::fold to port Query::for_each perf gains to select Iterator combinators #6773

Merged
merged 30 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a2d6b13
Move for_each implementation onto Iterator::fold
james7132 Nov 26, 2022
7d17204
Fix starting fold from the middle of iteration
james7132 Nov 27, 2022
2c5af86
Simplify for_each cases and add safety docs
james7132 Nov 27, 2022
16b94dc
Undo test changes to empty.rs
james7132 Nov 27, 2022
fe681cb
Formatting
james7132 Nov 27, 2022
da27171
Merge branch 'main' into query-iter-fold
james7132 Dec 4, 2022
e0ea6c8
Merge branch 'main' into query-iter-fold
james7132 Jan 7, 2023
a11827f
Fix build
james7132 Jan 7, 2023
2abc327
Fix CI
james7132 Jan 7, 2023
abadb6f
Merge branch 'main' into query-iter-fold
james7132 Feb 17, 2023
1883adc
Revert QueryIterationCursor::fetch
james7132 Feb 17, 2023
c57481a
Deprecate the functions
james7132 Feb 18, 2023
65028bf
Apply suggestions from code review
james7132 Feb 19, 2023
2f997d0
Formatting
james7132 Feb 19, 2023
4bde3f4
Address deprecation
james7132 Feb 19, 2023
bca5e61
Update the safety comments.
james7132 Feb 19, 2023
df3d3bb
Address JoJoJet's comment
james7132 Feb 19, 2023
ca81c1c
make Query::for_each safer
james7132 Feb 19, 2023
dab9734
Remove mentions of QueryState::for_each_unchecked_manual
james7132 Feb 19, 2023
d166843
Merge branch 'main' into query-iter-fold
james7132 Mar 5, 2023
0d7d7da
Merge branch 'main' into query-iter-fold
james7132 Nov 26, 2023
d13e9a6
Update deprecation versions
james7132 Nov 26, 2023
144bb14
Fix CI and rename fold_*/for_each_*
james7132 Nov 26, 2023
624c3c5
Fix CI
james7132 Nov 26, 2023
94bc428
Fix UX tests
james7132 Nov 26, 2023
381bfbb
Fix deprecation notice for iter_mut().for_each
james7132 Nov 26, 2023
6307015
Merge branch 'main' into query-iter-fold
alice-i-cecile Nov 28, 2023
541f4c5
Inline for_each_in*
james7132 Dec 1, 2023
997ab70
Fix CI
james7132 Dec 1, 2023
3a9c92c
Remove the for_each_* functions when multithreading is not required.
james7132 Dec 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 176 additions & 4 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::{
archetype::{ArchetypeEntity, ArchetypeId, Archetypes},
archetype::{Archetype, ArchetypeEntity, ArchetypeId, Archetypes},
entity::{Entities, Entity},
prelude::World,
query::{ArchetypeFilter, DebugCheckedUnwrap, QueryState, WorldQuery},
storage::{TableId, TableRow, Tables},
storage::{Table, TableId, TableRow, Tables},
};
use std::{borrow::Borrow, iter::FusedIterator, marker::PhantomData, mem::MaybeUninit};
use std::{borrow::Borrow, iter::FusedIterator, marker::PhantomData, mem::MaybeUninit, ops::Range};

use super::ReadOnlyWorldQuery;

Expand Down Expand Up @@ -39,6 +39,146 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> QueryIter<'w, 's, Q, F> {
cursor: QueryIterationCursor::init(world, query_state, last_change_tick, change_tick),
}
}

/// Executes the equivalent of [`Iterator::fold`] over a contiguous segment
/// from an archetype.
///
/// # Safety
/// - all `rows` must be in `[0, tables.entity_count)`.
/// - `table` must match Q and F
/// - Both `Q::IS_DENSE` and `F::IS_DENSE` must be true.
#[inline]
pub(super) unsafe fn fold_table<B, Func>(
&mut self,
mut accum: B,
func: &mut Func,
table: &'w Table,
rows: Range<usize>,
) -> B
where
Func: FnMut(B, Q::Item<'w>) -> B,
{
Q::set_table(&mut self.cursor.fetch, &self.query_state.fetch_state, table);
F::set_table(
&mut self.cursor.filter,
&self.query_state.filter_state,
table,
);

let entities = table.entities();
for row in rows {
// SAFETY: set_table was called prior.
// `current_row` is a table row in range of the current table, because if it was not, then the if above would have been executed.
let entity = entities.get_unchecked(row);
let row = TableRow::new(row);
if !F::filter_fetch(&mut self.cursor.filter, *entity, row) {
james7132 marked this conversation as resolved.
Show resolved Hide resolved
continue;
}

// SAFETY: set_table was called prior.
// `current_row` is a table row in range of the current table, because if it was not, then the if above would have been executed.
let item = Q::fetch(&mut self.cursor.fetch, *entity, row);

accum = func(accum, item);
}
accum
}

/// Executes the equivalent of [`Iterator::fold`] over a contiguous segment
/// from an archetype.
///
/// # Safety
/// - all `indices` must be in `[0, archetype.len())`.
/// - `archetype` must match Q and F
/// - Either `Q::IS_DENSE` or `F::IS_DENSE` must be false.
#[inline]
pub(super) unsafe fn fold_archetype<B, Func>(
&mut self,
mut accum: B,
func: &mut Func,
archetype: &'w Archetype,
indices: Range<usize>,
) -> B
where
Func: FnMut(B, Q::Item<'w>) -> B,
{
let table = self.tables.get(archetype.table_id()).debug_checked_unwrap();
Q::set_archetype(
&mut self.cursor.fetch,
&self.query_state.fetch_state,
archetype,
table,
);
F::set_archetype(
&mut self.cursor.filter,
&self.query_state.filter_state,
archetype,
table,
);

let entities = archetype.entities();
for index in indices {
// SAFETY: set_archetype was called prior.
// `current_row` is an archetype index row in range of the current archetype, because if it was not, then the if above would have been executed.
let archetype_entity = entities.get_unchecked(index);
if !F::filter_fetch(
&mut self.cursor.filter,
archetype_entity.entity(),
archetype_entity.table_row(),
) {
continue;
}

// SAFETY: set_archetype was called prior, `current_row` is an archetype index in range of the current archetype
// `current_row` is an archetype index row in range of the current archetype, because if it was not, then the if above would have been executed.
let item = Q::fetch(
&mut self.cursor.fetch,
archetype_entity.entity(),
archetype_entity.table_row(),
);

accum = func(accum, item);
}
accum
}

/// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment
/// from an table.
///
/// # Safety
/// - all `rows` must be in `[0, tables.entity_count)`.
/// - `table` must match Q and F
/// - Both `Q::IS_DENSE` and `F::IS_DENSE` must be true.
#[inline]
pub(super) unsafe fn for_each_table<Func>(
&mut self,
func: &mut Func,
table: &'w Table,
rows: Range<usize>,
) where
Func: FnMut(Q::Item<'w>),
{
self.fold_table((), &mut |_, item| func(item), table, rows);
}

/// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment
/// from an archetype.
///
/// # Safety
/// - all `indices` must be in `[0, archetype.len())`.
/// - `archetype` must match Q and F
/// - Either `Q::IS_DENSE` or `F::IS_DENSE` must be false.
#[inline]
pub(super) unsafe fn for_each_archetype<Func>(
james7132 marked this conversation as resolved.
Show resolved Hide resolved
&mut self,
func: &mut Func,
archetype: &'w Archetype,
indices: Range<usize>,
) where
Func: FnMut(Q::Item<'w>),
{
self.fold_archetype((), &mut |_, item| func(item), archetype, indices);
}
}

impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> Iterator for QueryIter<'w, 's, Q, F> {
Expand All @@ -61,6 +201,38 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> Iterator for QueryIter<'w, 's
let min_size = if archetype_query { max_size } else { 0 };
(min_size, Some(max_size))
}

#[inline]
fn fold<B, Func>(mut self, init: B, mut func: Func) -> B
where
Func: FnMut(B, Self::Item) -> B,
{
let mut accum = init;
// Empty any remaining uniterated values from the current table/archetype
while self.cursor.current_row != self.cursor.current_len {
let Some(item) = self.next() else { break };
accum = func(accum, item);
}
if Q::IS_DENSE && F::IS_DENSE {
for table_id in self.cursor.table_id_iter.clone() {
// SAFETY: Matched table IDs are guaranteed to still exist.
let table = unsafe { self.tables.get(*table_id).debug_checked_unwrap() };
accum =
// SAFETY: The fetched table matches the query
unsafe { self.fold_table(accum, &mut func, table, 0..table.entity_count()) };
}
} else {
for archetype_id in self.cursor.archetype_id_iter.clone() {
let archetype =
// SAFETY: Matched archetype IDs are guaranteed to still exist.
unsafe { self.archetypes.get(*archetype_id).debug_checked_unwrap() };
accum =
// SAFETY: The fetched archetype and table matches the query
unsafe { self.fold_archetype(accum, &mut func, archetype, 0..archetype.len()) };
}
}
accum
}
}

// This is correct as [`QueryIter`] always returns `None` once exhausted.
Expand Down Expand Up @@ -623,9 +795,9 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> QueryIterationCursor<'w, 's,
if self.current_row == self.current_len {
let archetype_id = self.archetype_id_iter.next()?;
let archetype = archetypes.get(*archetype_id).debug_checked_unwrap();
let table = tables.get(archetype.table_id()).debug_checked_unwrap();
// SAFETY: `archetype` and `tables` are from the world that `fetch/filter` were created for,
// `fetch_state`/`filter_state` are the states that `fetch/filter` were initialized with
let table = tables.get(archetype.table_id()).debug_checked_unwrap();
Q::set_archetype(&mut self.fetch, &query_state.fetch_state, archetype, table);
F::set_archetype(
&mut self.filter,
Expand Down
13 changes: 7 additions & 6 deletions crates/bevy_ecs/src/query/par_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,13 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> QueryParIter<'w, 's, Q, F> {
) {
let thread_count = ComputeTaskPool::get().thread_num();
if thread_count <= 1 {
self.state.for_each_unchecked_manual(
self.world,
func,
self.world.last_change_tick(),
self.world.read_change_tick(),
);
self.state
.iter_unchecked_manual(
self.world,
self.world.last_change_tick(),
self.world.read_change_tick(),
)
.for_each(func);
} else {
// Need a batch size of at least 1.
let batch_size = self.get_batch_size(thread_count).max(1);
Expand Down
Loading