diff --git a/gossip/src/push_active_set.rs b/gossip/src/push_active_set.rs index 1e7e3cbb22844c..83f84b8a0624ed 100644 --- a/gossip/src/push_active_set.rs +++ b/gossip/src/push_active_set.rs @@ -29,14 +29,15 @@ impl PushActiveSet { pub(crate) fn get_nodes<'a>( &'a self, - pubkey: &Pubkey, // This node. + pubkey: &'a Pubkey, // This node. origin: &'a Pubkey, // CRDS value owner. // If true forces gossip push even if the node has pruned the origin. should_force_push: impl FnMut(&Pubkey) -> bool + 'a, stakes: &HashMap, ) -> impl Iterator + 'a { let stake = stakes.get(pubkey).min(stakes.get(origin)); - self.get_entry(stake).get_nodes(origin, should_force_push) + self.get_entry(stake) + .get_nodes(pubkey, origin, should_force_push) } // Prunes origins for the given gossip node. @@ -110,14 +111,20 @@ impl PushActiveSetEntry { fn get_nodes<'a>( &'a self, - origin: &'a Pubkey, + pubkey: &'a Pubkey, // This node. + origin: &'a Pubkey, // CRDS value owner. // If true forces gossip push even if the node has pruned the origin. mut should_force_push: impl FnMut(&Pubkey) -> bool + 'a, ) -> impl Iterator + 'a { + let pubkey_eq_origin = pubkey == origin; self.0 .iter() .filter(move |(node, bloom_filter)| { - !bloom_filter.contains(origin) || should_force_push(node) + // Bloom filter can return false positive for origin == pubkey + // but a node should always be able to push its own values. + !bloom_filter.contains(origin) + || (pubkey_eq_origin && &pubkey != node) + || should_force_push(node) }) .map(|(node, _bloom_filter)| node) } @@ -175,7 +182,10 @@ fn get_stake_bucket(stake: Option<&u64>) -> usize { #[cfg(test)] mod tests { - use {super::*, rand::SeedableRng, rand_chacha::ChaChaRng, std::iter::repeat_with}; + use { + super::*, itertools::iproduct, rand::SeedableRng, rand_chacha::ChaChaRng, + std::iter::repeat_with, + }; #[test] fn test_get_stake_bucket() { @@ -274,13 +284,13 @@ mod tests { assert_eq!(entry.0.len(), 5); let keys = [&nodes[16], &nodes[11], &nodes[17], &nodes[14], &nodes[5]]; assert!(entry.0.keys().eq(keys)); - for origin in &nodes { + for (pubkey, origin) in iproduct!(&nodes, &nodes) { if !keys.contains(&origin) { - assert!(entry.get_nodes(origin, |_| false).eq(keys)); + assert!(entry.get_nodes(pubkey, origin, |_| false).eq(keys)); } else { - assert!(entry.get_nodes(origin, |_| true).eq(keys)); + assert!(entry.get_nodes(pubkey, origin, |_| true).eq(keys)); assert!(entry - .get_nodes(origin, |_| false) + .get_nodes(pubkey, origin, |_| false) .eq(keys.into_iter().filter(|&key| key != origin))); } } @@ -288,10 +298,10 @@ mod tests { for (node, filter) in entry.0.iter() { assert!(filter.contains(node)); } - for origin in keys { - assert!(entry.get_nodes(origin, |_| true).eq(keys)); + for (pubkey, origin) in iproduct!(&nodes, keys) { + assert!(entry.get_nodes(pubkey, origin, |_| true).eq(keys)); assert!(entry - .get_nodes(origin, |_| false) + .get_nodes(pubkey, origin, |_| false) .eq(keys.into_iter().filter(|&node| node != origin))); } // Assert that prune excludes node from get. @@ -299,10 +309,12 @@ mod tests { entry.prune(&nodes[11], origin); entry.prune(&nodes[14], origin); entry.prune(&nodes[19], origin); - assert!(entry.get_nodes(origin, |_| true).eq(keys)); - assert!(entry.get_nodes(origin, |_| false).eq(keys - .into_iter() - .filter(|&&node| node != nodes[11] && node != nodes[14]))); + for pubkey in &nodes { + assert!(entry.get_nodes(pubkey, origin, |_| true).eq(keys)); + assert!(entry.get_nodes(pubkey, origin, |_| false).eq(keys + .into_iter() + .filter(|&&node| pubkey == origin || (node != nodes[11] && node != nodes[14])))); + } // Assert that rotate adds new nodes. entry.rotate(&mut rng, 5, NUM_BLOOM_FILTER_ITEMS, &nodes, &weights); let keys = [&nodes[11], &nodes[17], &nodes[14], &nodes[5], &nodes[7]];