Skip to content

Commit

Permalink
New tree rollback method and wipe tree routine
Browse files Browse the repository at this point in the history
  • Loading branch information
EvgenKor committed Dec 7, 2022
1 parent 714457f commit 68e3a60
Showing 1 changed file with 132 additions and 28 deletions.
160 changes: 132 additions & 28 deletions libzkbob-rs/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,11 +855,13 @@ impl<D: KeyValueDB, P: PoolParams> MerkleTree<D, P> {
self.next_index
}

pub fn rollback(&mut self, rollback_index: u64) -> Option<u64> {
let mut result: Option<u64> = None;

// Returns tree root after rollback (None if lack of nodes)
// CAUTION: if this method returns None the Merkle tree is likely corrupted!
// You strongly need to rebuild or restore it in a some way
pub fn rollback(&mut self, rollback_index: u64) -> Option<Hash<P::Fr>> {
let rollback_index = if Some(rollback_index) <= self.first_index() {
0 // supporting rollback for the partial tree
// (if we rollback to the left of first_index - clear all tree)
} else {
rollback_index
};
Expand All @@ -883,14 +885,13 @@ impl<D: KeyValueDB, P: PoolParams> MerkleTree<D, P> {
index /= 2;
}
if nodes_request_index < clean_index {
result = Some(nodes_request_index)
return None;
}
}



// Update next_index.
let original_next_index = self.next_index;
let new_next_index = if rollback_index > 0 {
Self::calc_next_index(rollback_index - 1)
} else {
Expand All @@ -905,36 +906,103 @@ impl<D: KeyValueDB, P: PoolParams> MerkleTree<D, P> {
self.first_index = None;
}

// remove leaves
/* // OLD IMPLEMENTATION: VERY SLOW
for index in (rollback_index..original_next_index).rev() {
self.remove_leaf(index);
}*/

// NEW IMPLEMENTATION
let mut batch = self.db.transaction();

// remove unnecessary nodes
let mut remove_batch = self.db.transaction();
self.db
.iter(DbCols::Leaves as u32)
.filter_map(|(key, _)| {
.for_each(|(key, _)| {
let (height, index) = Self::parse_node_key(&key);
let left_index = index << height;
//let left_index = index << height;
let right_index = (index + 1) << height;
if right_index > rollback_index {
return Some(key);
} else {
return None;
remove_batch.delete(DbCols::Leaves as u32, &key);
remove_batch.delete(DbCols::TempLeaves as u32, &key);
//self.remove_batched(&mut remove_batch, height, index);
}
})
.for_each(|key| {
batch.delete(DbCols::Leaves as u32, &key);
//batch.delete(DbCols::TempLeaves as u32, &key);
});
self.db.write(remove_batch).unwrap();

// build the root node and restore intermediate ones
let mut update_batch = self.db.transaction();
let new_root = self.get_node_full_batched(constants::HEIGHT as u32, 0, &mut update_batch, new_next_index);
self.db.write(update_batch).unwrap();

self.db.write(batch).unwrap();
new_root
}

// The rollback helper: build node (e.g. root) within
// nodes before right_index (at level 0)
// It's supposed that the current tree has no nodes
// which subtree placed to the right of right_index
// Returns None in case of lack of nodes
// NOTE: as get_virtual_node_full routine, but not virtual one :)
// it operates only with the currently filled nodes and update
// the transaction batch to update the database
fn get_node_full_batched(
&mut self,
height: u32,
index: u64,
batch: &mut DBTransaction,
right_index: u64,
) -> Option<Hash<P::Fr>> {
let node_left = index * (1 << height);
//let node_right = (index + 1) * (1 << height);
if node_left >= right_index {
//print!("💬[{}.{}] ", height, index);
return Some(self.default_hashes[height as usize]);
}

match self.get_opt(height, index) {
Some(hash) => {
//print!("✅[{}.{}] ", height, index);
Some(hash)
},
None => {
if height > 0 {
let left_child = self.get_node_full_batched(
height - 1,
2 * index,
batch,
right_index,
);
let right_child = self.get_node_full_batched(
height - 1,
2 * index + 1,
batch,
right_index,
);

if left_child.is_some() && right_child.is_some() {
let pair = [left_child.unwrap(), right_child.unwrap()];
let hash = poseidon(pair.as_ref(), self.params.compress());

//print!("🧮[{}.{}]={}.{}+{}.{} ", height, index, height - 1, index << 1, height - 1, (index << 1) + 1);

self.set_batched(batch, height, index, hash, 0);

Some(hash)
} else {
//print!("❌[{}.{}] ", height, index);
None
}
} else {
//print!("❌[{}.{}] ", height, index);
None
}
}
}
}

result
pub fn wipe(&mut self) {
let mut wipe_batch = self.db.transaction();
wipe_batch.delete_prefix(DbCols::Leaves as u32, &[]);
wipe_batch.delete_prefix(DbCols::TempLeaves as u32, &[]);
self.db.write(wipe_batch).unwrap();

self.first_index = None;
self.next_index = 0;
self.write_index_to_database(FIRST_INDEX_KEY, None);
self.write_index_to_database(NEXT_INDEX_KEY, Some(self.next_index));
}

pub fn get_all_nodes(&self) -> Vec<Node<P::Fr>> {
Expand Down Expand Up @@ -1109,6 +1177,7 @@ impl<D: KeyValueDB, P: PoolParams> MerkleTree<D, P> {
batch.delete(DbCols::TempLeaves as u32, &key);
}

#[cfg(test)]
fn remove_leaf(&mut self, index: u64) {
let mut batch = self.db.transaction();

Expand Down Expand Up @@ -1514,16 +1583,18 @@ mod tests {

// 1st rollback
let rollback1_result = tree.rollback(first_part);
assert!(rollback1_result.is_none());
assert!(rollback1_result.is_some());
let rollback1_root = tree.get_root();
assert_eq!(rollback1_root, one_more_root);
assert_eq!(rollback1_root, rollback1_result.unwrap());
assert_eq!(tree.next_index, first_part);

// 2nd rollback
let rollback2_result = tree.rollback(0);
assert!(rollback2_result.is_none());
assert!(rollback2_result.is_some());
let rollback2_root = tree.get_root();
assert_eq!(rollback2_root, original_root);
assert_eq!(rollback2_root, rollback2_result.unwrap());
assert_eq!(tree.next_index, 0);
}

Expand All @@ -1546,9 +1617,10 @@ mod tests {
}

let rollback_result = tree.rollback(128);
assert!(rollback_result.is_none());
assert!(rollback_result.is_some());
let rollback_root = tree.get_root();
assert_eq!(rollback_root, original_root);
assert_eq!(rollback_root, rollback_result.unwrap());
assert_eq!(tree.next_index, 128);
}

Expand Down Expand Up @@ -1926,6 +1998,8 @@ mod tests {
let mut full_tree = MerkleTree::new(create(3), POOL_PARAMS.clone());
let mut partial_tree = MerkleTree::new(create(3), POOL_PARAMS.clone());

let zero_root = full_tree.get_root();

assert!(full_tree.first_index().is_none());
assert!(partial_tree.first_index().is_none());

Expand Down Expand Up @@ -1985,6 +2059,7 @@ mod tests {
partial_tree.rollback(partial_tree_start_index);
assert_eq!(partial_tree.next_index(), 0);
assert_eq!(partial_tree.first_index(), None);
assert_eq!(partial_tree.get_root(), zero_root);

}

Expand Down Expand Up @@ -2040,4 +2115,33 @@ mod tests {
assert_eq!(root_at_checkpoint.unwrap(), root_checkpoint);
assert_eq!(root_at_last.unwrap(), root_last);
}

#[test_case(0, 1)]
#[test_case(1, 1)]
#[test_case(10, 2)]
#[test_case(20, 128)]
#[test_case(42, 7)]
fn test_wipe(tx_count: u64, leaves_count: u64) {
let mut rng = CustomRng;
let mut tree = MerkleTree::new(create(3), POOL_PARAMS.clone());

let zero_root = tree.get_root();

let leafs: Vec<(u64, Vec<_>)> = (0..tx_count)
.map(|i| {
(i * (constants::OUT + 1) as u64, (0..leaves_count).map(|_| rng.gen()).collect())
})
.collect();

for (index, leafs) in leafs.clone().into_iter() {
tree.add_hashes(index, leafs)
}

tree.wipe();

assert_eq!(tree.get_root(), zero_root);
assert_eq!(tree.first_index(), None);
assert_eq!(tree.next_index(), 0);
assert_eq!(tree.get_leaves().len(), 0);
}
}

0 comments on commit 68e3a60

Please sign in to comment.