Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Refactor class hash #1095

Merged
merged 14 commits into from
Nov 17, 2023
Merged
21 changes: 9 additions & 12 deletions bench/internals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,24 @@ use starknet_in_rust::{
state::in_memory_state_reader::InMemoryStateReader,
state::{cached_state::CachedState, state_api::State},
transaction::{declare::Declare, Deploy, DeployAccount, InvokeFunction},
utils::Address,
utils::{Address, ClassHash},
};
use std::{collections::HashMap, hint::black_box, sync::Arc};

#[cfg(feature = "cairo-native")]
use {
starknet_in_rust::utils::ClassHash,
std::{cell::RefCell, rc::Rc},
};
use std::{cell::RefCell, rc::Rc};

lazy_static! {
// include_str! doesn't seem to work in CI
static ref CONTRACT_CLASS: ContractClass = ContractClass::from_path(
"starknet_programs/account_without_validation.json",
).unwrap();
static ref CLASS_HASH: Felt252 = compute_deprecated_class_hash(&CONTRACT_CLASS).unwrap();
static ref CLASS_HASH_BYTES: [u8; 32] = CLASS_HASH.clone().to_be_bytes();
static ref CLASS_HASH_FELT: Felt252 = compute_deprecated_class_hash(&CONTRACT_CLASS).unwrap();
static ref CLASS_HASH: ClassHash = ClassHash(CLASS_HASH_FELT.to_be_bytes());
static ref SALT: Felt252 = felt_str!(
"2669425616857739096022668060305620640217901643963991674344872184515580705509"
);
static ref CONTRACT_ADDRESS: Address = Address(calculate_contract_address(&SALT.clone(), &CLASS_HASH.clone(), &[], Address(0.into())).unwrap());
static ref CONTRACT_ADDRESS: Address = Address(calculate_contract_address(&SALT, &CLASS_HASH_FELT, &[], Address(0.into())).unwrap());
static ref SIGNATURE: Vec<Felt252> = vec![
felt_str!("3233776396904427614006684968846859029149676045084089832563834729503047027074"),
felt_str!("707039245213420890976709143988743108543645298941971188668773816813012281203"),
Expand Down Expand Up @@ -81,7 +78,7 @@ fn deploy_account(

state
.set_contract_class(
&CLASS_HASH_BYTES,
&CLASS_HASH,
&CompiledClass::Deprecated(Arc::new(CONTRACT_CLASS.clone())),
)
.unwrap();
Expand All @@ -90,7 +87,7 @@ fn deploy_account(

for _ in 0..RUNS {
let mut state_copy = state.clone();
let class_hash = *CLASS_HASH_BYTES;
let class_hash = *CLASS_HASH;
let signature = SIGNATURE.clone();
scope(|| {
// new consumes more execution time than raw struct instantiation
Expand Down Expand Up @@ -162,7 +159,7 @@ fn deploy(#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCach

state
.set_contract_class(
&CLASS_HASH_BYTES,
&CLASS_HASH,
&CompiledClass::Deprecated(Arc::new(CONTRACT_CLASS.clone())),
)
.unwrap();
Expand Down Expand Up @@ -205,7 +202,7 @@ fn invoke(#[cfg(feature = "cairo-native")] program_cache: Rc<RefCell<ProgramCach

state
.set_contract_class(
&CLASS_HASH_BYTES,
&CLASS_HASH,
&CompiledClass::Deprecated(Arc::new(CONTRACT_CLASS.clone())),
)
.unwrap();
Expand Down
16 changes: 8 additions & 8 deletions bench/native_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub fn main() {
fn bench_fibo(executions: usize, native: bool) {
// Create state reader with class hash data
let mut contract_class_cache = HashMap::new();
static CASM_CLASS_HASH: ClassHash = [2; 32];
static CASM_CLASS_HASH: ClassHash = ClassHash([2; 32]);

let (contract_class, constructor_selector) = match native {
true => {
Expand Down Expand Up @@ -137,7 +137,7 @@ fn bench_fibo(executions: usize, native: bool) {
fn bench_fact(executions: usize, native: bool) {
// Create state reader with class hash data
let mut contract_class_cache = HashMap::new();
static CASM_CLASS_HASH: ClassHash = [2; 32];
static CASM_CLASS_HASH: ClassHash = ClassHash([2; 32]);

let (contract_class, constructor_selector) = match native {
true => {
Expand Down Expand Up @@ -216,9 +216,9 @@ fn bench_erc20(executions: usize, native: bool) {
let mut contract_class_cache = HashMap::new();

lazy_static! {
static ref ERC20_CLASS_HASH: ClassHash = felt_str!("2").to_be_bytes();
static ref DEPLOYER_CLASS_HASH: ClassHash = felt_str!("10").to_be_bytes();
static ref ACCOUNT1_CLASS_HASH: ClassHash = felt_str!("1").to_be_bytes();
static ref ERC20_CLASS_HASH: ClassHash = ClassHash::from(felt_str!("2"));
static ref DEPLOYER_CLASS_HASH: ClassHash = ClassHash::from(felt_str!("10"));
static ref ACCOUNT1_CLASS_HASH: ClassHash = ClassHash::from(felt_str!("1"));
static ref DEPLOYER_ADDRESS: Address = Address(1111.into());
static ref ERC20_NAME: Felt252 = Felt252::from_bytes_be(b"be");
static ref ERC20_SYMBOL: Felt252 = Felt252::from_bytes_be(b"be");
Expand All @@ -227,7 +227,7 @@ fn bench_erc20(executions: usize, native: bool) {
static ref ERC20_RECIPIENT: Felt252 = felt_str!("111");
static ref ERC20_SALT: Felt252 = felt_str!("1234");
static ref ERC20_DEPLOYER_CALLDATA: [Felt252; 7] = [
Felt252::from_bytes_be(&ERC20_CLASS_HASH.clone()),
Felt252::from_bytes_be(ERC20_CLASS_HASH.to_bytes_be()),
ERC20_SALT.clone(),
ERC20_RECIPIENT.clone(),
ERC20_NAME.clone(),
Expand Down Expand Up @@ -419,8 +419,8 @@ fn bench_erc20(executions: usize, native: bool) {
.unwrap();
state
.set_compiled_class_hash(
&Felt252::from_bytes_be(&ACCOUNT1_CLASS_HASH.clone()),
&Felt252::from_bytes_be(&ACCOUNT1_CLASS_HASH.clone()),
&Felt252::from_bytes_be(ACCOUNT1_CLASS_HASH.to_bytes_be()),
&Felt252::from_bytes_be(ACCOUNT1_CLASS_HASH.to_bytes_be()),
)
.unwrap();

Expand Down
8 changes: 5 additions & 3 deletions fuzzer/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use cairo_vm::felt::Felt252;
use cairo_vm::vm::runners::cairo_runner::ExecutionResources;
use num_traits::Zero;
use starknet_in_rust::execution::execution_entry_point::ExecutionResult;
use starknet_in_rust::utils::ClassHash;
use starknet_in_rust::EntryPointType;
use starknet_in_rust::{
definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION},
Expand Down Expand Up @@ -115,7 +116,7 @@ fn main() {
// ------------ contract data --------------------

let address = Address(1111.into());
let class_hash = [1; 32];
let class_hash: ClassHash = ClassHash([1; 32]);

contract_class_cache.insert(
class_hash,
Expand Down Expand Up @@ -166,7 +167,8 @@ fn main() {
);
let mut resources_manager = ExecutionResourcesManager::default();

let expected_key = calculate_sn_keccak("_counter".as_bytes());
let expected_key_bytes = calculate_sn_keccak("_counter".as_bytes());
let expected_key = ClassHash(expected_key_bytes);

let mut expected_accessed_storage_keys = HashSet::new();
expected_accessed_storage_keys.insert(expected_key);
Expand Down Expand Up @@ -205,7 +207,7 @@ fn main() {
state
.cache()
.storage_writes()
.get(&(address, expected_key))
.get(&(address, expected_key_bytes))
.cloned(),
Some(Felt252::from_bytes_be(data_to_ascii(data).as_bytes()))
);
Expand Down
10 changes: 5 additions & 5 deletions rpc_state_reader/tests/sir_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct RpcStateReader(RpcState);

impl StateReader for RpcStateReader {
fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError> {
let hash = SNClassHash(StarkHash::new(*class_hash).unwrap());
let hash = SNClassHash(StarkHash::new(class_hash.0).unwrap());
Ok(CompiledClass::from(
self.0.get_contract_class(&hash).unwrap(),
))
Expand All @@ -56,7 +56,7 @@ impl StateReader for RpcStateReader {
);
let mut bytes = [0u8; 32];
bytes.copy_from_slice(self.0.get_class_hash_at(&address).0.bytes());
Ok(bytes)
Ok(ClassHash(bytes))
}

fn get_nonce_at(&self, contract_address: &Address) -> Result<Felt252, StateError> {
Expand All @@ -83,7 +83,7 @@ impl StateReader for RpcStateReader {
Ok(Felt252::from_bytes_be(value.bytes()))
}

fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result<[u8; 32], StateError> {
fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result<ClassHash, StateError> {
Ok(*class_hash)
}
}
Expand Down Expand Up @@ -156,7 +156,7 @@ pub fn execute_tx_configurable(
RpcState::new_infura(network, (block_number.next()).into()).unwrap(),
);
let contract_class = next_block_state_reader
.get_contract_class(tx.class_hash().0.bytes().try_into().unwrap())
.get_contract_class(&ClassHash(tx.class_hash().0.bytes().try_into().unwrap()))
.unwrap();

if tx.version() != TransactionVersion(2_u8.into()) {
Expand All @@ -177,7 +177,7 @@ pub fn execute_tx_configurable(
.collect(),
Felt252::from_bytes_be(tx.nonce().0.bytes()),
Felt252::from_bytes_be(tx_hash.0.bytes()),
tx.class_hash().0.bytes().try_into().unwrap(),
ClassHash(tx.class_hash().0.bytes().try_into().unwrap()),
)
.unwrap();
declare.create_for_simulation(skip_validate, false, false, false, skip_nonce_check)
Expand Down
4 changes: 2 additions & 2 deletions src/bin/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use starknet_in_rust::{
},
state::cached_state::CachedState,
state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager},
utils::Address,
utils::{Address, ClassHash},
EntryPointType,
};

Expand All @@ -33,7 +33,7 @@ lazy_static! {

static ref CONTRACT_PATH: PathBuf = PathBuf::from("starknet_programs/fibonacci.json");

static ref CONTRACT_CLASS_HASH: [u8; 32] = [1; 32];
static ref CONTRACT_CLASS_HASH: ClassHash = ClassHash([1; 32]);

static ref CONTRACT_ADDRESS: Address = Address(1.into());

Expand Down
4 changes: 2 additions & 2 deletions src/bin/invoke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use starknet_in_rust::{
state::cached_state::CachedState,
state::in_memory_state_reader::InMemoryStateReader,
transaction::{InvokeFunction, Transaction},
utils::Address,
utils::{Address, ClassHash},
};

use lazy_static::lazy_static;
Expand All @@ -31,7 +31,7 @@ lazy_static! {

static ref CONTRACT_PATH: PathBuf = PathBuf::from("starknet_programs/first_contract.json");

static ref CONTRACT_CLASS_HASH: [u8; 32] = [1; 32];
static ref CONTRACT_CLASS_HASH: ClassHash = ClassHash([1; 32]);

static ref CONTRACT_ADDRESS: Address = Address(1.into());

Expand Down
4 changes: 2 additions & 2 deletions src/bin/invoke_with_cachedstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use starknet_in_rust::{
state::in_memory_state_reader::InMemoryStateReader,
state::{cached_state::CachedState, BlockInfo},
transaction::InvokeFunction,
utils::Address,
utils::{Address, ClassHash},
};

use lazy_static::lazy_static;
Expand All @@ -34,7 +34,7 @@ lazy_static! {

static ref CONTRACT_PATH: PathBuf = PathBuf::from("starknet_programs/first_contract.json");

static ref CONTRACT_CLASS_HASH: [u8; 32] = [5, 133, 114, 83, 104, 231, 159, 23, 87, 255, 235, 75, 170, 4, 84, 140, 49, 77, 101, 41, 147, 198, 201, 231, 38, 189, 215, 84, 231, 141, 140, 122];
static ref CONTRACT_CLASS_HASH: ClassHash = ClassHash([5, 133, 114, 83, 104, 231, 159, 23, 87, 255, 235, 75, 170, 4, 84, 140, 49, 77, 101, 41, 147, 198, 201, 231, 38, 189, 215, 84, 231, 141, 140, 122]);

static ref CONTRACT_ADDRESS: Address = Address(1.into());

Expand Down
32 changes: 15 additions & 17 deletions src/execution/execution_entry_point.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::{
CallInfo, CallResult, CallType, OrderedEvent, OrderedL2ToL1Message, TransactionExecutionContext,
};
#[cfg(feature = "cairo-native")]
use crate::state::StateDiff;
use crate::{
definitions::{block_context::BlockContext, constants::DEFAULT_ENTRY_POINT_SELECTOR},
runner::StarknetRunner,
Expand All @@ -23,11 +25,9 @@ use crate::{
transaction::error::TransactionError,
utils::{
get_deployed_address_class_hash_at_address, parse_builtin_names,
validate_contract_deployed, Address,
validate_contract_deployed, Address, ClassHash,
},
};
#[cfg(feature = "cairo-native")]
use crate::{state::StateDiff, utils::ClassHash};
use cairo_lang_sierra::program::Program as SierraProgram;
use cairo_lang_starknet::casm_contract_class::{CasmContractClass, CasmContractEntryPoint};
use cairo_lang_starknet::contract_class::ContractEntryPoints;
Expand Down Expand Up @@ -74,7 +74,7 @@ pub struct ExecutionEntryPoint {
pub(crate) call_type: CallType,
pub(crate) contract_address: Address,
pub(crate) code_address: Option<Address>,
pub(crate) class_hash: Option<[u8; 32]>,
pub(crate) class_hash: Option<ClassHash>,
pub(crate) calldata: Vec<Felt252>,
pub(crate) caller_address: Address,
pub(crate) entry_point_selector: Felt252,
Expand All @@ -90,7 +90,7 @@ impl ExecutionEntryPoint {
caller_address: Address,
entry_point_type: EntryPointType,
call_type: Option<CallType>,
class_hash: Option<[u8; 32]>,
class_hash: Option<ClassHash>,
initial_gas: u128,
) -> Self {
ExecutionEntryPoint {
Expand Down Expand Up @@ -126,7 +126,7 @@ impl ExecutionEntryPoint {
T: StateReader,
{
// lookup the compiled class from the state.
let class_hash = self.get_code_class_hash(state)?;
let class_hash = self.get_class_hash(state)?;
let contract_class = state
.get_contract_class(&class_hash)
.map_err(|_| TransactionError::MissingCompiledClass)?;
Expand Down Expand Up @@ -236,7 +236,7 @@ impl ExecutionEntryPoint {
fn get_selected_entry_point_v0(
&self,
contract_class: &ContractClass,
_class_hash: [u8; 32],
_class_hash: ClassHash,
) -> Result<ContractEntryPoint, TransactionError> {
let entry_points = contract_class
.entry_points_by_type
Expand Down Expand Up @@ -267,7 +267,7 @@ impl ExecutionEntryPoint {
fn get_selected_entry_point(
&self,
contract_class: &CasmContractClass,
_class_hash: [u8; 32],
_class_hash: ClassHash,
) -> Result<CasmContractEntryPoint, TransactionError> {
let entry_points = match self.entry_point_type {
EntryPointType::External => &contract_class.entry_points_by_type.external,
Expand Down Expand Up @@ -312,7 +312,7 @@ impl ExecutionEntryPoint {
call_type: Some(self.call_type.clone()),
contract_address: self.contract_address.clone(),
code_address: self.code_address.clone(),
class_hash: Some(self.get_code_class_hash(starknet_storage_state.state)?),
class_hash: Some(self.get_class_hash(starknet_storage_state.state)?),
entry_point_selector: Some(self.entry_point_selector.clone()),
entry_point_type: Some(self.entry_point_type),
calldata: self.calldata.clone(),
Expand Down Expand Up @@ -345,7 +345,7 @@ impl ExecutionEntryPoint {
call_type: Some(self.call_type.clone()),
contract_address: self.contract_address.clone(),
code_address: self.code_address.clone(),
class_hash: Some(self.get_code_class_hash(starknet_storage_state.state)?),
class_hash: Some(self.get_class_hash(starknet_storage_state.state)?),
entry_point_selector: Some(self.entry_point_selector.clone()),
entry_point_type: Some(self.entry_point_type),
calldata: self.calldata.clone(),
Expand All @@ -366,7 +366,7 @@ impl ExecutionEntryPoint {
}

/// Returns the hash of the executed contract class.
fn get_code_class_hash<S: State>(&self, state: &mut S) -> Result<[u8; 32], TransactionError> {
fn get_class_hash<S: State>(&self, state: &mut S) -> Result<ClassHash, TransactionError> {
if let Some(class_hash) = self.class_hash {
match self.call_type {
CallType::Delegate => return Ok(class_hash),
Expand Down Expand Up @@ -394,7 +394,7 @@ impl ExecutionEntryPoint {
block_context: &BlockContext,
tx_execution_context: &mut TransactionExecutionContext,
contract_class: Arc<ContractClass>,
class_hash: [u8; 32],
class_hash: ClassHash,
) -> Result<CallInfo, TransactionError> {
let previous_cairo_usage = resources_manager.cairo_usage.clone();
// fetch selected entry point
Expand Down Expand Up @@ -499,7 +499,7 @@ impl ExecutionEntryPoint {
block_context: &BlockContext,
tx_execution_context: &mut TransactionExecutionContext,
contract_class: Arc<CasmContractClass>,
class_hash: [u8; 32],
class_hash: ClassHash,
support_reverted: bool,
) -> Result<CallInfo, TransactionError> {
let previous_cairo_usage = resources_manager.cairo_usage.clone();
Expand Down Expand Up @@ -666,7 +666,7 @@ impl ExecutionEntryPoint {
sierra_program_and_entrypoints: Arc<(SierraProgram, ContractEntryPoints)>,
tx_execution_context: &TransactionExecutionContext,
block_context: &BlockContext,
class_hash: &[u8; 32],
class_hash: &ClassHash,
program_cache: Rc<RefCell<ProgramCache<'_, ClassHash>>>,
) -> Result<CallInfo, TransactionError> {
use crate::{
Expand Down Expand Up @@ -826,9 +826,7 @@ impl ExecutionEntryPoint {
call_type: Some(self.call_type.clone()),
contract_address: self.contract_address.clone(),
code_address: self.code_address.clone(),
class_hash: Some(
self.get_code_class_hash(syscall_handler.starknet_storage_state.state)?,
),
class_hash: Some(self.get_class_hash(syscall_handler.starknet_storage_state.state)?),
entry_point_selector: Some(self.entry_point_selector.clone()),
entry_point_type: Some(self.entry_point_type),
calldata: self.calldata.clone(),
Expand Down
Loading