Skip to content

Commit

Permalink
compute merkle root on chunks of fanout^3 (solana-labs#15344)
Browse files Browse the repository at this point in the history
* compute merkle root on chunks of fanout^3

* improve test_accountsdb_compute_merkle_root_large
  • Loading branch information
jeffwashington authored Feb 16, 2021
1 parent ba02452 commit 8367740
Showing 1 changed file with 114 additions and 64 deletions.
178 changes: 114 additions & 64 deletions runtime/src/accounts_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3624,6 +3624,7 @@ impl AccountsDB {
fn compute_merkle_root_from_slices<'a, F>(
total_hashes: usize,
fanout: usize,
max_levels_per_pass: Option<usize>,
get_hashes: F,
) -> Hash
where
Expand All @@ -3635,7 +3636,17 @@ impl AccountsDB {

let mut time = Measure::start("time");

let chunks = Self::div_ceil(total_hashes, fanout);
const THREE_LEVEL_OPTIMIZATION: usize = 3; // this '3' is dependent on the code structure below where we manually unroll
let target = fanout.pow(THREE_LEVEL_OPTIMIZATION as u32);

// Only use the 3 level optimization if we have at least 4 levels of data.
// Otherwise, we'll be serializing a parallel operation.
let threshold = target * fanout;
let three_level = max_levels_per_pass.unwrap_or(usize::MAX) >= THREE_LEVEL_OPTIMIZATION
&& total_hashes >= threshold;
let num_hashes_per_chunk = if three_level { target } else { fanout };

let chunks = Self::div_ceil(total_hashes, num_hashes_per_chunk);

// initial fetch - could return entire slice
let data: &[Hash] = get_hashes(0);
Expand All @@ -3644,24 +3655,54 @@ impl AccountsDB {
let result: Vec<_> = (0..chunks)
.into_par_iter()
.map(|i| {
let start_index = i * fanout;
let end_index = std::cmp::min(start_index + fanout, total_hashes);
let start_index = i * num_hashes_per_chunk;
let end_index = std::cmp::min(start_index + num_hashes_per_chunk, total_hashes);

let mut hasher = Hasher::default();
let mut data_index = start_index;
let mut data = data;
let mut data_len = data_len;

for i in start_index..end_index {
if data_index >= data_len {
// fetch next slice
data = get_hashes(i);
data_len = data.len();
data_index = 0;
if !three_level {
// 1 group of fanout
// The result of this loop is a single hash value from fanout input hashes.
for i in start_index..end_index {
if data_index >= data_len {
// fetch next slice
data = get_hashes(i);
data_len = data.len();
data_index = 0;
}
hasher.hash(data[data_index].as_ref());
data_index += 1;
}
} else {
// hash 3 levels of fanout simultaneously.
// The result of this loop is a single hash value from fanout^3 input hashes.
let mut i = start_index;
while i < end_index {
let mut hasher_j = Hasher::default();
for _j in 0..fanout {
let mut hasher_k = Hasher::default();
let end = std::cmp::min(end_index - i, fanout);
for _k in 0..end {
if data_index >= data_len {
// fetch next slice
data = get_hashes(i);
data_len = data.len();
data_index = 0;
}
hasher_k.hash(data[data_index].as_ref());
data_index += 1;
i += 1;
}
hasher_j.hash(hasher_k.result().as_ref());
if i >= end_index {
break;
}
}
hasher.hash(hasher_j.result().as_ref());
}

hasher.hash(data[data_index].as_ref());
data_index += 1;
}

hasher.result()
Expand Down Expand Up @@ -3823,10 +3864,12 @@ impl AccountsDB {
let hash_total = cumulative_offsets.total_count;
let total_lamports = *total_lamports.lock().unwrap();
let mut hash_time = Measure::start("hash");
let accumulated_hash =
Self::compute_merkle_root_from_slices(hash_total, MERKLE_FANOUT, |start: usize| {
cumulative_offsets.get_slice(&hashes, start)
});
let accumulated_hash = Self::compute_merkle_root_from_slices(
hash_total,
MERKLE_FANOUT,
None,
|start: usize| cumulative_offsets.get_slice(&hashes, start),
);
hash_time.stop();
datapoint_info!(
"update_accounts_hash",
Expand Down Expand Up @@ -4118,7 +4161,8 @@ impl AccountsDB {
let offsets = CumulativeOffsets::from_raw_2d(&hashes);

let get_slice = |start: usize| -> &[Hash] { offsets.get_slice_2d(&hashes, start) };
let hash = Self::compute_merkle_root_from_slices(offsets.total_count, fanout, get_slice);
let hash =
Self::compute_merkle_root_from_slices(offsets.total_count, fanout, None, get_slice);
hash_time.stop();
stats.hash_time_total_us += hash_time.as_us();
stats.hash_total = offsets.total_count;
Expand Down Expand Up @@ -6459,55 +6503,44 @@ pub mod tests {

fn test_hashing_larger(hashes: Vec<(Pubkey, Hash)>, fanout: usize) -> Hash {
let result = AccountsDB::compute_merkle_root(hashes.clone(), fanout);
if hashes.len() >= fanout * fanout * fanout {
let reduced: Vec<_> = hashes.iter().map(|x| x.1).collect();
let result2 =
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, |start| {
&reduced[start..]
});
assert_eq!(result, result2);

let reduced2: Vec<_> = hashes.iter().map(|x| vec![x.1]).collect();
let result2 = AccountsDB::flatten_hashes_and_hash(
vec![reduced2],
fanout,
&mut HashStats::default(),
);
assert_eq!(result, result2);

for left in 0..reduced.len() {
for right in left + 1..reduced.len() {
let src = vec![
vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()],
vec![reduced[right..].to_vec()],
];
let result2 =
AccountsDB::flatten_hashes_and_hash(src, fanout, &mut HashStats::default());
assert_eq!(result, result2);
}
}
}
let reduced: Vec<_> = hashes.iter().map(|x| x.1).collect();
let result2 = test_hashing(reduced, fanout);
assert_eq!(result, result2, "len: {}", hashes.len());
result
}

fn test_hashing(hashes: Vec<Hash>, fanout: usize) -> Hash {
let temp: Vec<_> = hashes.iter().map(|h| (Pubkey::default(), *h)).collect();
let result = AccountsDB::compute_merkle_root(temp, fanout);
if hashes.len() >= fanout * fanout * fanout {
let reduced: Vec<_> = hashes.clone();
let result2 =
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, |start| {
&reduced[start..]
});
assert_eq!(result, result2, "len: {}", hashes.len());
let reduced: Vec<_> = hashes.clone();
let result2 =
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, None, |start| {
&reduced[start..]
});
assert_eq!(result, result2, "len: {}", hashes.len());

let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect();
let result2 = AccountsDB::flatten_hashes_and_hash(
vec![reduced2],
fanout,
&mut HashStats::default(),
);
assert_eq!(result, result2, "len: {}", hashes.len());
let result2 =
AccountsDB::compute_merkle_root_from_slices(hashes.len(), fanout, Some(1), |start| {
&reduced[start..]
});
assert_eq!(result, result2, "len: {}", hashes.len());

let reduced2: Vec<_> = hashes.iter().map(|x| vec![*x]).collect();
let result2 =
AccountsDB::flatten_hashes_and_hash(vec![reduced2], fanout, &mut HashStats::default());
assert_eq!(result, result2, "len: {}", hashes.len());

let max = std::cmp::min(reduced.len(), fanout * 2);
for left in 0..max {
for right in left + 1..max {
let src = vec![
vec![reduced[0..left].to_vec(), reduced[left..right].to_vec()],
vec![reduced[right..].to_vec()],
];
let result2 =
AccountsDB::flatten_hashes_and_hash(src, fanout, &mut HashStats::default());
assert_eq!(result, result2);
}
}
result
}
Expand All @@ -6516,12 +6549,29 @@ pub mod tests {
fn test_accountsdb_compute_merkle_root_large() {
solana_logger::setup();

let mut num = 100;
for _pass in 0..2 {
num *= 10;
let hashes: Vec<_> = (0..num).into_iter().map(|_| Hash::new_unique()).collect();
// handle fanout^x -1, +0, +1 for a few 'x's
const FANOUT: usize = 3;
let mut hash_counts: Vec<_> = (1..6)
.map(|x| {
let mark = FANOUT.pow(x);
vec![mark - 1, mark, mark + 1]
})
.flatten()
.collect();

// saturate the test space for threshold to threshold + target
// this hits right before we use the 3 deep optimization and all the way through all possible partial last chunks
let target = FANOUT.pow(3);
let threshold = target * FANOUT;
hash_counts.extend(threshold - 1..=threshold + target);

for hash_count in hash_counts {
let hashes: Vec<_> = (0..hash_count)
.into_iter()
.map(|_| Hash::new_unique())
.collect();

test_hashing(hashes, MERKLE_FANOUT);
test_hashing(hashes, FANOUT);
}
}

Expand Down

0 comments on commit 8367740

Please sign in to comment.