Skip to content

Commit

Permalink
implements weighted shuffle using N-ary tree (#259)
Browse files Browse the repository at this point in the history
This is port of firedancer's implementation of weighted shuffle:
https://github.com/firedancer-io/firedancer/blob/3401bfc26/src/ballet/wsample/fd_wsample.c

#185
implemented weighted shuffle using binary tree. Though asymptotically a
binary tree has better performance, compared to a Fenwick tree, it has
less cache locality resulting in smaller improvements and in particular
slower WeightedShuffle::new.

In order to improve cache locality and reduce the overheads of
traversing the tree, this commit instead uses a generalized N-ary tree
with fanout of 16, showing significant improvements in both
WeightedShuffle::new and WeightedShuffle::shuffle.

With 4000 weights:

N-ary tree (fanout 16):

    test bench_weighted_shuffle_new     ... bench:      36,244 ns/iter (+/- 243)
    test bench_weighted_shuffle_shuffle ... bench:     149,082 ns/iter (+/- 1,474)

Binary tree:

    test bench_weighted_shuffle_new     ... bench:      58,514 ns/iter (+/- 229)
    test bench_weighted_shuffle_shuffle ... bench:     269,961 ns/iter (+/- 16,446)

Fenwick tree:

    test bench_weighted_shuffle_new     ... bench:      39,413 ns/iter (+/- 179)
    test bench_weighted_shuffle_shuffle ... bench:     364,771 ns/iter (+/- 2,078)

The improvements become even more significant as there are more items to
shuffle. With 20_000 weights:

N-ary tree (fanout 16):

    test bench_weighted_shuffle_new     ... bench:     200,659 ns/iter (+/- 4,395)
    test bench_weighted_shuffle_shuffle ... bench:     941,928 ns/iter (+/- 26,492)

Binary tree:

    test bench_weighted_shuffle_new     ... bench:     881,114 ns/iter (+/- 12,343)
    test bench_weighted_shuffle_shuffle ... bench:   1,822,257 ns/iter (+/- 12,772)

Fenwick tree:

    test bench_weighted_shuffle_new     ... bench:     276,936 ns/iter (+/- 14,692)
    test bench_weighted_shuffle_shuffle ... bench:   2,644,713 ns/iter (+/- 49,252)
  • Loading branch information
behzadnouri authored Mar 26, 2024
1 parent b01d792 commit 30eecd6
Showing 1 changed file with 77 additions and 55 deletions.
132 changes: 77 additions & 55 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ use {
std::ops::{AddAssign, Sub, SubAssign},
};

// Each internal tree node has FANOUT many child nodes with indices:
// (index << BIT_SHIFT) + 1 ..= (index << BIT_SHIFT) + FANOUT
// Conversely, for each node, the parent node is obtained by:
// (index - 1) >> BIT_SHIFT
const BIT_SHIFT: usize = 4;
const FANOUT: usize = 1 << BIT_SHIFT;
const BIT_MASK: usize = FANOUT - 1;

/// Implements an iterator where indices are shuffled according to their
/// weights:
/// - Returned indices are unique in the range [0, weights.len()).
Expand All @@ -18,12 +26,13 @@ use {
/// non-zero weighted indices.
#[derive(Clone)]
pub struct WeightedShuffle<T> {
// Underlying array implementing binary tree.
// tree[i] is the sum of weights in the left sub-tree of node i.
tree: Vec<T>,
// Underlying array implementing the tree.
// tree[i][j] is the sum of all weights in the j'th sub-tree of node i.
tree: Vec<[T; FANOUT - 1]>,
// Current sum of all weights, excluding already sampled ones.
weight: T,
zeros: Vec<usize>, // Indices of zero weighted entries.
// Indices of zero weighted entries.
zeros: Vec<usize>,
}

impl<T> WeightedShuffle<T>
Expand All @@ -34,7 +43,7 @@ where
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let zero = <T as Default>::default();
let mut tree = vec![zero; get_tree_size(weights.len())];
let mut tree = vec![[zero; FANOUT - 1]; get_tree_size(weights.len())];
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative = 0;
Expand All @@ -59,12 +68,14 @@ where
continue;
}
};
let mut index = tree.len() + k;
// Traverse the tree from the leaf node upwards to the root,
// updating the sub-tree sums along the way.
let mut index = tree.len() + k; // leaf node
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
if offset > 0 {
tree[index] += weight;
tree[index][offset - 1] += weight;
}
}
}
Expand All @@ -88,54 +99,73 @@ where
{
// Removes given weight at index k.
fn remove(&mut self, k: usize, weight: T) {
debug_assert!(self.weight >= weight);
self.weight -= weight;
let mut index = self.tree.len() + k;
// Traverse the tree from the leaf node upwards to the root,
// updating the sub-tree sums along the way.
let mut index = self.tree.len() + k; // leaf node
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
if offset > 0 {
self.tree[index] -= weight;
debug_assert!(self.tree[index][offset - 1] >= weight);
self.tree[index][offset - 1] -= weight;
}
}
}

