Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Replace RefCell with RwLock #1052

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,22 @@ use cairo_vm::felt::Felt252;
use getset::{Getters, MutGetters};
use num_traits::Zero;
use std::{
cell::RefCell,
collections::{HashMap, HashSet},
sync::Arc,
sync::{Arc, RwLock},
};

pub const UNINITIALIZED_CLASS_HASH: &ClassHash = &[0u8; 32];

/// Represents a cached state of contract classes with optional caches.
#[derive(Default, Clone, Debug, Getters, MutGetters)]
#[derive(Default, Debug, Getters, MutGetters)]
pub struct CachedState<T: StateReader, C: ContractClassCache> {
pub state_reader: Arc<T>,
#[getset(get = "pub", get_mut = "pub")]
pub(crate) cache: StateCache,

#[getset(get = "pub", get_mut = "pub")]
pub(crate) contract_class_cache: Arc<C>,
pub(crate) contract_class_cache_private: RefCell<HashMap<ClassHash, CompiledClass>>,
pub(crate) contract_class_cache_private: RwLock<HashMap<ClassHash, CompiledClass>>,

#[cfg(feature = "metrics")]
cache_hits: usize,
Expand Down Expand Up @@ -73,7 +72,7 @@ impl<T: StateReader, C: ContractClassCache> CachedState<T, C> {
cache: StateCache::default(),
state_reader,
contract_class_cache: contract_classes,
contract_class_cache_private: RefCell::new(HashMap::new()),
contract_class_cache_private: RwLock::new(HashMap::new()),

#[cfg(feature = "metrics")]
cache_hits: 0,
Expand All @@ -92,7 +91,7 @@ impl<T: StateReader, C: ContractClassCache> CachedState<T, C> {
cache,
state_reader,
contract_class_cache: contract_classes,
contract_class_cache_private: RefCell::new(HashMap::new()),
contract_class_cache_private: RwLock::new(HashMap::new()),

#[cfg(feature = "metrics")]
cache_hits: 0,
Expand All @@ -104,7 +103,11 @@ impl<T: StateReader, C: ContractClassCache> CachedState<T, C> {
pub fn drain_private_contract_class_cache(
&self,
) -> impl Iterator<Item = (ClassHash, CompiledClass)> {
self.contract_class_cache_private.take().into_iter()
self.contract_class_cache_private
.read()
.unwrap()
.clone()
.into_iter()
}

/// Creates a copy of this state with an empty cache for saving changes and applying them
Expand All @@ -115,8 +118,8 @@ impl<T: StateReader, C: ContractClassCache> CachedState<T, C> {
state_reader,
cache: self.cache.clone(),
contract_class_cache: self.contract_class_cache.clone(),
contract_class_cache_private: RefCell::new(
self.contract_class_cache_private.borrow().clone(),
contract_class_cache_private: RwLock::new(
self.contract_class_cache_private.read().unwrap().clone(),
),
#[cfg(feature = "metrics")]
cache_hits: 0,
Expand Down Expand Up @@ -177,7 +180,7 @@ impl<T: StateReader, C: ContractClassCache> StateReader for CachedState<T, C> {
}

// I: FETCHING FROM CACHE
let mut private_cache = self.contract_class_cache_private.borrow_mut();
let mut private_cache = self.contract_class_cache_private.write().unwrap();
if let Some(compiled_class) = private_cache.get(class_hash) {
return Ok(compiled_class.clone());
} else if let Some(compiled_class) =
Expand Down Expand Up @@ -221,6 +224,7 @@ impl<T: StateReader, C: ContractClassCache> State for CachedState<T, C> {
// have a mutable reference to the `RefCell` available.
self.contract_class_cache_private
.get_mut()
.unwrap()
.insert(*class_hash, contract_class.clone());

Ok(())
Expand Down Expand Up @@ -446,6 +450,7 @@ impl<T: StateReader, C: ContractClassCache> State for CachedState<T, C> {
if let Some(compiled_class) = self
.contract_class_cache_private
.get_mut()
.unwrap()
.get(class_hash)
.cloned()
{
Expand All @@ -457,6 +462,7 @@ impl<T: StateReader, C: ContractClassCache> State for CachedState<T, C> {
self.add_hit();
self.contract_class_cache_private
.get_mut()
.unwrap()
.insert(*class_hash, compiled_class.clone());
return Ok(compiled_class);
}
Expand All @@ -465,14 +471,11 @@ impl<T: StateReader, C: ContractClassCache> State for CachedState<T, C> {
if let Some(compiled_class_hash) =
self.cache.class_hash_to_compiled_class_hash.get(class_hash)
{
let write_guard = self.contract_class_cache_private.get_mut().unwrap();

// `RefCell::get_mut()` provides a mutable reference without the borrowing overhead when
// we have a mutable reference to the `RefCell` available.
if let Some(casm_class) = self
.contract_class_cache_private
.get_mut()
.get(compiled_class_hash)
.cloned()
{
if let Some(casm_class) = write_guard.get(compiled_class_hash).cloned() {
self.add_hit();
return Ok(casm_class);
} else if let Some(casm_class) = self
Expand All @@ -482,6 +485,7 @@ impl<T: StateReader, C: ContractClassCache> State for CachedState<T, C> {
self.add_hit();
self.contract_class_cache_private
.get_mut()
.unwrap()
.insert(*class_hash, casm_class.clone());
return Ok(casm_class);
}
Expand Down Expand Up @@ -517,7 +521,7 @@ pub type TransactionalCachedState<'a, T, C> =
/// In practice this will act as a way to access the parent state's cache and other fields,
/// without referencing the whole parent state, so there's no need to adapt state-modifying
/// functions in the case that a transactional state is needed.
#[derive(Debug, MutGetters, Getters, PartialEq, Clone)]
#[derive(Debug, MutGetters, Getters)]
pub struct TransactionalCachedStateReader<'a, T: StateReader, C: ContractClassCache> {
/// The parent state's state_reader
#[get(get = "pub")]
Expand All @@ -529,7 +533,7 @@ pub struct TransactionalCachedStateReader<'a, T: StateReader, C: ContractClassCa
/// The parent state's contract_classes
#[get(get = "pub")]
pub(crate) contract_class_cache: Arc<C>,
pub(crate) contract_class_cache_private: &'a RefCell<HashMap<ClassHash, CompiledClass>>,
pub(crate) contract_class_cache_private: &'a RwLock<HashMap<ClassHash, CompiledClass>>,
}

impl<'a, T: StateReader, C: ContractClassCache> TransactionalCachedStateReader<'a, T, C> {
Expand Down Expand Up @@ -602,7 +606,7 @@ impl<'a, T: StateReader, C: ContractClassCache> StateReader
}

// I: FETCHING FROM CACHE
let mut private_cache = self.contract_class_cache_private.borrow_mut();
let mut private_cache = self.contract_class_cache_private.write().unwrap();
if let Some(compiled_class) = private_cache.get(class_hash) {
return Ok(compiled_class.clone());
} else if let Some(compiled_class) =
Expand Down