diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index 541813c7e42..2b150cb3025 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -38,6 +38,8 @@ use lightning::routing::router::Router; use lightning::routing::scoring::{Score, WriteableScore}; use lightning::util::logger::Logger; use lightning::util::persist::Persister; +#[cfg(feature = "std")] +use lightning::util::wakers::Sleeper; use lightning_rapid_gossip_sync::RapidGossipSync; use core::ops::Deref; @@ -114,6 +116,13 @@ const FIRST_NETWORK_PRUNE_TIMER: u64 = 60; #[cfg(test)] const FIRST_NETWORK_PRUNE_TIMER: u64 = 1; +#[cfg(feature = "futures")] +/// core::cmp::min is not currently const, so we define a trivial (and equivalent) replacement +const fn min_u64(a: u64, b: u64) -> u64 { if a < b { a } else { b } } +#[cfg(feature = "futures")] +const FASTEST_TIMER: u64 = min_u64(min_u64(FRESHNESS_TIMER, PING_TIMER), + min_u64(SCORER_PERSIST_TIMER, FIRST_NETWORK_PRUNE_TIMER)); + /// Either [`P2PGossipSync`] or [`RapidGossipSync`]. pub enum GossipSync< P: Deref>, @@ -256,7 +265,8 @@ macro_rules! define_run_body { ($persister: ident, $chain_monitor: ident, $process_chain_monitor_events: expr, $channel_manager: ident, $process_channel_manager_events: expr, $gossip_sync: ident, $peer_manager: ident, $logger: ident, $scorer: ident, - $loop_exit_check: expr, $await: expr, $get_timer: expr, $timer_elapsed: expr) + $loop_exit_check: expr, $await: expr, $get_timer: expr, $timer_elapsed: expr, + $check_slow_await: expr) => { { log_trace!($logger, "Calling ChannelManager's timer_tick_occurred on startup"); $channel_manager.timer_tick_occurred(); @@ -286,9 +296,10 @@ macro_rules! define_run_body { // We wait up to 100ms, but track how long it takes to detect being put to sleep, // see `await_start`'s use below. - let mut await_start = $get_timer(1); + let mut await_start = None; + if $check_slow_await { await_start = Some($get_timer(1)); } let updates_available = $await; - let await_slow = $timer_elapsed(&mut await_start, 1); + let await_slow = if $check_slow_await { $timer_elapsed(&mut await_start.unwrap(), 1) } else { false }; if updates_available { log_trace!($logger, "Persisting ChannelManager..."); @@ -388,15 +399,20 @@ pub(crate) mod futures_util { use core::task::{Poll, Waker, RawWaker, RawWakerVTable}; use core::pin::Pin; use core::marker::Unpin; - pub(crate) struct Selector + Unpin, B: Future + Unpin> { + pub(crate) struct Selector< + A: Future + Unpin, B: Future + Unpin, C: Future + Unpin + > { pub a: A, pub b: B, + pub c: C, } pub(crate) enum SelectorOutput { - A, B(bool), + A, B, C(bool), } - impl + Unpin, B: Future + Unpin> Future for Selector { + impl< + A: Future + Unpin, B: Future + Unpin, C: Future + Unpin + > Future for Selector { type Output = SelectorOutput; fn poll(mut self: Pin<&mut Self>, ctx: &mut core::task::Context<'_>) -> Poll { match Pin::new(&mut self.a).poll(ctx) { @@ -404,7 +420,11 @@ pub(crate) mod futures_util { Poll::Pending => {}, } match Pin::new(&mut self.b).poll(ctx) { - Poll::Ready(res) => { return Poll::Ready(SelectorOutput::B(res)); }, + Poll::Ready(()) => { return Poll::Ready(SelectorOutput::B); }, + Poll::Pending => {}, + } + match Pin::new(&mut self.c).poll(ctx) { + Poll::Ready(res) => { return Poll::Ready(SelectorOutput::C(res)); }, Poll::Pending => {}, } Poll::Pending @@ -438,6 +458,11 @@ use core::task; /// feature, doing so will skip calling [`NetworkGraph::remove_stale_channels_and_tracking`], /// you should call [`NetworkGraph::remove_stale_channels_and_tracking_with_time`] regularly /// manually instead. +/// +/// The `mobile_interruptable_platform` flag should be set if we're currently running on a +/// mobile device, where we may need to check for interruption of the application regularly. If you +/// are unsure, you should set the flag, as the performance impact of it is minimal unless there +/// are hundreds or thousands of simultaneous process calls running. #[cfg(feature = "futures")] pub async fn process_events_async< 'a, @@ -473,7 +498,7 @@ pub async fn process_events_async< >( persister: PS, event_handler: EventHandler, chain_monitor: M, channel_manager: CM, gossip_sync: GossipSync, peer_manager: PM, logger: L, scorer: Option, - sleeper: Sleeper, + sleeper: Sleeper, mobile_interruptable_platform: bool, ) -> Result<(), lightning::io::Error> where UL::Target: 'static + UtxoLookup, @@ -514,11 +539,13 @@ where gossip_sync, peer_manager, logger, scorer, should_break, { let fut = Selector { a: channel_manager.get_persistable_update_future(), - b: sleeper(Duration::from_millis(100)), + b: chain_monitor.get_update_future(), + c: sleeper(if mobile_interruptable_platform { Duration::from_millis(100) } else { Duration::from_secs(FASTEST_TIMER) }), }; match fut.await { SelectorOutput::A => true, - SelectorOutput::B(exit) => { + SelectorOutput::B => false, + SelectorOutput::C(exit) => { should_break = exit; false } @@ -528,7 +555,7 @@ where let mut waker = dummy_waker(); let mut ctx = task::Context::from_waker(&mut waker); core::pin::Pin::new(fut).poll(&mut ctx).is_ready() - }) + }, mobile_interruptable_platform) } #[cfg(feature = "std")] @@ -643,8 +670,11 @@ impl BackgroundProcessor { define_run_body!(persister, chain_monitor, chain_monitor.process_pending_events(&event_handler), channel_manager, channel_manager.process_pending_events(&event_handler), gossip_sync, peer_manager, logger, scorer, stop_thread.load(Ordering::Acquire), - channel_manager.await_persistable_update_timeout(Duration::from_millis(100)), - |_| Instant::now(), |time: &Instant, dur| time.elapsed().as_secs() > dur) + Sleeper::from_two_futures( + channel_manager.get_persistable_update_future(), + chain_monitor.get_update_future() + ).wait_timeout(Duration::from_millis(100)), + |_| Instant::now(), |time: &Instant, dur| time.elapsed().as_secs() > dur, false) }); Self { stop_thread: stop_thread_clone, thread_handle: Some(handle) } } diff --git a/lightning-net-tokio/src/lib.rs b/lightning-net-tokio/src/lib.rs index aeb5c5b7a87..37c9ddad762 100644 --- a/lightning-net-tokio/src/lib.rs +++ b/lightning-net-tokio/src/lib.rs @@ -8,64 +8,19 @@ // licenses. //! A socket handling library for those running in Tokio environments who wish to use -//! rust-lightning with native TcpStreams. +//! rust-lightning with native [`TcpStream`]s. //! //! Designed to be as simple as possible, the high-level usage is almost as simple as "hand over a -//! TcpStream and a reference to a PeerManager and the rest is handled", except for the -//! [Event](../lightning/util/events/enum.Event.html) handling mechanism; see example below. +//! [`TcpStream`] and a reference to a [`PeerManager`] and the rest is handled". //! -//! The PeerHandler, due to the fire-and-forget nature of this logic, must be an Arc, and must use -//! the SocketDescriptor provided here as the PeerHandler's SocketDescriptor. +//! The [`PeerManager`], due to the fire-and-forget nature of this logic, must be a reference, +//! (e.g. an [`Arc`]) and must use the [`SocketDescriptor`] provided here as the [`PeerManager`]'s +//! `SocketDescriptor` implementation. //! -//! Three methods are exposed to register a new connection for handling in tokio::spawn calls; see -//! their individual docs for details. +//! Three methods are exposed to register a new connection for handling in [`tokio::spawn`] calls; +//! see their individual docs for details. //! -//! # Example -//! ``` -//! use std::net::TcpStream; -//! use bitcoin::secp256k1::PublicKey; -//! use lightning::events::{Event, EventHandler, EventsProvider}; -//! use std::net::SocketAddr; -//! use std::sync::Arc; -//! -//! // Define concrete types for our high-level objects: -//! type TxBroadcaster = dyn lightning::chain::chaininterface::BroadcasterInterface + Send + Sync; -//! type FeeEstimator = dyn lightning::chain::chaininterface::FeeEstimator + Send + Sync; -//! type Logger = dyn lightning::util::logger::Logger + Send + Sync; -//! type NodeSigner = dyn lightning::chain::keysinterface::NodeSigner + Send + Sync; -//! type UtxoLookup = dyn lightning::routing::utxo::UtxoLookup + Send + Sync; -//! type ChainFilter = dyn lightning::chain::Filter + Send + Sync; -//! type DataPersister = dyn lightning::chain::chainmonitor::Persist + Send + Sync; -//! type ChainMonitor = lightning::chain::chainmonitor::ChainMonitor, Arc, Arc, Arc, Arc>; -//! type ChannelManager = Arc>; -//! type PeerManager = Arc>; -//! -//! // Connect to node with pubkey their_node_id at addr: -//! async fn connect_to_node(peer_manager: PeerManager, chain_monitor: Arc, channel_manager: ChannelManager, their_node_id: PublicKey, addr: SocketAddr) { -//! lightning_net_tokio::connect_outbound(peer_manager, their_node_id, addr).await; -//! loop { -//! let event_handler = |event: Event| { -//! // Handle the event! -//! }; -//! channel_manager.await_persistable_update(); -//! channel_manager.process_pending_events(&event_handler); -//! chain_monitor.process_pending_events(&event_handler); -//! } -//! } -//! -//! // Begin reading from a newly accepted socket and talk to the peer: -//! async fn accept_socket(peer_manager: PeerManager, chain_monitor: Arc, channel_manager: ChannelManager, socket: TcpStream) { -//! lightning_net_tokio::setup_inbound(peer_manager, socket); -//! loop { -//! let event_handler = |event: Event| { -//! // Handle the event! -//! }; -//! channel_manager.await_persistable_update(); -//! channel_manager.process_pending_events(&event_handler); -//! chain_monitor.process_pending_events(&event_handler); -//! } -//! } -//! ``` +//! [`PeerManager`]: lightning::ln::peer_handler::PeerManager // Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings. #![deny(broken_intra_doc_links)] diff --git a/lightning/src/chain/chainmonitor.rs b/lightning/src/chain/chainmonitor.rs index 4bfb47de402..f4109ac173d 100644 --- a/lightning/src/chain/chainmonitor.rs +++ b/lightning/src/chain/chainmonitor.rs @@ -37,6 +37,7 @@ use crate::events::{Event, EventHandler}; use crate::util::atomic_counter::AtomicCounter; use crate::util::logger::Logger; use crate::util::errors::APIError; +use crate::util::wakers::{Future, Notifier}; use crate::ln::channelmanager::ChannelDetails; use crate::prelude::*; @@ -240,6 +241,8 @@ pub struct ChainMonitor, Option)>>, /// The best block height seen, used as a proxy for the passage of time. highest_chain_height: AtomicUsize, + + event_notifier: Notifier, } impl ChainMonitor @@ -300,6 +303,7 @@ where C::Target: chain::Filter, ChannelMonitorUpdateStatus::PermanentFailure => { monitor_state.channel_perm_failed.store(true, Ordering::Release); self.pending_monitor_events.lock().unwrap().push((*funding_outpoint, vec![MonitorEvent::UpdateFailed(*funding_outpoint)], monitor.get_counterparty_node_id())); + self.event_notifier.notify(); }, ChannelMonitorUpdateStatus::InProgress => { log_debug!(self.logger, "Channel Monitor sync for channel {} in progress, holding events until completion!", log_funding_info!(monitor)); @@ -345,6 +349,7 @@ where C::Target: chain::Filter, persister, pending_monitor_events: Mutex::new(Vec::new()), highest_chain_height: AtomicUsize::new(0), + event_notifier: Notifier::new(), } } @@ -472,6 +477,7 @@ where C::Target: chain::Filter, } }, } + self.event_notifier.notify(); Ok(()) } @@ -486,6 +492,7 @@ where C::Target: chain::Filter, funding_txo, monitor_update_id, }], counterparty_node_id)); + self.event_notifier.notify(); } #[cfg(any(test, fuzzing, feature = "_test_utils"))] @@ -514,6 +521,18 @@ where C::Target: chain::Filter, handler(event).await; } } + + /// Gets a [`Future`] that completes when an event is available either via + /// [`chain::Watch::release_pending_monitor_events`] or + /// [`EventsProvider::process_pending_events`]. + /// + /// Note that callbacks registered on the [`Future`] MUST NOT call back into this + /// [`ChainMonitor`] and should instead register actions to be taken later. + /// + /// [`EventsProvider::process_pending_events`]: crate::events::EventsProvider::process_pending_events + pub fn get_update_future(&self) -> Future { + self.event_notifier.get_future() + } } impl diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index e18c133636d..c2ff9da67f3 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -6170,34 +6170,11 @@ where } } - /// Blocks until ChannelManager needs to be persisted or a timeout is reached. It returns a bool - /// indicating whether persistence is necessary. Only one listener on - /// [`await_persistable_update`], [`await_persistable_update_timeout`], or a future returned by - /// [`get_persistable_update_future`] is guaranteed to be woken up. + /// Gets a [`Future`] that completes when this [`ChannelManager`] needs to be persisted. /// - /// Note that this method is not available with the `no-std` feature. + /// Note that callbacks registered on the [`Future`] MUST NOT call back into this + /// [`ChannelManager`] and should instead register actions to be taken later. /// - /// [`await_persistable_update`]: Self::await_persistable_update - /// [`await_persistable_update_timeout`]: Self::await_persistable_update_timeout - /// [`get_persistable_update_future`]: Self::get_persistable_update_future - #[cfg(any(test, feature = "std"))] - pub fn await_persistable_update_timeout(&self, max_wait: Duration) -> bool { - self.persistence_notifier.wait_timeout(max_wait) - } - - /// Blocks until ChannelManager needs to be persisted. Only one listener on - /// [`await_persistable_update`], `await_persistable_update_timeout`, or a future returned by - /// [`get_persistable_update_future`] is guaranteed to be woken up. - /// - /// [`await_persistable_update`]: Self::await_persistable_update - /// [`get_persistable_update_future`]: Self::get_persistable_update_future - pub fn await_persistable_update(&self) { - self.persistence_notifier.wait() - } - - /// Gets a [`Future`] that completes when a persistable update is available. Note that - /// callbacks registered on the [`Future`] MUST NOT call back into this [`ChannelManager`] and - /// should instead register actions to be taken later. pub fn get_persistable_update_future(&self) -> Future { self.persistence_notifier.get_future() } @@ -7952,6 +7929,7 @@ mod tests { use bitcoin::hashes::Hash; use bitcoin::hashes::sha256::Hash as Sha256; use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; + #[cfg(feature = "std")] use core::time::Duration; use core::sync::atomic::Ordering; use crate::events::{Event, HTLCDestination, MessageSendEvent, MessageSendEventsProvider, ClosureReason}; @@ -7977,9 +7955,9 @@ mod tests { // All nodes start with a persistable update pending as `create_network` connects each node // with all other nodes to make most tests simpler. - assert!(nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1))); - assert!(nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1))); - assert!(nodes[2].node.await_persistable_update_timeout(Duration::from_millis(1))); + assert!(nodes[0].node.get_persistable_update_future().poll_is_complete()); + assert!(nodes[1].node.get_persistable_update_future().poll_is_complete()); + assert!(nodes[2].node.get_persistable_update_future().poll_is_complete()); let mut chan = create_announced_chan_between_nodes(&nodes, 0, 1); @@ -7993,19 +7971,19 @@ mod tests { &nodes[0].node.get_our_node_id()).pop().unwrap(); // The first two nodes (which opened a channel) should now require fresh persistence - assert!(nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1))); - assert!(nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1))); + assert!(nodes[0].node.get_persistable_update_future().poll_is_complete()); + assert!(nodes[1].node.get_persistable_update_future().poll_is_complete()); // ... but the last node should not. - assert!(!nodes[2].node.await_persistable_update_timeout(Duration::from_millis(1))); + assert!(!nodes[2].node.get_persistable_update_future().poll_is_complete()); // After persisting the first two nodes they should no longer need fresh persistence. - assert!(!nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1))); - assert!(!nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1))); + assert!(!nodes[0].node.get_persistable_update_future().poll_is_complete()); + assert!(!nodes[1].node.get_persistable_update_future().poll_is_complete()); // Node 3, unrelated to the only channel, shouldn't care if it receives a channel_update // about the channel. nodes[2].node.handle_channel_update(&nodes[1].node.get_our_node_id(), &chan.0); nodes[2].node.handle_channel_update(&nodes[1].node.get_our_node_id(), &chan.1); - assert!(!nodes[2].node.await_persistable_update_timeout(Duration::from_millis(1))); + assert!(!nodes[2].node.get_persistable_update_future().poll_is_complete()); // The nodes which are a party to the channel should also ignore messages from unrelated // parties. @@ -8013,8 +7991,8 @@ mod tests { nodes[0].node.handle_channel_update(&nodes[2].node.get_our_node_id(), &chan.1); nodes[1].node.handle_channel_update(&nodes[2].node.get_our_node_id(), &chan.0); nodes[1].node.handle_channel_update(&nodes[2].node.get_our_node_id(), &chan.1); - assert!(!nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1))); - assert!(!nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1))); + assert!(!nodes[0].node.get_persistable_update_future().poll_is_complete()); + assert!(!nodes[1].node.get_persistable_update_future().poll_is_complete()); // At this point the channel info given by peers should still be the same. assert_eq!(nodes[0].node.list_channels()[0], node_a_chan_info); @@ -8031,8 +8009,8 @@ mod tests { // persisted and that its channel info remains the same. nodes[0].node.handle_channel_update(&nodes[1].node.get_our_node_id(), &as_update); nodes[1].node.handle_channel_update(&nodes[0].node.get_our_node_id(), &bs_update); - assert!(!nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1))); - assert!(!nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1))); + assert!(!nodes[0].node.get_persistable_update_future().poll_is_complete()); + assert!(!nodes[1].node.get_persistable_update_future().poll_is_complete()); assert_eq!(nodes[0].node.list_channels()[0], node_a_chan_info); assert_eq!(nodes[1].node.list_channels()[0], node_b_chan_info); @@ -8040,8 +8018,8 @@ mod tests { // the channel info has updated. nodes[0].node.handle_channel_update(&nodes[1].node.get_our_node_id(), &bs_update); nodes[1].node.handle_channel_update(&nodes[0].node.get_our_node_id(), &as_update); - assert!(nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1))); - assert!(nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1))); + assert!(nodes[0].node.get_persistable_update_future().poll_is_complete()); + assert!(nodes[1].node.get_persistable_update_future().poll_is_complete()); assert_ne!(nodes[0].node.list_channels()[0], node_a_chan_info); assert_ne!(nodes[1].node.list_channels()[0], node_b_chan_info); } diff --git a/lightning/src/sync/debug_sync.rs b/lightning/src/sync/debug_sync.rs index 11557be82af..b9f015af656 100644 --- a/lightning/src/sync/debug_sync.rs +++ b/lightning/src/sync/debug_sync.rs @@ -12,6 +12,8 @@ use std::sync::RwLockReadGuard as StdRwLockReadGuard; use std::sync::RwLockWriteGuard as StdRwLockWriteGuard; use std::sync::Condvar as StdCondvar; +pub use std::sync::WaitTimeoutResult; + use crate::prelude::HashMap; use super::{LockTestExt, LockHeldState}; @@ -35,15 +37,19 @@ impl Condvar { Condvar { inner: StdCondvar::new() } } - pub fn wait<'a, T>(&'a self, guard: MutexGuard<'a, T>) -> LockResult> { + pub fn wait_while<'a, T, F: FnMut(&mut T) -> bool>(&'a self, guard: MutexGuard<'a, T>, condition: F) + -> LockResult> { let mutex: &'a Mutex = guard.mutex; - self.inner.wait(guard.into_inner()).map(|lock| MutexGuard { mutex, lock }).map_err(|_| ()) + self.inner.wait_while(guard.into_inner(), condition).map(|lock| MutexGuard { mutex, lock }) + .map_err(|_| ()) } #[allow(unused)] - pub fn wait_timeout<'a, T>(&'a self, guard: MutexGuard<'a, T>, dur: Duration) -> LockResult<(MutexGuard<'a, T>, ())> { + pub fn wait_timeout_while<'a, T, F: FnMut(&mut T) -> bool>(&'a self, guard: MutexGuard<'a, T>, dur: Duration, condition: F) + -> LockResult<(MutexGuard<'a, T>, WaitTimeoutResult)> { let mutex = guard.mutex; - self.inner.wait_timeout(guard.into_inner(), dur).map(|(lock, _)| (MutexGuard { mutex, lock }, ())).map_err(|_| ()) + self.inner.wait_timeout_while(guard.into_inner(), dur, condition).map_err(|_| ()) + .map(|(lock, e)| (MutexGuard { mutex, lock }, e)) } pub fn notify_all(&self) { self.inner.notify_all(); } diff --git a/lightning/src/sync/nostd_sync.rs b/lightning/src/sync/nostd_sync.rs index 17307997d81..08d54a939be 100644 --- a/lightning/src/sync/nostd_sync.rs +++ b/lightning/src/sync/nostd_sync.rs @@ -1,30 +1,10 @@ pub use ::alloc::sync::Arc; use core::ops::{Deref, DerefMut}; -use core::time::Duration; use core::cell::{RefCell, Ref, RefMut}; use super::{LockTestExt, LockHeldState}; pub type LockResult = Result; -pub struct Condvar {} - -impl Condvar { - pub fn new() -> Condvar { - Condvar { } - } - - pub fn wait<'a, T>(&'a self, guard: MutexGuard<'a, T>) -> LockResult> { - Ok(guard) - } - - #[allow(unused)] - pub fn wait_timeout<'a, T>(&'a self, guard: MutexGuard<'a, T>, _dur: Duration) -> LockResult<(MutexGuard<'a, T>, ())> { - Ok((guard, ())) - } - - pub fn notify_all(&self) {} -} - pub struct Mutex { inner: RefCell } diff --git a/lightning/src/util/wakers.rs b/lightning/src/util/wakers.rs index 1e41b2daee5..602c2ee04b7 100644 --- a/lightning/src/util/wakers.rs +++ b/lightning/src/util/wakers.rs @@ -15,12 +15,14 @@ use alloc::sync::Arc; use core::mem; -use crate::sync::{Condvar, Mutex, MutexGuard}; +use crate::sync::Mutex; use crate::prelude::*; -#[cfg(any(test, feature = "std"))] -use std::time::{Duration, Instant}; +#[cfg(feature = "std")] +use crate::sync::Condvar; +#[cfg(feature = "std")] +use std::time::Duration; use core::future::Future as StdFuture; use core::task::{Context, Poll}; @@ -30,74 +32,12 @@ use core::pin::Pin; /// Used to signal to one of many waiters that the condition they're waiting on has happened. pub(crate) struct Notifier { notify_pending: Mutex<(bool, Option>>)>, - condvar: Condvar, -} - -macro_rules! check_woken { - ($guard: expr, $retval: expr) => { { - if $guard.0 { - $guard.0 = false; - if $guard.1.as_ref().map(|l| l.lock().unwrap().complete).unwrap_or(false) { - // If we're about to return as woken, and the future state is marked complete, wipe - // the future state and let the next future wait until we get a new notify. - $guard.1.take(); - } - return $retval; - } - } } } impl Notifier { pub(crate) fn new() -> Self { Self { notify_pending: Mutex::new((false, None)), - condvar: Condvar::new(), - } - } - - fn propagate_future_state_to_notify_flag(&self) -> MutexGuard<(bool, Option>>)> { - let mut lock = self.notify_pending.lock().unwrap(); - if let Some(existing_state) = &lock.1 { - if existing_state.lock().unwrap().callbacks_made { - // If the existing `FutureState` has completed and actually made callbacks, - // consider the notification flag to have been cleared and reset the future state. - lock.1.take(); - lock.0 = false; - } - } - lock - } - - pub(crate) fn wait(&self) { - loop { - let mut guard = self.propagate_future_state_to_notify_flag(); - check_woken!(guard, ()); - guard = self.condvar.wait(guard).unwrap(); - check_woken!(guard, ()); - } - } - - #[cfg(any(test, feature = "std"))] - pub(crate) fn wait_timeout(&self, max_wait: Duration) -> bool { - let current_time = Instant::now(); - loop { - let mut guard = self.propagate_future_state_to_notify_flag(); - check_woken!(guard, true); - guard = self.condvar.wait_timeout(guard, max_wait).unwrap().0; - check_woken!(guard, true); - // Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the - // desired wait time has actually passed, and if not then restart the loop with a reduced wait - // time. Note that this logic can be highly simplified through the use of - // `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to - // 1.42.0. - let elapsed = current_time.elapsed(); - if elapsed >= max_wait { - return false; - } - match max_wait.checked_sub(elapsed) { - None => return false, - Some(_) => continue - } } } @@ -111,13 +51,19 @@ impl Notifier { } } lock.0 = true; - mem::drop(lock); - self.condvar.notify_all(); } /// Gets a [`Future`] that will get woken up with any waiters pub(crate) fn get_future(&self) -> Future { - let mut lock = self.propagate_future_state_to_notify_flag(); + let mut lock = self.notify_pending.lock().unwrap(); + if let Some(existing_state) = &lock.1 { + if existing_state.lock().unwrap().callbacks_made { + // If the existing `FutureState` has completed and actually made callbacks, + // consider the notification flag to have been cleared and reset the future state. + lock.1.take(); + lock.0 = false; + } + } if let Some(existing_state) = &lock.1 { Future { state: Arc::clone(&existing_state) } } else { @@ -137,6 +83,7 @@ impl Notifier { } } +macro_rules! define_callback { ($($bounds: path),*) => { /// A callback which is called when a [`Future`] completes. /// /// Note that this MUST NOT call back into LDK directly, it must instead schedule actions to be @@ -145,14 +92,20 @@ impl Notifier { /// /// Note that the [`std::future::Future`] implementation may only work for runtimes which schedule /// futures when they receive a wake, rather than immediately executing them. -pub trait FutureCallback : Send { +pub trait FutureCallback : $($bounds +)* { /// The method which is called. fn call(&self); } -impl FutureCallback for F { +impl FutureCallback for F { fn call(&self) { (self)(); } } +} } + +#[cfg(feature = "std")] +define_callback!(Send); +#[cfg(not(feature = "std"))] +define_callback!(); pub(crate) struct FutureState { // When we're tracking whether a callback counts as having woken the user's code, we check the @@ -175,6 +128,9 @@ impl FutureState { } /// A simple future which can complete once, and calls some callback(s) when it does so. +/// +/// Clones can be made and all futures cloned from the same source will complete at the same time. +#[derive(Clone)] pub struct Future { state: Arc>, } @@ -204,6 +160,29 @@ impl Future { pub fn register_callback_fn(&self, callback: F) { self.register_callback(Box::new(callback)); } + + /// Waits until this [`Future`] completes. + #[cfg(feature = "std")] + pub fn wait(self) { + Sleeper::from_single_future(self).wait(); + } + + /// Waits until this [`Future`] completes or the given amount of time has elapsed. + /// + /// Returns true if the [`Future`] completed, false if the time elapsed. + #[cfg(feature = "std")] + pub fn wait_timeout(self, max_wait: Duration) -> bool { + Sleeper::from_single_future(self).wait_timeout(max_wait) + } + + #[cfg(test)] + pub fn poll_is_complete(&self) -> bool { + let mut state = self.state.lock().unwrap(); + if state.complete { + state.callbacks_made = true; + true + } else { false } + } } use core::task::Waker; @@ -229,6 +208,78 @@ impl<'a> StdFuture for Future { } } +/// A struct which can be used to select across many [`Future`]s at once without relying on a full +/// async context. +#[cfg(feature = "std")] +pub struct Sleeper { + notifiers: Vec>>, +} + +#[cfg(feature = "std")] +impl Sleeper { + /// Constructs a new sleeper from one future, allowing blocking on it. + pub fn from_single_future(future: Future) -> Self { + Self { notifiers: vec![future.state] } + } + /// Constructs a new sleeper from two futures, allowing blocking on both at once. + // Note that this is the common case - a ChannelManager and ChainMonitor. + pub fn from_two_futures(fut_a: Future, fut_b: Future) -> Self { + Self { notifiers: vec![fut_a.state, fut_b.state] } + } + /// Constructs a new sleeper on many futures, allowing blocking on all at once. + pub fn new(futures: Vec) -> Self { + Self { notifiers: futures.into_iter().map(|f| f.state).collect() } + } + /// Prepares to go into a wait loop body, creating a condition variable which we can block on + /// and an `Arc>>` which gets set to the waking `Future`'s state prior to the + /// condition variable being woken. + fn setup_wait(&self) -> (Arc, Arc>>>>) { + let cv = Arc::new(Condvar::new()); + let notified_fut_mtx = Arc::new(Mutex::new(None)); + { + for notifier_mtx in self.notifiers.iter() { + let cv_ref = Arc::clone(&cv); + let notified_fut_ref = Arc::clone(¬ified_fut_mtx); + let notifier_ref = Arc::clone(¬ifier_mtx); + let mut notifier = notifier_mtx.lock().unwrap(); + if notifier.complete { + *notified_fut_mtx.lock().unwrap() = Some(notifier_ref); + break; + } + notifier.callbacks.push((false, Box::new(move || { + *notified_fut_ref.lock().unwrap() = Some(Arc::clone(¬ifier_ref)); + cv_ref.notify_all(); + }))); + } + } + (cv, notified_fut_mtx) + } + + /// Wait until one of the [`Future`]s registered with this [`Sleeper`] has completed. + pub fn wait(&self) { + let (cv, notified_fut_mtx) = self.setup_wait(); + let notified_fut = cv.wait_while(notified_fut_mtx.lock().unwrap(), |fut_opt| fut_opt.is_none()) + .unwrap().take().expect("CV wait shouldn't have returned until the notifying future was set"); + notified_fut.lock().unwrap().callbacks_made = true; + } + + /// Wait until one of the [`Future`]s registered with this [`Sleeper`] has completed or the + /// given amount of time has elapsed. Returns true if a [`Future`] completed, false if the time + /// elapsed. + pub fn wait_timeout(&self, max_wait: Duration) -> bool { + let (cv, notified_fut_mtx) = self.setup_wait(); + let notified_fut = + match cv.wait_timeout_while(notified_fut_mtx.lock().unwrap(), max_wait, |fut_opt| fut_opt.is_none()) { + Ok((_, e)) if e.timed_out() => return false, + Ok((mut notified_fut, _)) => + notified_fut.take().expect("CV wait shouldn't have returned until the notifying future was set"), + Err(_) => panic!("Previous panic while a lock was held led to a lock panic"), + }; + notified_fut.lock().unwrap().callbacks_made = true; + true + } +} + #[cfg(test)] mod tests { use super::*; @@ -327,10 +378,7 @@ mod tests { let exit_thread_clone = exit_thread.clone(); thread::spawn(move || { loop { - let mut lock = thread_notifier.notify_pending.lock().unwrap(); - lock.0 = true; - thread_notifier.condvar.notify_all(); - + thread_notifier.notify(); if exit_thread_clone.load(Ordering::SeqCst) { break } @@ -338,12 +386,12 @@ mod tests { }); // Check that we can block indefinitely until updates are available. - let _ = persistence_notifier.wait(); + let _ = persistence_notifier.get_future().wait(); // Check that the Notifier will return after the given duration if updates are // available. loop { - if persistence_notifier.wait_timeout(Duration::from_millis(100)) { + if persistence_notifier.get_future().wait_timeout(Duration::from_millis(100)) { break } } @@ -353,7 +401,7 @@ mod tests { // Check that the Notifier will return after the given duration even if no updates // are available. loop { - if !persistence_notifier.wait_timeout(Duration::from_millis(100)) { + if !persistence_notifier.get_future().wait_timeout(Duration::from_millis(100)) { break } } @@ -443,6 +491,7 @@ mod tests { } #[test] + #[cfg(feature = "std")] fn test_dropped_future_doesnt_count() { // Tests that if a Future gets drop'd before it is poll()ed `Ready` it doesn't count as // having been woken, leaving the notify-required flag set. @@ -451,8 +500,8 @@ mod tests { // If we get a future and don't touch it we're definitely still notify-required. notifier.get_future(); - assert!(notifier.wait_timeout(Duration::from_millis(1))); - assert!(!notifier.wait_timeout(Duration::from_millis(1))); + assert!(notifier.get_future().wait_timeout(Duration::from_millis(1))); + assert!(!notifier.get_future().wait_timeout(Duration::from_millis(1))); // Even if we poll'd once but didn't observe a `Ready`, we should be notify-required. let mut future = notifier.get_future(); @@ -461,7 +510,7 @@ mod tests { notifier.notify(); assert!(woken.load(Ordering::SeqCst)); - assert!(notifier.wait_timeout(Duration::from_millis(1))); + assert!(notifier.get_future().wait_timeout(Duration::from_millis(1))); // However, once we do poll `Ready` it should wipe the notify-required flag. let mut future = notifier.get_future(); @@ -471,7 +520,7 @@ mod tests { notifier.notify(); assert!(woken.load(Ordering::SeqCst)); assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(())); - assert!(!notifier.wait_timeout(Duration::from_millis(1))); + assert!(!notifier.get_future().wait_timeout(Duration::from_millis(1))); } #[test] @@ -532,4 +581,67 @@ mod tests { assert!(woken.load(Ordering::SeqCst)); assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(())); } + + #[test] + #[cfg(feature = "std")] + fn test_multi_future_sleep() { + // Tests the `Sleeper` with multiple futures. + let notifier_a = Notifier::new(); + let notifier_b = Notifier::new(); + + // Set both notifiers as woken without sleeping yet. + notifier_a.notify(); + notifier_b.notify(); + Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait(); + + // One future has woken us up, but the other should still have a pending notification. + Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait(); + + // However once we've slept twice, we should no longer have any pending notifications + assert!(!Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()) + .wait_timeout(Duration::from_millis(10))); + + // Test ordering somewhat more. + notifier_a.notify(); + Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait(); + } + + #[test] + #[cfg(feature = "std")] + fn sleeper_with_pending_callbacks() { + // This is similar to the above `test_multi_future_sleep` test, but in addition registers + // "normal" callbacks which will cause the futures to assume notification has occurred, + // rather than waiting for a woken sleeper. + let notifier_a = Notifier::new(); + let notifier_b = Notifier::new(); + + // Set both notifiers as woken without sleeping yet. + notifier_a.notify(); + notifier_b.notify(); + + // After sleeping one future (not guaranteed which one, however) will have its notification + // bit cleared. + Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait(); + + // By registering a callback on the futures for both notifiers, one will complete + // immediately, but one will remain tied to the notifier, and will complete once the + // notifier is next woken, which will be considered the completion of the notification. + let callback_a = Arc::new(AtomicBool::new(false)); + let callback_b = Arc::new(AtomicBool::new(false)); + let callback_a_ref = Arc::clone(&callback_a); + let callback_b_ref = Arc::clone(&callback_b); + notifier_a.get_future().register_callback(Box::new(move || assert!(!callback_a_ref.fetch_or(true, Ordering::SeqCst)))); + notifier_b.get_future().register_callback(Box::new(move || assert!(!callback_b_ref.fetch_or(true, Ordering::SeqCst)))); + assert!(callback_a.load(Ordering::SeqCst) ^ callback_b.load(Ordering::SeqCst)); + + // If we now notify both notifiers again, the other callback will fire, completing the + // notification, and we'll be back to one pending notification. + notifier_a.notify(); + notifier_b.notify(); + + assert!(callback_a.load(Ordering::SeqCst) && callback_b.load(Ordering::SeqCst)); + Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()).wait(); + assert!(!Sleeper::from_two_futures(notifier_a.get_future(), notifier_b.get_future()) + .wait_timeout(Duration::from_millis(10))); + } }