diff --git a/src/lib.rs b/src/lib.rs index 7336ed57..1834c387 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,6 +90,16 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { DataDirectory::create_dir_if_not_exists(&data_dir.root_dir_path()).await?; info!("Data directory is {}", data_dir); + // Get wallet object, create various wallet secret files + let wallet_dir = data_dir.wallet_directory_path(); + DataDirectory::create_dir_if_not_exists(&wallet_dir).await?; + let (wallet_secret, _) = + WalletSecret::read_from_file_or_create(&data_dir.wallet_directory_path())?; + info!("Now getting wallet state. This may take a while if the database needs pruning."); + let wallet_state = + WalletState::new_from_wallet_secret(&data_dir, wallet_secret, &cli_args).await; + info!("Got wallet state."); + // Connect to or create databases for block index, peers, mutator set, block sync let block_index_db = ArchivalState::initialize_block_index_database(&data_dir).await?; info!("Got block index database"); @@ -101,7 +111,7 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { info!("Got archival mutator set"); let archival_state = ArchivalState::new( - data_dir.clone(), + data_dir, block_index_db, archival_mutator_set, cli_args.network, @@ -155,16 +165,6 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { latest_block.hash(), ); - // Get wallet object, create various wallet secret files - let wallet_dir = data_dir.wallet_directory_path(); - DataDirectory::create_dir_if_not_exists(&wallet_dir).await?; - let (wallet_secret, _) = - WalletSecret::read_from_file_or_create(&data_dir.wallet_directory_path())?; - info!("Now getting wallet state. This may take a while if the database needs pruning."); - let wallet_state = - WalletState::new_from_wallet_secret(&data_dir, wallet_secret, &cli_args).await; - info!("Got wallet state."); - let mut global_state_lock = GlobalStateLock::new( wallet_state, blockchain_state, @@ -192,11 +192,8 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { .await?; info!("UTXO restoration check complete"); - let mut task_join_handles = vec![]; - - task_join_handles.push(spawn_wallet_task(global_state_lock.clone()).await?); - // Connect to peers, and provide each peer task with a thread-safe copy of the state + let mut task_join_handles = vec![]; for peer_address in global_state_lock.cli().peers.clone() { let peer_state_var = global_state_lock.clone(); // bump arc refcount let main_to_peer_broadcast_rx_clone: broadcast::Receiver = @@ -303,33 +300,6 @@ pub async fn initialize(cli_args: cli_args::Args) -> Result<()> { .await } -pub(crate) async fn spawn_wallet_task( - mut global_state_lock: GlobalStateLock, -) -> Result> { - let mut mempool_subscriber = global_state_lock.lock_guard().await.mempool.subscribe(); - - let wallet_join_handle = tokio::task::Builder::new() - .name("wallet_mempool_listener") - .spawn(async move { - let mut events: std::collections::VecDeque<_> = Default::default(); - - while let Ok(event) = mempool_subscriber.recv().await { - events.push_back(event); - - if let Ok(mut gs) = global_state_lock.try_lock_guard_mut() { - while let Some(e) = events.pop_front() { - gs.wallet_state - .handle_mempool_event(e) - .await - .expect("Wallet should handle mempool event without error"); - } - } - } - })?; - - Ok(wallet_join_handle) -} - /// Time a fn call. Duration is returned as a float in seconds. pub fn time_fn_call(f: impl FnOnce() -> O) -> (O, f64) { let start = Instant::now(); diff --git a/src/locks/tokio/atomic_rw.rs b/src/locks/tokio/atomic_rw.rs index ef883457..b2d9d53b 100644 --- a/src/locks/tokio/atomic_rw.rs +++ b/src/locks/tokio/atomic_rw.rs @@ -241,6 +241,9 @@ impl AtomicRw { AtomicRwWriteGuard::new(guard, &self.lock_callback_info) } + /// Attempt to acquire write lock immediately. + /// + /// If the lock cannot be acquired without waiting, an error is returned. pub fn try_lock_guard_mut(&mut self) -> Result, TryLockError> { self.try_acquire_write_cb(); let guard = self.inner.try_write()?; diff --git a/src/main_loop.rs b/src/main_loop.rs index 5797dcf1..bec4a3f1 100644 --- a/src/main_loop.rs +++ b/src/main_loop.rs @@ -539,8 +539,8 @@ impl MainLoopHandler { // Insert into mempool global_state_mut - .mempool - .insert(pt2m_transaction.transaction.to_owned())?; + .mempool_insert(pt2m_transaction.transaction.to_owned()) + .await?; // send notification to peers let transaction_notification: TransactionNotification = @@ -998,7 +998,7 @@ impl MainLoopHandler { // Handle mempool cleanup, i.e. removing stale/too old txs from mempool _ = &mut mempool_cleanup_timer => { debug!("Timer: mempool-cleaner job"); - self.global_state_lock.lock_guard_mut().await.mempool.prune_stale_transactions()?; + self.global_state_lock.lock_guard_mut().await.mempool_prune_stale_transactions().await?; // Reset the timer to run this branch again in P seconds mempool_cleanup_timer.as_mut().reset(tokio::time::Instant::now() + mempool_cleanup_timer_interval); @@ -1056,8 +1056,8 @@ impl MainLoopHandler { self.global_state_lock .lock_guard_mut() .await - .mempool - .insert(*transaction)?; + .mempool_insert(*transaction) + .await?; // do not shut down Ok(false) diff --git a/src/mine_loop.rs b/src/mine_loop.rs index 37924377..18f21160 100644 --- a/src/mine_loop.rs +++ b/src/mine_loop.rs @@ -636,12 +636,11 @@ mod mine_loop_tests { // no need to inform wallet of expected utxos; block template validity // is what is being tested - alice - .lock_guard_mut() - .await - .mempool - .insert(tx_by_preminer)?; - assert_eq!(1, alice.lock_guard().await.mempool.len()); + { + let mut alice_gsm = alice.lock_guard_mut().await; + alice_gsm.mempool_insert(tx_by_preminer).await?; + assert_eq!(1, alice_gsm.mempool.len()); + } // Build transaction for block let (transaction_non_empty_mempool, _new_coinbase_sender_randomness) = { diff --git a/src/models/state/mempool.rs b/src/models/state/mempool.rs index 64699d39..233dddc2 100644 --- a/src/models/state/mempool.rs +++ b/src/models/state/mempool.rs @@ -61,10 +61,21 @@ pub const TRANSACTION_NOTIFICATION_AGE_LIMIT_IN_SECS: u64 = 60 * 60 * 24; type LookupItem<'a> = (TransactionKernelId, &'a Transaction); +/// Represents a mempool state change. +/// +/// For purpose of notifying interested parties #[derive(Debug, Clone)] pub enum MempoolEvent { + /// a transaction was added to the mempool AddTx(Transaction), + + /// a transaction was removed from the mempool RemoveTx(Transaction), + + /// the mutator-set of a transaction was updated in the mempool. + /// + /// (Digest of Tx before update, Tx after mutator-set updated) + UpdateTxMutatorSet(TransactionKernelId, Transaction), } #[derive(Debug, GetSize)] @@ -88,15 +99,14 @@ pub struct Mempool { /// Records the digest of the block that the transactions were synced to. /// Used to discover reorganizations. tip_digest: Digest, - - /// a mpmc channel for interested parties to listen to mempool events - #[get_size(ignore)] // does not impl GetSize - event_channel: ( - tokio::sync::broadcast::Sender, - tokio::sync::broadcast::Receiver, - ), } +/// note that all methods that modify state and result in a MempoolEvent +/// notification are private or pub(super). This enforces that these methods +/// can only be called from/via GlobalState. +/// +/// Mempool updates must go through GlobalState so that it can +/// forward mempool events to the wallet in atomic fashion. impl Mempool { /// instantiate a new, empty `Mempool` pub fn new( @@ -113,12 +123,11 @@ impl Mempool { tx_dictionary: table, queue, tip_digest, - event_channel: tokio::sync::broadcast::channel(100), } } /// Update the block digest to which all transactions are synced. - fn set_tip_digest_sync_label(&mut self, tip_digest: Digest) { + pub(super) fn set_tip_digest_sync_label(&mut self, tip_digest: Digest) { self.tip_digest = tip_digest; } @@ -170,28 +179,36 @@ impl Mempool { /// The caller must also ensure that the transaction does not have a timestamp /// in the too distant future. /// + /// this method may return: + /// 2 events: RemoveTx,AddTx. tx replaces an older one with lower fee. + /// 1 event: AddTx. tx does not replace an older one. + /// 0 events: tx not added because an older matching tx has a higher fee. + /// /// # Panics /// /// Panics if the transaction's proof is of the wrong type. - pub fn insert(&mut self, transaction: Transaction) -> Result { + pub(super) fn insert(&mut self, transaction: Transaction) -> Result> { + let mut events = vec![]; + match transaction.proof { TransactionProof::Invalid => panic!("cannot insert invalid transaction into mempool"), TransactionProof::Witness(_) => {} TransactionProof::SingleProof(_) => {} TransactionProof::ProofCollection(_) => {} }; - // If transaction to be inserted conflicts with a transaction that's already // in the mempool we preserve only the one with the highest fee density. if let Some((txid, tx)) = self.transaction_conflicts_with(&transaction) { if tx.fee_density() < transaction.fee_density() { // If new transaction has a higher fee density than the one previously seen // remove the old one. - self.remove(txid)?; + if let Some(e) = self.remove(txid)? { + events.push(e); + } } else { // If new transaction has a lower fee density than the one previous seen, // ignore it. Stop execution here. - return Ok(txid); + return Ok(events); } }; @@ -212,26 +229,36 @@ impl Mempool { "mempool's table and queue length must agree after shrink" ); - self.sender().send(MempoolEvent::AddTx(transaction))?; + events.push(MempoolEvent::AddTx(transaction)); - Ok(txid) + Ok(events) } /// remove a transaction from the `Mempool` - pub fn remove(&mut self, transaction_id: TransactionKernelId) -> Result { + pub(super) fn remove( + &mut self, + transaction_id: TransactionKernelId, + ) -> Result> { match self.tx_dictionary.remove(&transaction_id) { Some(tx) => { self.queue.remove(&transaction_id); debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); - self.sender().send(MempoolEvent::RemoveTx(tx))?; - Ok(true) + Ok(Some(MempoolEvent::RemoveTx(tx))) } - None => Ok(false), + None => Ok(None), } } /// Delete all transactions from the mempool. - pub fn clear(&mut self) -> Result<()> { + /// + /// note that this will return a MempoolEvent for every removed Tx. + /// In the case of a full block, that could be a lot of Tx and + /// significant memory usage. Of course the mempool itself will + /// be emptied at the same time. + /// + /// If the mem usage ever becomes a problem we could accept a closure + /// to handle the events individually as each Tx is removed. + pub(super) fn clear(&mut self) -> Result> { // note: this causes event listeners to be notified of each removed tx. self.retain(|_| false) } @@ -285,15 +312,14 @@ impl Mempool { /// /// Computes in θ(lg N) #[allow(dead_code)] - pub fn pop_max(&mut self) -> Result> { + fn pop_max(&mut self) -> Result> { if let Some((transaction_digest, fee_density)) = self.queue.pop_max() { if let Some(transaction) = self.tx_dictionary.remove(&transaction_digest) { debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); - self.sender() - .send(MempoolEvent::RemoveTx(transaction.clone()))?; + let event = MempoolEvent::RemoveTx(transaction); - return Ok(Some((transaction, fee_density))); + return Ok(Some((event, fee_density))); } } Ok(None) @@ -303,15 +329,14 @@ impl Mempool { /// Returns the removed value. /// /// Computes in θ(lg N) - pub fn pop_min(&mut self) -> Result> { + fn pop_min(&mut self) -> Result> { if let Some((transaction_digest, fee_density)) = self.queue.pop_min() { if let Some(transaction) = self.tx_dictionary.remove(&transaction_digest) { debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); - self.sender() - .send(MempoolEvent::RemoveTx(transaction.clone()))?; + let event = MempoolEvent::RemoveTx(transaction); - return Ok(Some((transaction, fee_density))); + return Ok(Some((event, fee_density))); } } Ok(None) @@ -322,7 +347,7 @@ impl Mempool { /// Modelled after [HashMap::retain](std::collections::HashMap::retain()) /// /// Computes in O(capacity) >= O(N) - pub fn retain(&mut self, mut predicate: F) -> Result<()> + fn retain(&mut self, mut predicate: F) -> Result> where F: FnMut(LookupItem) -> bool, { @@ -335,21 +360,24 @@ impl Mempool { } } + let mut events = Vec::with_capacity(victims.len()); for t in victims { - self.remove(t)?; + if let Some(e) = self.remove(t)? { + events.push(e); + } } debug_assert_eq!(self.tx_dictionary.len(), self.queue.len()); self.shrink_to_fit(); - Ok(()) + Ok(events) } /// Remove transactions from mempool that are older than the specified /// timestamp. Prunes base on the transaction's timestamp. /// /// Computes in O(n) - pub fn prune_stale_transactions(&mut self) -> Result<()> { + pub(super) fn prune_stale_transactions(&mut self) -> Result> { let cutoff = Timestamp::now() - Timestamp::seconds(MEMPOOL_TX_THRESHOLD_AGE_IN_SECS); let keep = |(_transaction_id, transaction): LookupItem| -> bool { @@ -362,11 +390,11 @@ impl Mempool { /// Remove from the mempool all transactions that become invalid because /// of a newly received block. Also update all mutator set data for mempool /// transactions that were not removed. - pub async fn update_with_block( + pub(super) async fn update_with_block( &mut self, previous_mutator_set_accumulator: MutatorSetAccumulator, block: &Block, - ) -> Result<()> { + ) -> Result> { // If we discover a reorganization, we currently just clear the mempool, // as we don't have the ability to roll transaction removal record integrity // proofs back to previous blocks. It would be nice if we could handle a @@ -414,20 +442,21 @@ impl Mempool { }; // Remove the transactions that become invalid with this block - self.retain(keep)?; + let mut events = self.retain(keep)?; // Update the remaining transactions so their mutator set data is still valid - // But kick out those transactions that we were unable to update. - let mut kick_outs = vec![]; + let mut kick_outs = Vec::with_capacity(self.tx_dictionary.len()); for (tx_id, tx) in self.tx_dictionary.iter_mut() { if let Ok(new_tx) = tx .clone() .new_with_updated_mutator_set_records(&previous_mutator_set_accumulator, block) { *tx = new_tx; + events.push(MempoolEvent::UpdateTxMutatorSet(*tx_id, (*tx).clone())); } else { error!("Failed to update transaction {tx_id}. Removing from mempool."); kick_outs.push(*tx_id); + events.push(MempoolEvent::RemoveTx(tx.clone())); } } @@ -442,7 +471,7 @@ impl Mempool { let current_block_digest = block.hash(); self.set_tip_digest_sync_label(current_block_digest); - Ok(()) + Ok(events) } /// Shrink the memory pool to the value of its `max_size` field. @@ -507,15 +536,6 @@ impl Mempool { let dpq_clone = self.queue.clone(); dpq_clone.into_sorted_iter().rev() } - - pub fn subscribe(&self) -> tokio::sync::broadcast::Receiver { - self.sender().subscribe() - } - - fn sender(&self) -> &tokio::sync::broadcast::Sender { - let (sender, _) = &self.event_channel; - sender - } } #[cfg(test)] @@ -568,13 +588,13 @@ mod tests { assert!(mempool.contains(transaction_digests[0])); assert!(!mempool.contains(transaction_digests[1])); - assert!(mempool.remove(transaction_digests[0])?); + assert!(mempool.remove(transaction_digests[0])?.is_some()); assert!(!mempool.contains(transaction_digests[0])); for tx_id in transaction_digests.iter() { assert!(!mempool.contains(*tx_id)); } - assert!(mempool.remove(transaction_digests[0])?); + assert!(mempool.remove(transaction_digests[0])?.is_some()); assert!(!mempool.contains(transaction_digests[0])); let transaction_second_get_option = mempool.get(transaction_digests[0]); diff --git a/src/models/state/mod.rs b/src/models/state/mod.rs index 76958d66..42b8c274 100644 --- a/src/models/state/mod.rs +++ b/src/models/state/mod.rs @@ -1382,6 +1382,24 @@ impl GlobalState { pub fn cli(&self) -> &cli_args::Args { &self.cli } + + /// clears all Tx from mempool and notifies wallet of changes. + pub async fn mempool_clear(&mut self) -> Result<()> { + let events = self.mempool.clear()?; + self.wallet_state.handle_mempool_events(events).await + } + + /// adds Tx to mempool and notifies wallet of change. + pub async fn mempool_insert(&mut self, transaction: Transaction) -> Result<()> { + let events = self.mempool.insert(transaction)?; + self.wallet_state.handle_mempool_events(events).await + } + + /// prunes stale tx in mempool and notifies wallet of changes. + pub async fn mempool_prune_stale_transactions(&mut self) -> Result<()> { + let events = self.mempool.prune_stale_transactions()?; + self.wallet_state.handle_mempool_events(events).await + } } #[cfg(test)] diff --git a/src/models/state/wallet/rusty_wallet_database.rs b/src/models/state/wallet/rusty_wallet_database.rs index 70e1c2b2..92239b4a 100644 --- a/src/models/state/wallet/rusty_wallet_database.rs +++ b/src/models/state/wallet/rusty_wallet_database.rs @@ -17,9 +17,6 @@ pub struct RustyWalletDatabase { // list of utxos we have already received in a block monitored_utxos: DbtVec, - // list of utxos presently in the mempool - // monitored_mempool_utxos: DbtVec, - // list of off-chain utxos we are expecting to receive in a future block expected_utxos: DbtVec, diff --git a/src/models/state/wallet/wallet_state.rs b/src/models/state/wallet/wallet_state.rs index d143e51a..acda9aec 100644 --- a/src/models/state/wallet/wallet_state.rs +++ b/src/models/state/wallet/wallet_state.rs @@ -18,6 +18,7 @@ use tokio::io::BufWriter; use tracing::debug; use tracing::error; use tracing::info; +use tracing::trace; use tracing::warn; use twenty_first::math::bfield_codec::BFieldCodec; use twenty_first::math::digest::Digest; @@ -68,6 +69,8 @@ pub struct WalletState { pub number_of_mps_per_utxo: usize, wallet_directory_path: PathBuf, + /// these two fields are for monitoring wallet-affecting utxos in the mempool. + /// key is Tx hash. for removing watched utxos when a tx is removed from mempool. mempool_spent_utxos: HashMap>, mempool_unspent_utxos: HashMap>, } @@ -269,10 +272,30 @@ impl WalletState { .collect_vec() } - pub async fn handle_mempool_event(&mut self, event: MempoolEvent) -> Result<()> { + /// handles a list of mempool events + pub(in crate::models::state) async fn handle_mempool_events( + &mut self, + events: impl IntoIterator, + ) -> Result<()> { + for event in events { + self.handle_mempool_event(event).await? + } + Ok(()) + } + + /// handles a single mempool event. + /// + /// note: the wallet watches the mempool in order to keep track of + /// unconfirmed utxos sent from or to the wallet. This enables + /// calculation of unconfirmed balance. It also lays foundation for + /// spending unconfirmed utxos. (issue #189) + pub(in crate::models::state) async fn handle_mempool_event( + &mut self, + event: MempoolEvent, + ) -> Result<()> { match event { MempoolEvent::AddTx(tx) => { - debug!("handling mempool AddTx event."); + trace!("handling mempool AddTx event."); let spent_utxos = self.scan_for_spent_utxos(&tx.kernel).await; @@ -286,11 +309,14 @@ impl WalletState { self.mempool_unspent_utxos.insert(tx_hash, announced_utxos); } MempoolEvent::RemoveTx(tx) => { - debug!("handling mempool RemoveTx event."); + trace!("handling mempool RemoveTx event."); let tx_hash = Hash::hash(&tx); self.mempool_spent_utxos.remove(&tx_hash); self.mempool_unspent_utxos.remove(&tx_hash); } + MempoolEvent::UpdateTxMutatorSet(_tx_hash_pre_update, _tx_post_update) => { + // Utxos are not affected by MutatorSet update, so this is a no-op. + } } Ok(()) } @@ -1488,6 +1514,15 @@ mod tests { use crate::models::state::wallet::address::ReceivingAddress; use crate::tests::shared::mine_block_to_wallet; + /// basic test for confirmed and unconfirmed balance. + /// + /// This test: + /// 1. mines a block to self worth 100 + /// 2. sends 5 to a 3rd party, and 95 change back to self. + /// 3. verifies that confirmed balance is 100 + /// 4. verifies that unconfirmed balance is 95 + /// 5. empties the mempool (removing our unconfirmed tx) + /// 6. verifies that unconfirmed balance is 100 #[traced_test] #[tokio::test] async fn confirmed_and_unconfirmed_balance() -> Result<()> { @@ -1495,18 +1530,20 @@ mod tests { let network = Network::RegTest; let mut global_state_lock = mock_genesis_global_state(network, 0, WalletSecret::new_random()).await; - let _wallet_task_jh = crate::spawn_wallet_task(global_state_lock.clone()).await?; let change_key = global_state_lock .lock_guard_mut() .await .wallet_state .next_unused_spending_key(KeyType::Generation); + let coinbase_amt = NeptuneCoins::new(100); let send_amt = NeptuneCoins::new(5); + // mine a block to our wallet. we should have 100 coins after. let tip_digest = mine_block_to_wallet(&mut global_state_lock).await?.hash(); let tx = { + // verify that confirmed and unconfirmed balance are both 100. let gs = global_state_lock.lock_guard().await; assert_eq!( gs.wallet_state @@ -1521,7 +1558,7 @@ mod tests { coinbase_amt ); - // --- Setup. generate an output that our wallet cannot claim. --- + // generate an output that our wallet cannot claim. let outputs = vec![( ReceivingAddress::from(GenerationReceivingAddress::derive_from_seed(rng.gen())), send_amt, @@ -1541,16 +1578,16 @@ mod tests { tx }; + // add the tx to the mempool. + // note that the wallet should be notified of these changes. global_state_lock .lock_guard_mut() .await - .mempool - .insert(tx)?; - - // we must yield so the wallet task can process the mempool events - tokio::task::yield_now().await; + .mempool_insert(tx) + .await?; { + // verify that confirmed balance is still 100 let gs = global_state_lock.lock_guard().await; assert_eq!( gs.wallet_state @@ -1558,21 +1595,23 @@ mod tests { .await, coinbase_amt ); - debug!("calculated confirmed balance"); + // verify that unconfirmed balance is now 95. assert_eq!( gs.wallet_state .unconfirmed_balance(tip_digest, Timestamp::now()) .await, coinbase_amt.checked_sub(&send_amt).unwrap() ); - debug!("calculated unconfirmed balance"); } - global_state_lock.lock_guard_mut().await.mempool.clear()?; - - // we must yield so the wallet task can process the mempool events - tokio::task::yield_now().await; + // clear the mempool, which drops our unconfirmed tx. + global_state_lock + .lock_guard_mut() + .await + .mempool_clear() + .await?; + // verify that wallet's unconfirmed balance is 100 again. assert_eq!( global_state_lock .lock_guard() diff --git a/src/peer_loop.rs b/src/peer_loop.rs index 822c2352..9208009b 100644 --- a/src/peer_loop.rs +++ b/src/peer_loop.rs @@ -2574,8 +2574,8 @@ mod peer_loop_tests { state_lock .lock_guard_mut() .await - .mempool - .insert(transaction_1.clone())?; + .mempool_insert(transaction_1.clone()) + .await?; assert!( !state_lock.lock_guard().await.mempool.is_empty(), "Mempool must be non-empty after insertion"