Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

BREAKING: StateReader::get_class_hash_at return zero by default #1012

Merged
merged 5 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 15 additions & 30 deletions src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,12 @@ impl<T: StateReader> CachedState<T> {

impl<T: StateReader> StateReader for CachedState<T> {
/// Returns the class hash for a given contract address.
/// Returns zero as default value if missing
fn get_class_hash_at(&self, contract_address: &Address) -> Result<ClassHash, StateError> {
if self.cache.get_class_hash(contract_address).is_none() {
match self.state_reader.get_class_hash_at(contract_address) {
Ok(class_hash) => {
return Ok(class_hash);
}
Err(StateError::NoneContractState(_)) => {
return Ok([0; 32]);
}
Err(e) => {
return Err(e);
}
}
}

self.cache
.get_class_hash(contract_address)
.ok_or_else(|| StateError::NoneClassHash(contract_address.clone()))
.cloned()
.map(|a| Ok(*a))
.unwrap_or_else(|| self.state_reader.get_class_hash_at(contract_address))
}

/// Returns the nonce for a given contract address.
Expand Down Expand Up @@ -305,22 +292,20 @@ impl<T: StateReader> State for CachedState<T> {
Ok((n_modified_contracts, storage_updates.len()))
}

/// Returns the class hash for a given contract address.
/// Returns zero as default value if missing
/// Adds the value to the cache's inital_values if not present
fn get_class_hash_at(&mut self, contract_address: &Address) -> Result<ClassHash, StateError> {
if self.cache.get_class_hash(contract_address).is_none() {
let class_hash = match self.state_reader.get_class_hash_at(contract_address) {
Ok(class_hash) => class_hash,
Err(StateError::NoneContractState(_)) => [0; 32],
Err(e) => return Err(e),
};
self.cache
.class_hash_initial_values
.insert(contract_address.clone(), class_hash);
match self.cache.get_class_hash(contract_address) {
Some(class_hash) => Ok(*class_hash),
None => {
let class_hash = self.state_reader.get_class_hash_at(contract_address)?;
self.cache
.class_hash_initial_values
.insert(contract_address.clone(), class_hash);
Ok(class_hash)
}
}

self.cache
.get_class_hash(contract_address)
.ok_or_else(|| StateError::NoneClassHash(contract_address.clone()))
.cloned()
}

fn get_nonce_at(&mut self, contract_address: &Address) -> Result<Felt252, StateError> {
Expand Down
27 changes: 18 additions & 9 deletions src/state/in_memory_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::{
};
use cairo_vm::felt::Felt252;
use getset::{Getters, MutGetters};
use num_traits::Zero;
use std::collections::HashMap;

/// A [StateReader] that holds all the data in memory.
Expand Down Expand Up @@ -80,20 +79,19 @@ impl InMemoryStateReader {

impl StateReader for InMemoryStateReader {
fn get_class_hash_at(&self, contract_address: &Address) -> Result<ClassHash, StateError> {
let class_hash = self
Ok(self
.address_to_class_hash
.get(contract_address)
.ok_or_else(|| StateError::NoneContractState(contract_address.clone()));
class_hash.cloned()
.cloned()
.unwrap_or_default())
}

fn get_nonce_at(&self, contract_address: &Address) -> Result<Felt252, StateError> {
let default = Felt252::zero();
let nonce = self
Ok(self
.address_to_nonce
.get(contract_address)
.unwrap_or(&default);
Ok(nonce.clone())
.cloned()
.unwrap_or_default())
}

fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result<Felt252, StateError> {
Expand Down Expand Up @@ -132,12 +130,23 @@ impl StateReader for InMemoryStateReader {

#[cfg(test)]
mod tests {
use num_traits::One;
use num_traits::{One, Zero};

use super::*;
use crate::services::api::contract_classes::deprecated_contract_class::ContractClass;
use std::sync::Arc;

#[test]
fn get_class_hash_at_returns_zero_if_missing() {
let state_reader = InMemoryStateReader::default();
assert!(Felt252::from_bytes_be(
&state_reader
.get_class_hash_at(&Address(Felt252::one()))
.unwrap()
)
.is_zero())
}

#[test]
fn get_storage_returns_zero_if_missing() {
let state_reader = InMemoryStateReader::default();
Expand Down
3 changes: 3 additions & 0 deletions src/state/state_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub trait StateReader {
/// Returns the contract class of the given class hash or compiled class hash.
fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError>;
/// Returns the class hash of the contract class at the given address.
/// Returns zero by default if the value is not present
fn get_class_hash_at(&self, contract_address: &Address) -> Result<ClassHash, StateError>;
/// Returns the nonce of the given contract instance.
fn get_nonce_at(&self, contract_address: &Address) -> Result<Felt252, StateError>;
Expand Down Expand Up @@ -61,6 +62,8 @@ pub trait State {
fee_token_and_sender_address: Option<(&Address, &Address)>,
) -> Result<(usize, usize), FromByteArrayError>;

/// Returns the class hash of the contract class at the given address.
/// Returns zero by default if the value is not present
fn get_class_hash_at(&mut self, contract_address: &Address) -> Result<ClassHash, StateError>;

/// Default: 0 for an uninitialized contract address.
Expand Down