Skip to content

Commit ee09e04

Browse files
authored
sync: drop wakers after unlocking the mutex in Notify (#5471)
1 parent d07027f commit ee09e04

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

tokio/src/sync/notify.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -917,10 +917,14 @@ impl Notified<'_> {
917917
}
918918
}
919919

920+
let mut old_waker = None;
920921
if waker.is_some() {
921922
// Safety: called while locked.
923+
//
924+
// The use of `old_waiter` here is not necessary, as the field is always
925+
// None when we reach this line.
922926
unsafe {
923-
(*waiter.get()).waker = waker;
927+
old_waker = std::mem::replace(&mut (*waiter.get()).waker, waker);
924928
}
925929
}
926930

@@ -931,6 +935,9 @@ impl Notified<'_> {
931935

932936
*state = Waiting;
933937

938+
drop(waiters);
939+
drop(old_waker);
940+
934941
return Poll::Pending;
935942
}
936943
Waiting => {
@@ -945,12 +952,13 @@ impl Notified<'_> {
945952

946953
// Safety: called while locked
947954
let w = unsafe { &mut *waiter.get() };
955+
let mut old_waker = None;
948956

949957
if w.notified.is_some() {
950958
// Our waker has been notified and our waiter is already removed from
951959
// the list. Reset the notification and convert to `Done`.
960+
old_waker = std::mem::take(&mut w.waker);
952961
w.notified = None;
953-
w.waker = None;
954962
*state = Done;
955963
} else if get_num_notify_waiters_calls(curr) != *notify_waiters_calls {
956964
// Before we add a waiter to the list we check if these numbers are
@@ -960,7 +968,7 @@ impl Notified<'_> {
960968
// We can treat the waiter as notified and remove it from the list, as
961969
// it would have been notified in the `notify_waiters` call anyways.
962970

963-
w.waker = None;
971+
old_waker = std::mem::take(&mut w.waker);
964972

965973
// Safety: we hold the lock, so we have an exclusive access to the list.
966974
// The list is used in `notify_waiters`, so it must be guarded.
@@ -975,10 +983,14 @@ impl Notified<'_> {
975983
None => true,
976984
};
977985
if should_update {
978-
w.waker = Some(waker.clone());
986+
old_waker = std::mem::replace(&mut w.waker, Some(waker.clone()));
979987
}
980988
}
981989

990+
// Drop the old waker after releasing the lock.
991+
drop(waiters);
992+
drop(old_waker);
993+
982994
return Poll::Pending;
983995
}
984996

@@ -988,6 +1000,9 @@ impl Notified<'_> {
9881000
// is helpful to visualize the scope of the critical
9891001
// section.
9901002
drop(waiters);
1003+
1004+
// Drop the old waker after releasing the lock.
1005+
drop(old_waker);
9911006
}
9921007
Done => {
9931008
return Poll::Ready(());

0 commit comments

Comments
 (0)