// Returns smallest index such that cumsum of weights[..=k] > val,
// Returns smallest index such that sum of weights[..=k] > val,
// along with its respective weight.
fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) {
let zero = <T as Default>::default();
debug_assert!(val >= zero);
debug_assert!(val < self.weight);
let mut index = 0;
// Traverse the tree downwards from the root while maintaining the
// weight of the subtree which contains the target leaf node.
let mut index = 0; // root
let mut weight = self.weight;
while index < self.tree.len() {
if val < self.tree[index] {
weight = self.tree[index];
index = (index << 1) + 1;
} else {
weight -= self.tree[index];
val -= self.tree[index];
index = (index << 1) + 2;
'outer: while index < self.tree.len() {
for (j, &node) in self.tree[index].iter().enumerate() {
if val < node {
// Traverse to the j+1 subtree of self.tree[index].
weight = node;
index = (index << BIT_SHIFT) + j + 1;
continue 'outer;
} else {
debug_assert!(weight >= node);
weight -= node;
val -= node;
}
}
// Traverse to the right-most subtree of self.tree[index].
index = (index << BIT_SHIFT) + FANOUT;
}
(index - self.tree.len(), weight)
}

pub fn remove_index(&mut self, k: usize) {
let mut index = self.tree.len() + k;
// Traverse the tree from the leaf node upwards to the root, while
// maintaining the sum of weights of subtrees *not* containing the leaf
// node.
let mut index = self.tree.len() + k; // leaf node
let mut weight = <T as Default>::default(); // zero
while index != 0 {
let offset = index & 1;
index = (index - 1) >> 1;
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
if offset > 0 {
if self.tree[index] != weight {
self.remove(k, self.tree[index] - weight);
if self.tree[index][offset - 1] != weight {
self.remove(k, self.tree[index][offset - 1] - weight);
} else {
self.remove_zero(k);
}
return;
}
weight += self.tree[index];
// The leaf node is in the right-most subtree of self.tree[index].
for &node in &self.tree[index] {
weight += node;
}
}
// The leaf node is the right-most node of the whole tree.
if self.weight != weight {
self.remove(k, self.weight - weight);
} else {
Expand Down Expand Up @@ -193,17 +223,16 @@ where
}
}

// Maps number of items to the "internal" size of the binary tree "implicitly"
// holding those items on the leaves.
// Maps number of items to the "internal" size of the tree
// which "implicitly" holds those items on the leaves.
fn get_tree_size(count: usize) -> usize {
let shift = usize::BITS
- count.leading_zeros()
- if count.is_power_of_two() && count != 1 {
1
} else {
0
};
(1usize << shift) - 1
let mut size = if count == 1 { 1 } else { 0 };
let mut nodes = 1;
while nodes < count {
size += nodes;
nodes *= FANOUT;
}
size
}

#[cfg(test)]
Expand Down Expand Up @@ -251,25 +280,18 @@ mod tests {
#[test]
fn test_get_tree_size() {
assert_eq!(get_tree_size(0), 0);
assert_eq!(get_tree_size(1), 1);
assert_eq!(get_tree_size(2), 1);
assert_eq!(get_tree_size(3), 3);
assert_eq!(get_tree_size(4), 3);
for count in 5..9 {
assert_eq!(get_tree_size(count), 7);
for count in 1..=16 {
assert_eq!(get_tree_size(count), 1);
}
for count in 17..=256 {
assert_eq!(get_tree_size(count), 1 + 16);
}
for count in 9..17 {
assert_eq!(get_tree_size(count), 15);
for count in 257..=4096 {
assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16);
}
for count in 17..33 {
assert_eq!(get_tree_size(count), 31);
for count in 4097..=65536 {
assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16 + 16 * 16 * 16);
}
assert_eq!(get_tree_size((1 << 16) - 1), (1 << 16) - 1);
assert_eq!(get_tree_size(1 << 16), (1 << 16) - 1);
assert_eq!(get_tree_size((1 << 16) + 1), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 17) - 1), (1 << 17) - 1);
assert_eq!(get_tree_size(1 << 17), (1 << 17) - 1);
assert_eq!(get_tree_size((1 << 17) + 1), (1 << 18) - 1);
}

// Asserts that empty weights will return empty shuffle.
Expand Down

0 comments on commit 30eecd6

Please sign in to comment.