Skip to content

Commit a6b8197

Browse files
author
Abutalib Aghayev
committed
rt: add a method to retrieve task id
1 parent f464360 commit a6b8197

File tree

6 files changed

+117
-3
lines changed

6 files changed

+117
-3
lines changed

tokio/src/runtime/context.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::runtime::coop;
2+
use crate::runtime::task::Id;
23

34
use std::cell::Cell;
45

@@ -17,6 +18,7 @@ struct Context {
1718
/// Handle to the runtime scheduler running on the current thread.
1819
#[cfg(feature = "rt")]
1920
scheduler: RefCell<Option<scheduler::Handle>>,
21+
current_task_id: Cell<Option<Id>>,
2022

2123
#[cfg(any(feature = "rt", feature = "macros"))]
2224
rng: FastRand,
@@ -31,6 +33,7 @@ tokio_thread_local! {
3133
Context {
3234
#[cfg(feature = "rt")]
3335
scheduler: RefCell::new(None),
36+
current_task_id: Cell::new(None),
3437

3538
#[cfg(any(feature = "rt", feature = "macros"))]
3639
rng: FastRand::new(RngSeed::new()),
@@ -85,6 +88,17 @@ cfg_rt! {
8588

8689
pub(crate) struct DisallowBlockInPlaceGuard(bool);
8790

91+
pub(crate) fn set_current_task_id(id: Option<Id>) {
92+
CONTEXT.with(|ctx| ctx.current_task_id.replace(id));
93+
}
94+
95+
pub(crate) fn current_task_id() -> Id {
96+
match CONTEXT.try_with(|ctx| ctx.current_task_id.get()) {
97+
Ok(Some(id)) => id,
98+
_ => panic!("tried to get task id from outside of the task"),
99+
}
100+
}
101+
88102
pub(crate) fn try_current() -> Result<scheduler::Handle, TryCurrentError> {
89103
match CONTEXT.try_with(|ctx| ctx.scheduler.borrow().clone()) {
90104
Ok(Some(handle)) => Ok(handle),

tokio/src/runtime/task/harness.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use crate::future::Future;
2+
use crate::runtime::context;
23
use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Trailer};
34
use crate::runtime::task::state::Snapshot;
45
use crate::runtime::task::waker::waker_ref;
5-
use crate::runtime::task::{JoinError, Notified, Schedule, Task};
6+
use crate::runtime::task::{Id, JoinError, Notified, Schedule, Task};
67

78
use std::mem;
89
use std::mem::ManuallyDrop;
@@ -476,6 +477,19 @@ fn poll_future<T: Future, S: Schedule>(
476477
self.core.drop_future_or_output();
477478
}
478479
}
480+
struct TaskIdGuard {}
481+
impl TaskIdGuard {
482+
fn new(id: Id) -> Self {
483+
context::set_current_task_id(Some(id));
484+
TaskIdGuard {}
485+
}
486+
}
487+
impl Drop for TaskIdGuard {
488+
fn drop(&mut self) {
489+
context::set_current_task_id(None);
490+
}
491+
}
492+
let _task_id_guard = TaskIdGuard::new(id);
479493
let guard = Guard { core };
480494
let res = guard.core.poll(cx);
481495
mem::forget(guard);

tokio/src/runtime/task/mod.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,21 @@ use std::{fmt, mem};
201201
/// [unstable]: crate#unstable-features
202202
#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
203203
#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))]
204-
// TODO(eliza): there's almost certainly no reason not to make this `Copy` as well...
205-
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
204+
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
206205
pub struct Id(u64);
207206

207+
/// Returns the `Id` of the task.
208+
///
209+
/// # Panics
210+
///
211+
/// This function panics if called from outside a task.
212+
///
213+
#[allow(unreachable_pub)]
214+
pub fn id() -> Id {
215+
use crate::runtime::context;
216+
context::current_task_id()
217+
}
218+
208219
/// An owned handle to the task, tracked by ref count.
209220
#[repr(transparent)]
210221
pub(crate) struct Task<S: 'static> {

tokio/src/task/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ cfg_rt! {
319319

320320
cfg_unstable! {
321321
pub use crate::runtime::task::Id;
322+
pub use crate::runtime::task::id;
322323
}
323324

324325
cfg_trace! {

tokio/tests/task_local.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,64 @@ async fn task_local_available_on_completion_drop() {
116116
assert_eq!(rx.await.unwrap(), 42);
117117
h.await.unwrap();
118118
}
119+
120+
#[tokio::test(flavor = "current_thread")]
121+
async fn task_id() {
122+
use tokio::task;
123+
124+
let handle = tokio::spawn(async { println!("task id: {}", task::id()) });
125+
126+
handle.await.unwrap();
127+
}
128+
129+
#[cfg(tokio_unstable)]
130+
#[tokio::test(flavor = "multi_thread")]
131+
async fn task_id_collision_multi_thread() {
132+
use tokio::task;
133+
134+
let handle1 = tokio::spawn(async { task::id() });
135+
let handle2 = tokio::spawn(async { task::id() });
136+
137+
let (id1, id2) = tokio::join!(handle1, handle2);
138+
assert_ne!(id1.unwrap(), id2.unwrap());
139+
}
140+
141+
#[cfg(tokio_unstable)]
142+
#[tokio::test(flavor = "current_thread")]
143+
async fn task_id_collision_current_thread() {
144+
use tokio::task;
145+
146+
let handle1 = tokio::spawn(async { task::id() });
147+
let handle2 = tokio::spawn(async { task::id() });
148+
149+
let (id1, id2) = tokio::join!(handle1, handle2);
150+
assert_ne!(id1.unwrap(), id2.unwrap());
151+
}
152+
153+
#[cfg(tokio_unstable)]
154+
#[tokio::test(flavor = "current_thread")]
155+
async fn task_ids_match_current_thread() {
156+
use tokio::{sync::oneshot, task};
157+
158+
let (tx, rx) = oneshot::channel();
159+
let handle = tokio::spawn(async {
160+
let id = rx.await.unwrap();
161+
assert_eq!(id, task::id());
162+
});
163+
tx.send(handle.id()).unwrap();
164+
handle.await.unwrap();
165+
}
166+
167+
#[cfg(tokio_unstable)]
168+
#[tokio::test(flavor = "multi_thread")]
169+
async fn task_ids_match_multi_thread() {
170+
use tokio::{sync::oneshot, task};
171+
172+
let (tx, rx) = oneshot::channel();
173+
let handle = tokio::spawn(async {
174+
let id = rx.await.unwrap();
175+
assert_eq!(id, task::id());
176+
});
177+
tx.send(handle.id()).unwrap();
178+
handle.await.unwrap();
179+
}

tokio/tests/task_panic.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,16 @@ fn local_key_get_panic_caller() -> Result<(), Box<dyn Error>> {
120120

121121
Ok(())
122122
}
123+
124+
#[cfg(tokio_unstable)]
125+
#[test]
126+
fn task_id_handle_panic_caller() -> Result<(), Box<dyn Error>> {
127+
let panic_location_file = test_panic(|| {
128+
let _ = task::id();
129+
});
130+
131+
// The panic location should be in this file
132+
assert_eq!(&panic_location_file.unwrap(), file!());
133+
134+
Ok(())
135+
}

0 commit comments

Comments
 (0)