Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

perf: refactor substract_mappings and friends to avoid clones #1023

Merged
merged 6 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use crate::{
services::api::contract_classes::compiled_class::CompiledClass,
state::StateDiff,
utils::{
get_erc20_balance_var_addresses, subtract_mappings, to_cache_state_storage_mapping,
Address, ClassHash,
get_erc20_balance_var_addresses, subtract_mappings, subtract_mappings_keys,
to_cache_state_storage_mapping, Address, ClassHash,
},
};
use cairo_vm::felt::Felt252;
Expand Down Expand Up @@ -281,32 +281,24 @@ impl<T: StateReader> State for CachedState<T> {
self.update_initial_values_of_write_only_accesses()?;

let mut storage_updates = subtract_mappings(
self.cache.storage_writes.clone(),
self.cache.storage_initial_values.clone(),
&self.cache.storage_writes,
&self.cache.storage_initial_values,
);

let storage_unique_updates = storage_updates.keys().map(|k| k.0.clone());

let class_hash_updates: Vec<_> = subtract_mappings(
self.cache.class_hash_writes.clone(),
self.cache.class_hash_initial_values.clone(),
)
.keys()
.cloned()
.collect();
let class_hash_updates = subtract_mappings_keys(
&self.cache.class_hash_writes,
&self.cache.class_hash_initial_values,
);

let nonce_updates: Vec<_> = subtract_mappings(
self.cache.nonce_writes.clone(),
self.cache.nonce_initial_values.clone(),
)
.keys()
.cloned()
.collect();
let nonce_updates =
subtract_mappings_keys(&self.cache.nonce_writes, &self.cache.nonce_initial_values);

let mut modified_contracts: HashSet<Address> = HashSet::new();
modified_contracts.extend(storage_unique_updates);
modified_contracts.extend(class_hash_updates);
modified_contracts.extend(nonce_updates);
modified_contracts.extend(class_hash_updates.cloned());
modified_contracts.extend(nonce_updates.cloned());

// Add fee transfer storage update before actually charging it, as it needs to be included in the
// calculation of the final fee.
Expand Down
27 changes: 12 additions & 15 deletions src/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,25 +134,23 @@ impl StateDiff {
let state_cache = cached_state.cache().to_owned();

let substracted_maps = subtract_mappings(
state_cache.storage_writes.clone(),
state_cache.storage_initial_values.clone(),
&state_cache.storage_writes,
&state_cache.storage_initial_values,
);

let storage_updates = to_state_diff_storage_mapping(substracted_maps);

let address_to_nonce = subtract_mappings(
state_cache.nonce_writes.clone(),
state_cache.nonce_initial_values.clone(),
);
let address_to_nonce =
subtract_mappings(&state_cache.nonce_writes, &state_cache.nonce_initial_values);

let class_hash_to_compiled_class = subtract_mappings(
state_cache.compiled_class_hash_writes.clone(),
state_cache.compiled_class_hash_initial_values.clone(),
&state_cache.compiled_class_hash_writes,
&state_cache.compiled_class_hash_initial_values,
);

let address_to_class_hash = subtract_mappings(
state_cache.class_hash_writes.clone(),
state_cache.class_hash_initial_values,
&state_cache.class_hash_writes,
&state_cache.class_hash_initial_values,
);

Ok(StateDiff {
Expand Down Expand Up @@ -193,23 +191,22 @@ impl StateDiff {

let mut storage_updates = HashMap::new();

let addresses: Vec<Address> =
get_keys(self.storage_updates.clone(), other.storage_updates.clone());
let addresses: Vec<&Address> = get_keys(&self.storage_updates, &other.storage_updates);

for address in addresses {
let default: HashMap<Felt252, Felt252> = HashMap::new();
let mut map_a = self
.storage_updates
.get(&address)
.get(address)
.unwrap_or(&default)
.to_owned();
let map_b = other
.storage_updates
.get(&address)
.get(address)
.unwrap_or(&default)
.to_owned();
map_a.extend(map_b);
storage_updates.insert(address, map_a.clone());
storage_updates.insert(address.clone(), map_a.clone());
}

StateDiff {
Expand Down
38 changes: 28 additions & 10 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,38 @@ where
V: PartialEq + Clone,
{
let val = map.get(key);
!(map.contains_key(key) && (Some(value) == val))
Some(value) != val
}

pub fn subtract_mappings<K, V>(map_a: HashMap<K, V>, map_b: HashMap<K, V>) -> HashMap<K, V>
pub fn subtract_mappings<'a, K, V>(
map_a: &'a HashMap<K, V>,
map_b: &'a HashMap<K, V>,
) -> HashMap<K, V>
where
K: Hash + Eq + Clone,
V: PartialEq + Clone,
{
map_a
.into_iter()
.filter(|(k, v)| contained_and_not_updated(k, v, &map_b))
.iter()
.filter(|(k, v)| contained_and_not_updated(*k, *v, map_b))
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}

pub fn subtract_mappings_keys<'a, K, V>(
map_a: &'a HashMap<K, V>,
map_b: &'a HashMap<K, V>,
) -> impl Iterator<Item = &'a K>
where
K: Hash + Eq + Clone,
V: PartialEq + Clone,
{
map_a
.iter()
.filter(|(k, v)| contained_and_not_updated(*k, *v, map_b))
.map(|x| x.0)
}

/// Converts StateDiff storage mapping (addresses map to a key-value mapping) to CachedState
/// storage mapping (Tuple of address and key map to the associated value).
pub fn to_cache_state_storage_mapping(
Expand All @@ -260,12 +278,12 @@ pub fn to_cache_state_storage_mapping(

// get a vector of keys from two hashmaps

pub fn get_keys<K, V>(map_a: HashMap<K, V>, map_b: HashMap<K, V>) -> Vec<K>
pub fn get_keys<'a, K, V>(map_a: &'a HashMap<K, V>, map_b: &'a HashMap<K, V>) -> Vec<&'a K>
where
K: Hash + Eq,
{
let mut keys1: HashSet<K> = map_a.into_keys().collect();
let keys2: HashSet<K> = map_b.into_keys().collect();
let mut keys1: HashSet<&K> = map_a.keys().collect();
let keys2: HashSet<&K> = map_b.keys().collect();

keys1.extend(keys2);

Expand Down Expand Up @@ -647,7 +665,7 @@ mod test {
.into_iter()
.collect::<HashMap<&str, i32>>();

assert_eq!(subtract_mappings(a, b), res);
assert_eq!(subtract_mappings(&a, &b), res);

let mut c = HashMap::new();
let mut d = HashMap::new();
Expand All @@ -664,7 +682,7 @@ mod test {
.into_iter()
.collect::<HashMap<i32, i32>>();

assert_eq!(subtract_mappings(c, d), res);
assert_eq!(subtract_mappings(&c, &d), res);

let mut e = HashMap::new();
let mut f = HashMap::new();
Expand All @@ -676,7 +694,7 @@ mod test {
f.insert(3, 4);
f.insert(6, 7);

assert_eq!(subtract_mappings(e, f), HashMap::new())
assert_eq!(subtract_mappings(&e, &f), HashMap::new())
}

#[test]
Expand Down