Skip to content

Commit 886622c

Browse files
committed
Remove task_pool parameter from par_for_each(_mut) (bevyengine#4705)
# Objective Fixes bevyengine#3183. Requiring a `&TaskPool` parameter is sort of meaningless if the only correct one is to use the one provided by `Res<ComputeTaskPool>` all the time. ## Solution Have `QueryState` save a clone of the `ComputeTaskPool` which is used for all `par_for_each` functions. ~~Adds a small overhead of the internal `Arc` clone as a part of the startup, but the ergonomics win should be well worth this hardly-noticable overhead.~~ Updated the docs to note that it will panic the task pool is not present as a resource. # Future Work If bevyengine/rfcs#54 is approved, we can replace these resource lookups with a static function call instead to get the `ComputeTaskPool`. --- ## Changelog Removed: The `task_pool` parameter of `Query(State)::par_for_each(_mut)`. These calls will use the `World`'s `ComputeTaskPool` resource instead. ## Migration Guide The `task_pool` parameter for `Query(State)::par_for_each(_mut)` has been removed. Remove these parameters from all calls to these functions. Before: ```rust fn parallel_system( task_pool: Res<ComputeTaskPool>, query: Query<&MyComponent>, ) { query.par_for_each(&task_pool, 32, |comp| { ... }); } ``` After: ```rust fn parallel_system(query: Query<&MyComponent>) { query.par_for_each(32, |comp| { ... }); } ``` If using `Query(State)` outside of a system run by the scheduler, you may need to manually configure and initialize a `ComputeTaskPool` as a resource in the `World`.
1 parent 36049e9 commit 886622c

File tree

5 files changed

+149
-125
lines changed

5 files changed

+149
-125
lines changed

benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use bevy_ecs::prelude::*;
2-
use bevy_tasks::TaskPool;
2+
use bevy_tasks::{ComputeTaskPool, TaskPool};
33
use glam::*;
44

55
#[derive(Component, Copy, Clone)]
@@ -29,8 +29,8 @@ impl Benchmark {
2929
)
3030
}));
3131

