diff --git a/console-subscriber/tests/framework.rs b/console-subscriber/tests/framework.rs index 855f778ac..719adc3e8 100644 --- a/console-subscriber/tests/framework.rs +++ b/console-subscriber/tests/framework.rs @@ -182,3 +182,24 @@ fn fail_1_of_2_expected_tasks() { assert_tasks(expected_tasks, future); } + +#[test] +fn polls() { + let expected_task = ExpectedTask::default().match_default_name().expect_polls(2); + + let future = async { task::yield_now().await }; + + assert_task(expected_task, future); +} + +#[test] +#[should_panic(expected = "Test failed: Task validation failed: + - Task { name=console-test::main }: expected `polls` to be 2, but actual was 1 +")] +fn fail_polls() { + let expected_task = ExpectedTask::default().match_default_name().expect_polls(2); + + let future = async {}; + + assert_task(expected_task, future); +} diff --git a/console-subscriber/tests/poll.rs b/console-subscriber/tests/poll.rs new file mode 100644 index 000000000..11593c201 --- /dev/null +++ b/console-subscriber/tests/poll.rs @@ -0,0 +1,41 @@ +use std::time::Duration; + +use tokio::time::sleep; + +mod support; +use support::{assert_task, ExpectedTask}; + +#[test] +fn single_poll() { + let expected_task = ExpectedTask::default().match_default_name().expect_polls(1); + + let future = futures::future::ready(()); + + assert_task(expected_task, future); +} + +#[test] +fn two_polls() { + let expected_task = ExpectedTask::default().match_default_name().expect_polls(2); + + let future = async { + sleep(Duration::ZERO).await; + }; + + assert_task(expected_task, future); +} + +#[test] +fn many_polls() { + let expected_task = ExpectedTask::default() + .match_default_name() + .expect_polls(11); + + let future = async { + for _ in 0..10 { + sleep(Duration::ZERO).await; + } + }; + + assert_task(expected_task, future); +} diff --git a/console-subscriber/tests/spawn.rs b/console-subscriber/tests/spawn.rs new file mode 100644 index 000000000..21ecaad41 --- /dev/null +++ b/console-subscriber/tests/spawn.rs @@ -0,0 +1,36 @@ +use std::time::Duration; + +use tokio::time::sleep; + +mod support; +use support::{assert_tasks, spawn_named, ExpectedTask}; + +/// This test asserts the behavior that was fixed in #440. Before that fix, +/// the polls of a child were also counted towards the parent (the task which +/// spawned the child task). In this scenario, that would result in the parent +/// having 3 polls counted, when it should really be 1. +#[test] +fn child_polls_dont_count_towards_parent_polls() { + let expected_tasks = vec![ + ExpectedTask::default() + .match_name("parent".into()) + .expect_polls(1), + ExpectedTask::default() + .match_name("child".into()) + .expect_polls(2), + ]; + + let future = async { + let child_join_handle = spawn_named("parent", async { + spawn_named("child", async { + sleep(Duration::ZERO).await; + }) + }) + .await + .expect("joining parent failed"); + + child_join_handle.await.expect("joining child failed"); + }; + + assert_tasks(expected_tasks, future); +} diff --git a/console-subscriber/tests/support/mod.rs b/console-subscriber/tests/support/mod.rs index 4937aff6a..3a42583a2 100644 --- a/console-subscriber/tests/support/mod.rs +++ b/console-subscriber/tests/support/mod.rs @@ -8,6 +8,7 @@ use subscriber::run_test; pub(crate) use subscriber::MAIN_TASK_NAME; pub(crate) use task::ExpectedTask; +use tokio::task::JoinHandle; /// Assert that an `expected_task` is recorded by a console-subscriber /// when driving the provided `future` to completion. @@ -45,3 +46,19 @@ where { run_test(expected_tasks, future) } + +/// Spawn a named task and unwrap. +/// +/// This is a convenience function to create a task with a name and then spawn +/// it directly (unwrapping the `Result` which the task builder API returns). +#[allow(dead_code)] +pub(crate) fn spawn_named(name: &str, f: Fut) -> JoinHandle<::Output> +where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, +{ + tokio::task::Builder::new() + .name(name) + .spawn(f) + .expect(&format!("spawning task '{name}' failed")) +} diff --git a/console-subscriber/tests/support/subscriber.rs b/console-subscriber/tests/support/subscriber.rs index ace48397d..54e5c995c 100644 --- a/console-subscriber/tests/support/subscriber.rs +++ b/console-subscriber/tests/support/subscriber.rs @@ -283,8 +283,7 @@ async fn record_actual_tasks( for (id, stats) in &task_update.stats_update { if let Some(task) = tasks.get_mut(id) { - task.wakes = stats.wakes; - task.self_wakes = stats.self_wakes; + task.update_from_stats(stats); } } } diff --git a/console-subscriber/tests/support/task.rs b/console-subscriber/tests/support/task.rs index 63814d016..b6952cf61 100644 --- a/console-subscriber/tests/support/task.rs +++ b/console-subscriber/tests/support/task.rs @@ -1,5 +1,7 @@ use std::{error, fmt}; +use console_api::tasks; + use super::MAIN_TASK_NAME; /// An actual task @@ -13,6 +15,7 @@ pub(super) struct ActualTask { pub(super) name: Option, pub(super) wakes: u64, pub(super) self_wakes: u64, + pub(super) polls: u64, } impl ActualTask { @@ -22,6 +25,15 @@ impl ActualTask { name: None, wakes: 0, self_wakes: 0, + polls: 0, + } + } + + pub(super) fn update_from_stats(&mut self, stats: &tasks::Stats) { + self.wakes = stats.wakes; + self.self_wakes = stats.self_wakes; + if let Some(poll_stats) = &stats.poll_stats { + self.polls = poll_stats.polls; } } } @@ -78,6 +90,7 @@ pub(crate) struct ExpectedTask { expect_present: Option, expect_wakes: Option, expect_self_wakes: Option, + expect_polls: Option, } impl Default for ExpectedTask { @@ -87,6 +100,7 @@ impl Default for ExpectedTask { expect_present: None, expect_wakes: None, expect_self_wakes: None, + expect_polls: None, } } } @@ -164,6 +178,21 @@ impl ExpectedTask { } } + if let Some(expected_polls) = self.expect_polls { + no_expectations = false; + if expected_polls != actual_task.polls { + return Err(TaskValidationFailure { + expected: self.clone(), + actual: Some(actual_task.clone()), + failure: format!( + "{self}: expected `polls` to be {expected_polls}, but \ + actual was {actual_polls}", + actual_polls = actual_task.polls, + ), + }); + } + } + if no_expectations { return Err(TaskValidationFailure { expected: self.clone(), @@ -229,6 +258,16 @@ impl ExpectedTask { self.expect_self_wakes = Some(self_wakes); self } + + /// Expects taht a task has a specific value for `polls`. + /// + /// To validate, the actual task must have a count of polls (on + /// `PollStats`) equal to `polls`. + #[allow(dead_code)] + pub(crate) fn expect_polls(mut self, polls: u64) -> Self { + self.expect_polls = Some(polls); + self + } } impl fmt::Display for ExpectedTask {