32-
fn sys(task_pool: Res<TaskPool>, mut query: Query<(&mut Position, &mut Transform)>) {
33-
query.par_for_each_mut(&task_pool, 128, |(mut pos, mut mat)| {
32+
fn sys(mut query: Query<(&mut Position, &mut Transform)>) {
33+
query.par_for_each_mut(128, |(mut pos, mut mat)| {
3434
for _ in 0..100 {
3535
mat.0 = mat.0.inverse();
3636
}
@@ -39,7 +39,7 @@ impl Benchmark {
3939
});
4040
}
4141

42-
world.insert_resource(TaskPool::default());
42+
world.insert_resource(ComputeTaskPool(TaskPool::default()));
4343
let mut system = IntoSystem::into_system(sys);
4444
system.initialize(&mut world);
4545
system.update_archetype_component_access(&world);

crates/bevy_ecs/src/lib.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ mod tests {
5959
query::{Added, ChangeTrackers, Changed, FilteredAccess, With, Without, WorldQuery},
6060
world::{Mut, World},
6161
};
62-
use bevy_tasks::TaskPool;
62+
use bevy_tasks::{ComputeTaskPool, TaskPool};
6363
use std::{
6464
any::TypeId,
6565
sync::{
@@ -376,7 +376,7 @@ mod tests {
376376
#[test]
377377
fn par_for_each_dense() {
378378
let mut world = World::new();
379-
let task_pool = TaskPool::default();
379+
world.insert_resource(ComputeTaskPool(TaskPool::default()));
380380
let e1 = world.spawn().insert(A(1)).id();
381381
let e2 = world.spawn().insert(A(2)).id();
382382
let e3 = world.spawn().insert(A(3)).id();
@@ -385,7 +385,7 @@ mod tests {
385385
let results = Arc::new(Mutex::new(Vec::new()));
386386
world
387387
.query::<(Entity, &A)>()
388-
.par_for_each(&world, &task_pool, 2, |(e, &A(i))| {
388+
.par_for_each(&world, 2, |(e, &A(i))| {
389389
results.lock().unwrap().push((e, i));
390390
});
391391
results.lock().unwrap().sort();
@@ -398,8 +398,7 @@ mod tests {
398398
#[test]
399399
fn par_for_each_sparse() {
400400
let mut world = World::new();
401-
402-
let task_pool = TaskPool::default();
401+
world.insert_resource(ComputeTaskPool(TaskPool::default()));
403402
let e1 = world.spawn().insert(SparseStored(1)).id();
404403
let e2 = world.spawn().insert(SparseStored(2)).id();
405404
let e3 = world.spawn().insert(SparseStored(3)).id();
@@ -408,7 +407,6 @@ mod tests {
408407
let results = Arc::new(Mutex::new(Vec::new()));
409408
world.query::<(Entity, &SparseStored)>().par_for_each(
410409
&world,
411-
&task_pool,
412410
2,
413411
|(e, &SparseStored(i))| results.lock().unwrap().push((e, i)),
414412
);

crates/bevy_ecs/src/query/state.rs

Lines changed: 120 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,18 @@ use crate::{
1010
storage::TableId,
1111
world::{World, WorldId},
1212
};
13-
use bevy_tasks::TaskPool;
13+
use bevy_tasks::{ComputeTaskPool, TaskPool};
1414
#[cfg(feature = "trace")]
1515
use bevy_utils::tracing::Instrument;
1616
use fixedbitset::FixedBitSet;
17-
use std::fmt;
17+
use std::{fmt, ops::Deref};
1818

1919
use super::{QueryFetch, QueryItem, ROQueryFetch, ROQueryItem};
2020

2121
/// Provides scoped access to a [`World`] state according to a given [`WorldQuery`] and query filter.
2222
pub struct QueryState<Q: WorldQuery, F: WorldQuery = ()> {
2323
world_id: WorldId,
24+
task_pool: Option<TaskPool>,
2425
pub(crate) archetype_generation: ArchetypeGeneration,
2526
pub(crate) matched_tables: FixedBitSet,
2627
pub(crate) matched_archetypes: FixedBitSet,
@@ -61,6 +62,9 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
6162

6263
let mut state = Self {
6364
world_id: world.id(),
65+
task_pool: world
66+
.get_resource::<ComputeTaskPool>()
67+
.map(|task_pool| task_pool.deref().clone()),
6468
archetype_generation: ArchetypeGeneration::initial(),
6569
matched_table_ids: Vec::new(),
6670
matched_archetype_ids: Vec::new(),
@@ -689,15 +693,18 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
689693
);
690694
}
691695

692-
/// Runs `func` on each query result in parallel using the given `task_pool`.
696+
/// Runs `func` on each query result in parallel.
693697
///
694698
/// This can only be called for read-only queries, see [`Self::par_for_each_mut`] for
695699
/// write-queries.
700+
///
701+
/// # Panics
702+
/// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query
703+
/// that is being initialized and run from the ECS scheduler, this should never panic.
696704
#[inline]
697705
pub fn par_for_each<'w, FN: Fn(ROQueryItem<'w, Q>) + Send + Sync + Clone>(
698706
&mut self,
699707
world: &'w World,
700-
task_pool: &TaskPool,
701708
batch_size: usize,
702709
func: FN,
703710
) {
@@ -706,7 +713,6 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
706713
self.update_archetypes(world);
707714
self.par_for_each_unchecked_manual::<ROQueryFetch<Q>, FN>(
708715
world,
709-
task_pool,
710716
batch_size,
711717
func,
712718
world.last_change_tick(),
@@ -715,12 +721,15 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
715721
}
716722
}
717723

718-
/// Runs `func` on each query result in parallel using the given `task_pool`.
724+
/// Runs `func` on each query result in parallel.
725+
///
726+
/// # Panics
727+
/// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query
728+
/// that is being initialized and run from the ECS scheduler, this should never panic.
719729
#[inline]
720730
pub fn par_for_each_mut<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(
721731
&mut self,
722732
world: &'w mut World,
723-
task_pool: &TaskPool,
724733
batch_size: usize,
725734
func: FN,
726735
) {
@@ -729,7 +738,6 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
729738
self.update_archetypes(world);
730739
self.par_for_each_unchecked_manual::<QueryFetch<Q>, FN>(
731740
world,
732-
task_pool,
733741
batch_size,
734742
func,
735743
world.last_change_tick(),
@@ -738,10 +746,14 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
738746
}
739747
}
740748

741-
/// Runs `func` on each query result in parallel using the given `task_pool`.
749+
/// Runs `func` on each query result in parallel.
742750
///
743751
/// This can only be called for read-only queries.
744752
///
753+
/// # Panics
754+
/// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query
755+
/// that is being initialized and run from the ECS scheduler, this should never panic.
756+
///
745757
/// # Safety
746758
///
747759
/// This does not check for mutable query correctness. To be safe, make sure mutable queries
@@ -750,14 +762,12 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
750762
pub unsafe fn par_for_each_unchecked<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(
751763
&mut self,
752764
world: &'w World,
753-
task_pool: &TaskPool,
754765
batch_size: usize,
755766
func: FN,
756767
) {
757768
self.update_archetypes(world);
758769
self.par_for_each_unchecked_manual::<QueryFetch<Q>, FN>(
759770
world,
760-
task_pool,
761771
batch_size,
762772
func,
763773
world.last_change_tick(),
@@ -833,6 +843,10 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
833843
/// the current change tick are given. This is faster than the equivalent
834844
/// iter() method, but cannot be chained like a normal [`Iterator`].
835845
///
846+
/// # Panics
847+
/// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query
848+
/// that is being initialized and run from the ECS scheduler, this should never panic.
849+
///
836850
/// # Safety
837851
///
838852
/// This does not check for mutable query correctness. To be safe, make sure mutable queries
@@ -846,103 +860,113 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
846860
>(
847861
&self,
848862
world: &'w World,
849-
task_pool: &TaskPool,
850863
batch_size: usize,
851864
func: FN,
852865
last_change_tick: u32,
853866
change_tick: u32,
854867
) {
855868
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
856869
// QueryIter, QueryIterationCursor, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
857-
task_pool.scope(|scope| {
858-
if QF::IS_DENSE && <QueryFetch<'static, F>>::IS_DENSE {
859-
let tables = &world.storages().tables;
860-
for table_id in &self.matched_table_ids {
861-
let table = &tables[*table_id];
862-
let mut offset = 0;
863-
while offset < table.len() {
864-
let func = func.clone();
865-
let len = batch_size.min(table.len() - offset);
866-
let task = async move {
867-
let mut fetch =
868-
QF::init(world, &self.fetch_state, last_change_tick, change_tick);
869-
let mut filter = <QueryFetch<F> as Fetch>::init(
870-
world,
871-
&self.filter_state,
872-
last_change_tick,
873-
change_tick,
874-
);
875-
let tables = &world.storages().tables;
876-
let table = &tables[*table_id];
877-
fetch.set_table(&self.fetch_state, table);
878-
filter.set_table(&self.filter_state, table);
879-
for table_index in offset..offset + len {
880-
if !filter.table_filter_fetch(table_index) {
881-
continue;
870+
self.task_pool
871+
.as_ref()
872+
.expect("Cannot iterate query in parallel. No ComputeTaskPool initialized.")
873+
.scope(|scope| {
874+
if QF::IS_DENSE && <QueryFetch<'static, F>>::IS_DENSE {
875+
let tables = &world.storages().tables;
876+
for table_id in &self.matched_table_ids {
877+
let table = &tables[*table_id];
878+
let mut offset = 0;
879+
while offset < table.len() {
880+
let func = func.clone();
881+
let len = batch_size.min(table.len() - offset);
882+
let task = async move {
883+
let mut fetch = QF::init(
884+
world,
885+
&self.fetch_state,
886+
last_change_tick,
887+
change_tick,
888+
);
889+
let mut filter = <QueryFetch<F> as Fetch>::init(
890+
world,
891+
&self.filter_state,
892+
last_change_tick,
893+
change_tick,
894+
);
895+
let tables = &world.storages().tables;
896+
let table = &tables[*table_id];
897+
fetch.set_table(&self.fetch_state, table);
898+
filter.set_table(&self.filter_state, table);
899+
for table_index in offset..offset + len {
900+
if !filter.table_filter_fetch(table_index) {
901+
continue;
902+
}
903+
let item = fetch.table_fetch(table_index);
904+
func(item);
882905
}
883-
let item = fetch.table_fetch(table_index);
884-
func(item);
885-
}
886-
};
887-
#[cfg(feature = "trace")]
888-
let span = bevy_utils::tracing::info_span!(
889-
"par_for_each",
890-
query = std::any::type_name::<Q>(),
891-
filter = std::any::type_name::<F>(),
892-
count = len,
893-
);
894-
#[cfg(feature = "trace")]
895-
let task = task.instrument(span);
896-
scope.spawn(task);
897-
offset += batch_size;
898-
}
899-
}
900-
} else {
901-
let archetypes = &world.archetypes;
902-
for archetype_id in &self.matched_archetype_ids {
903-
let mut offset = 0;
904-
let archetype = &archetypes[*archetype_id];
905-
while offset < archetype.len() {
906-
let func = func.clone();
907-
let len = batch_size.min(archetype.len() - offset);
908-
let task = async move {
909-
let mut fetch =
910-
QF::init(world, &self.fetch_state, last_change_tick, change_tick);
911-
let mut filter = <QueryFetch<F> as Fetch>::init(
912-
world,
913-
&self.filter_state,
914-
last_change_tick,
915-
change_tick,
906+
};
907+
#[cfg(feature = "trace")]
908+
let span = bevy_utils::tracing::info_span!(
909+
"par_for_each",
910+
query = std::any::type_name::<Q>(),
911+
filter = std::any::type_name::<F>(),
912+
count = len,
916913
);
917-
let tables = &world.storages().tables;
918-
let archetype = &world.archetypes[*archetype_id];
919-
fetch.set_archetype(&self.fetch_state, archetype, tables);
920-
filter.set_archetype(&self.filter_state, archetype, tables);
921-
922-
for archetype_index in offset..offset + len {
923-
if !filter.archetype_filter_fetch(archetype_index) {
924-
continue;
914+
#[cfg(feature = "trace")]
915+
let task = task.instrument(span);
916+
scope.spawn(task);
917+
offset += batch_size;
918+
}
919+
}
920+
} else {
921+
let archetypes = &world.archetypes;
922+
for archetype_id in &self.matched_archetype_ids {
923+
let mut offset = 0;
924+
let archetype = &archetypes[*archetype_id];
925+
while offset < archetype.len() {
926+
let func = func.clone();
927+
let len = batch_size.min(archetype.len() - offset);
928+
let task = async move {
929+
let mut fetch = QF::init(
930+
world,
931+
&self.fetch_state,
932+
last_change_tick,
933+
change_tick,
934+
);
935+
let mut filter = <QueryFetch<F> as Fetch>::init(
936+
world,
937+
&self.filter_state,
938+
last_change_tick,
939+
change_tick,
940+
);
941+
let tables = &world.storages().tables;
942+
let archetype = &world.archetypes[*archetype_id];
943+
fetch.set_archetype(&self.fetch_state, archetype, tables);
944+
filter.set_archetype(&self.filter_state, archetype, tables);
945+
946+
for archetype_index in offset..offset + len {
947+
if !filter.archetype_filter_fetch(archetype_index) {
948+
continue;
949+
}
950+
func(fetch.archetype_fetch(archetype_index));
925951
}
926-
func(fetch.archetype_fetch(archetype_index));
927-
}
928-
};
929-
930-
#[cfg(feature = "trace")]
931-
let span = bevy_utils::tracing::info_span!(
932-
"par_for_each",
933-
query = std::any::type_name::<Q>(),
934-
filter = std::any::type_name::<F>(),
935-
count = len,
936-
);
937-
#[cfg(feature = "trace")]
938-
let task = task.instrument(span);
939-
940-
scope.spawn(task);
941-
offset += batch_size;
952+
};
953+
954+
#[cfg(feature = "trace")]
955+
let span = bevy_utils::tracing::info_span!(
956+
"par_for_each",
957+
query = std::any::type_name::<Q>(),
958+
filter = std::any::type_name::<F>(),
959+
count = len,
960+
);
961+
#[cfg(feature = "trace")]
962+
let task = task.instrument(span);
963+
964+
scope.spawn(task);
965+
offset += batch_size;
966+
}
942967
}
943968
}
944-
}
945-
});
969+
});
946970
}
947971

948972
/// Returns a single immutable query result when there is exactly one entity matching

0 commit comments

Comments
 (0)