From 8809ba033b67c4cbc325026d18220bcfa1598eb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C3=ADas=20Ignacio=20Gonz=C3=A1lez?= Date: Mon, 17 Jul 2023 13:45:52 -0300 Subject: [PATCH] Add reverted transactions (#813) * Remove mutable ref from StateReader * Fix tests * Fix tests * Fix cairo2 tests * Fix clippy * Add tmp_state * Fix tmp_state * Remove unused max steps * Revert "Remove unused max steps" This reverts commit a8c6374655064b22e193394dcafb10bb52ea7388. * Add reverted Transactions * Add revert_error to TransactionExecutionInfo * Test test_reverted_transaction_wrong_entry_point * Disable reverted transaction for declare verify * Replace empty for default * Update CachedState * Fix tests * Fix test * Remove unused imports --- bench/internals.rs | 10 +- cli/src/main.rs | 29 +- fuzzer/src/main.rs | 28 +- src/bin/fibonacci.rs | 4 +- src/bin/invoke.rs | 4 +- src/bin/invoke_with_cachedstate.rs | 4 +- src/execution/execution_entry_point.rs | 134 +++++---- src/execution/mod.rs | 6 + src/lib.rs | 40 +-- src/state/cached_state.rs | 265 ++++++++++++------ src/state/contract_storage_state.rs | 13 +- src/state/in_memory_state_reader.rs | 12 +- src/state/mod.rs | 14 +- src/state/state_api.rs | 27 +- .../business_logic_syscall_handler.rs | 53 ++-- ...precated_business_logic_syscall_handler.rs | 71 +++-- src/syscalls/deprecated_syscall_handler.rs | 42 ++- src/syscalls/deprecated_syscall_response.rs | 6 +- src/syscalls/syscall_handler.rs | 16 +- src/testing/erc20.rs | 7 +- src/testing/mod.rs | 4 +- src/testing/state.rs | 36 +-- src/transaction/declare.rs | 57 ++-- src/transaction/declare_v2.rs | 37 +-- src/transaction/deploy.rs | 37 ++- src/transaction/deploy_account.rs | 83 +++--- src/transaction/fee.rs | 29 +- src/transaction/invoke_function.rs | 87 +++--- src/transaction/l1_handler.rs | 39 ++- src/transaction/mod.rs | 6 +- src/utils.rs | 9 +- tests/cairo_1_syscalls.rs | 168 +++++++---- tests/complex_contracts/amm_contracts/amm.rs | 19 +- .../amm_contracts/amm_proxy.rs | 11 +- tests/complex_contracts/nft/erc721.rs | 29 +- tests/complex_contracts/utils.rs | 10 +- tests/delegate_call.rs | 4 +- tests/delegate_l1_handler.rs | 4 +- tests/deploy_account.rs | 8 +- tests/fibonacci.rs | 11 +- tests/increase_balance.rs | 6 +- tests/internal_calls.rs | 7 +- tests/internals.rs | 47 ++-- tests/storage.rs | 6 +- tests/syscalls.rs | 29 +- tests/syscalls_errors.rs | 4 +- 46 files changed, 941 insertions(+), 631 deletions(-) diff --git a/bench/internals.rs b/bench/internals.rs index f116ec94e..075f216a2 100644 --- a/bench/internals.rs +++ b/bench/internals.rs @@ -17,7 +17,7 @@ use starknet_in_rust::{ transaction::{declare::Declare, Deploy, DeployAccount, InvokeFunction}, utils::Address, }; -use std::hint::black_box; +use std::{hint::black_box, sync::Arc}; lazy_static! { // include_str! doesn't seem to work in CI @@ -60,7 +60,7 @@ fn main() { fn deploy_account() { const RUNS: usize = 500; - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(Default::default()), None); state @@ -96,7 +96,7 @@ fn deploy_account() { fn declare() { const RUNS: usize = 5; - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let state = CachedState::new(state_reader, Some(Default::default()), None); let block_context = &Default::default(); @@ -128,7 +128,7 @@ fn declare() { fn deploy() { const RUNS: usize = 8; - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(Default::default()), None); state @@ -163,7 +163,7 @@ fn deploy() { fn invoke() { const RUNS: usize = 100; - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(Default::default()), None); state diff --git a/cli/src/main.rs b/cli/src/main.rs index 3b809bc71..f37bdf544 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -16,20 +16,24 @@ use starknet_in_rust::{ block_context::BlockContext, constants::{DECLARE_VERSION, TRANSACTION_VERSION}, }, - execution::{execution_entry_point::ExecutionEntryPoint, TransactionExecutionContext}, + execution::{ + execution_entry_point::{ExecutionEntryPoint, ExecutionResult}, + TransactionExecutionContext, + }, hash_utils::calculate_contract_address, parser_errors::ParserError, serde_structs::read_abi, services::api::contract_classes::deprecated_contract_class::ContractClass, - state::{ - cached_state::CachedState, - state_api::{State, StateReader}, - }, + state::{cached_state::CachedState, state_api::State}, state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, - transaction::InvokeFunction, + transaction::{error::TransactionError, InvokeFunction}, utils::{felt_to_hash, string_to_hash, Address}, }; -use std::{collections::HashMap, path::PathBuf, sync::Mutex}; +use std::{ + collections::HashMap, + path::PathBuf, + sync::{Arc, Mutex}, +}; #[derive(Parser)] struct Cli { @@ -248,13 +252,18 @@ fn call_parser( None, 0, ); - let call_info = execution_entry_point.execute( + let block_context = BlockContext::default(); + let ExecutionResult { call_info, .. } = execution_entry_point.execute( cached_state, - &BlockContext::default(), + &block_context, &mut ExecutionResourcesManager::default(), &mut TransactionExecutionContext::default(), false, + block_context.invoke_tx_max_n_steps(), )?; + + let call_info = call_info.ok_or(TransactionError::CallInfoIsNone)?; + Ok(call_info.retdata) } @@ -303,7 +312,7 @@ async fn call_req(data: web::Data, args: web::Json) -> HttpR pub async fn start_devnet(port: u16) -> Result<(), std::io::Error> { let cached_state = web::Data::new(AppState { cached_state: Mutex::new(CachedState::::new( - InMemoryStateReader::default(), + Arc::new(InMemoryStateReader::default()), Some(HashMap::new()), None, )), diff --git a/fuzzer/src/main.rs b/fuzzer/src/main.rs index 2431f5f46..edbfb1cf5 100644 --- a/fuzzer/src/main.rs +++ b/fuzzer/src/main.rs @@ -6,6 +6,7 @@ extern crate honggfuzz; use cairo_vm::felt::Felt252; use cairo_vm::vm::runners::cairo_runner::ExecutionResources; use num_traits::Zero; +use starknet_in_rust::execution::execution_entry_point::ExecutionResult; use starknet_in_rust::EntryPointType; use starknet_in_rust::{ definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, @@ -18,6 +19,7 @@ use starknet_in_rust::{ utils::{calculate_sn_keccak, Address}, }; +use std::sync::Arc; use std::{ collections::{HashMap, HashSet}, path::PathBuf, @@ -124,7 +126,8 @@ fn main() { //* Create state with previous data //* --------------------------------------- - let mut state = CachedState::new(state_reader, Some(contract_class_cache), None); + let mut state = + CachedState::new(Arc::new(state_reader), Some(contract_class_cache), None); //* ------------------------------------ //* Create execution entry point @@ -180,18 +183,17 @@ fn main() { ..Default::default() }; - assert_eq!( - exec_entry_point - .execute( - &mut state, - &block_context, - &mut resources_manager, - &mut tx_execution_context, - false, - ) - .unwrap(), - expected_call_info - ); + let ExecutionResult { call_info, .. } = exec_entry_point + .execute( + &mut state, + &block_context, + &mut resources_manager, + &mut tx_execution_context, + false, + block_context.invoke_tx_max_n_steps(), + ) + .unwrap(); + assert_eq!(call_info.unwrap(), expected_call_info); assert!(!state.cache().storage_writes().is_empty()); assert_eq!( diff --git a/src/bin/fibonacci.rs b/src/bin/fibonacci.rs index fb7f9535f..b14be188a 100644 --- a/src/bin/fibonacci.rs +++ b/src/bin/fibonacci.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, path::PathBuf}; +use std::{collections::HashMap, path::PathBuf, sync::Arc}; use cairo_vm::felt::{felt_str, Felt252}; use num_traits::Zero; @@ -85,7 +85,7 @@ fn create_initial_state() -> CachedState { state_reader .address_to_storage_mut() .insert((CONTRACT_ADDRESS.clone(), [0; 32]), Felt252::zero()); - state_reader + Arc::new(state_reader) }, Some(HashMap::new()), None, diff --git a/src/bin/invoke.rs b/src/bin/invoke.rs index 4b344e049..afec929fa 100644 --- a/src/bin/invoke.rs +++ b/src/bin/invoke.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, path::PathBuf}; +use std::{collections::HashMap, path::PathBuf, sync::Arc}; use cairo_vm::felt::{felt_str, Felt252}; use num_traits::Zero; @@ -99,7 +99,7 @@ fn create_initial_state() -> CachedState { state_reader .address_to_storage_mut() .insert((CONTRACT_ADDRESS.clone(), [0; 32]), Felt252::zero()); - state_reader + Arc::new(state_reader) }, Some(HashMap::new()), None, diff --git a/src/bin/invoke_with_cachedstate.rs b/src/bin/invoke_with_cachedstate.rs index 0175dd6af..ebfe9202e 100644 --- a/src/bin/invoke_with_cachedstate.rs +++ b/src/bin/invoke_with_cachedstate.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, path::PathBuf}; +use std::{collections::HashMap, path::PathBuf, sync::Arc}; use cairo_vm::felt::{felt_str, Felt252}; use num_traits::Zero; @@ -106,7 +106,7 @@ fn create_initial_state() -> CachedState { state_reader .address_to_storage_mut() .insert((CONTRACT_ADDRESS.clone(), [0; 32]), Felt252::zero()); - state_reader + Arc::new(state_reader) }, Some(HashMap::new()), None, diff --git a/src/execution/execution_entry_point.rs b/src/execution/execution_entry_point.rs index e19ba0a0a..b7116d1f1 100644 --- a/src/execution/execution_entry_point.rs +++ b/src/execution/execution_entry_point.rs @@ -1,6 +1,8 @@ use crate::services::api::contract_classes::deprecated_contract_class::{ ContractEntryPoint, EntryPointType, }; +use crate::state::cached_state::CachedState; +use crate::state::StateDiff; use crate::{ definitions::{block_context::BlockContext, constants::DEFAULT_ENTRY_POINT_SELECTOR}, runner::StarknetRunner, @@ -39,6 +41,13 @@ use super::{ CallInfo, CallResult, CallType, OrderedEvent, OrderedL2ToL1Message, TransactionExecutionContext, }; +#[derive(Debug, Default)] +pub struct ExecutionResult { + pub call_info: Option, + pub revert_error: Option, + pub n_reverted_steps: usize, +} + /// Represents a Cairo entry point execution of a StarkNet contract. // TODO:initial_gas is a new field added in the current changes, it should be checked if we delete it once the new execution entry point is done @@ -85,14 +94,15 @@ impl ExecutionEntryPoint { /// Returns a CallInfo object that represents the execution. pub fn execute( &self, - state: &mut T, + state: &mut CachedState, block_context: &BlockContext, resources_manager: &mut ExecutionResourcesManager, tx_execution_context: &mut TransactionExecutionContext, support_reverted: bool, - ) -> Result + max_steps: u64, + ) -> Result where - T: State + StateReader, + T: StateReader, { // lookup the compiled class from the state. let class_hash = self.get_code_class_hash(state)?; @@ -100,23 +110,62 @@ impl ExecutionEntryPoint { .get_contract_class(&class_hash) .map_err(|_| TransactionError::MissingCompiledClass)?; match contract_class { - CompiledClass::Deprecated(contract_class) => self._execute_version0_class( - state, - resources_manager, - block_context, - tx_execution_context, - contract_class, - class_hash, - ), - CompiledClass::Casm(contract_class) => self._execute( - state, - resources_manager, - block_context, - tx_execution_context, - contract_class, - class_hash, - support_reverted, - ), + CompiledClass::Deprecated(contract_class) => { + let call_info = self._execute_version0_class( + state, + resources_manager, + block_context, + tx_execution_context, + contract_class, + class_hash, + )?; + Ok(ExecutionResult { + call_info: Some(call_info), + revert_error: None, + n_reverted_steps: 0, + }) + } + CompiledClass::Casm(contract_class) => { + let mut tmp_state = CachedState::new( + state.state_reader.clone(), + state.contract_classes.clone(), + state.casm_contract_classes.clone(), + ); + tmp_state.cache = state.cache.clone(); + + match self._execute( + &mut tmp_state, + resources_manager, + block_context, + tx_execution_context, + contract_class, + class_hash, + support_reverted, + ) { + Ok(call_info) => { + let state_diff = StateDiff::from_cached_state(tmp_state)?; + state.apply_state_update(&state_diff)?; + Ok(ExecutionResult { + call_info: Some(call_info), + revert_error: None, + n_reverted_steps: 0, + }) + } + Err(e) => { + if !support_reverted { + return Err(e); + } + + let n_reverted_steps = + (max_steps as usize) - resources_manager.cairo_usage.n_steps; + Ok(ExecutionResult { + call_info: None, + revert_error: Some(e.to_string()), + n_reverted_steps, + }) + } + } + } } } @@ -184,7 +233,7 @@ impl ExecutionEntryPoint { .ok_or(TransactionError::EntryPointNotFound) } - fn build_call_info_deprecated( + fn build_call_info_deprecated( &self, previous_cairo_usage: ExecutionResources, resources_manager: &ExecutionResourcesManager, @@ -193,10 +242,7 @@ impl ExecutionEntryPoint { l2_to_l1_messages: Vec, internal_calls: Vec, retdata: Vec, - ) -> Result - where - S: State + StateReader, - { + ) -> Result { let execution_resources = &resources_manager.cairo_usage - &previous_cairo_usage; Ok(CallInfo { @@ -220,7 +266,7 @@ impl ExecutionEntryPoint { }) } - fn build_call_info( + fn build_call_info( &self, previous_cairo_usage: ExecutionResources, resources_manager: &ExecutionResourcesManager, @@ -229,10 +275,7 @@ impl ExecutionEntryPoint { l2_to_l1_messages: Vec, internal_calls: Vec, call_result: CallResult, - ) -> Result - where - S: State + StateReader, - { + ) -> Result { let execution_resources = &resources_manager.cairo_usage - &previous_cairo_usage; Ok(CallInfo { @@ -261,10 +304,7 @@ impl ExecutionEntryPoint { } /// Returns the hash of the executed contract class. - fn get_code_class_hash( - &self, - state: &mut S, - ) -> Result<[u8; 32], TransactionError> { + fn get_code_class_hash(&self, state: &mut S) -> Result<[u8; 32], TransactionError> { if self.class_hash.is_some() { match self.call_type { CallType::Delegate => return Ok(self.class_hash.unwrap()), @@ -285,18 +325,15 @@ impl ExecutionEntryPoint { get_deployed_address_class_hash_at_address(state, &code_address.unwrap()) } - fn _execute_version0_class( + fn _execute_version0_class( &self, - state: &mut T, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, tx_execution_context: &mut TransactionExecutionContext, contract_class: Box, class_hash: [u8; 32], - ) -> Result - where - T: State + StateReader, - { + ) -> Result { let previous_cairo_usage = resources_manager.cairo_usage.clone(); // fetch selected entry point let entry_point = self.get_selected_entry_point_v0(&contract_class, class_hash)?; @@ -311,7 +348,7 @@ impl ExecutionEntryPoint { // prepare OS context //let os_context = runner.prepare_os_context(); let os_context = - StarknetRunner::>::prepare_os_context_cairo0( + StarknetRunner::>::prepare_os_context_cairo0( &cairo_runner, &mut vm, ); @@ -382,7 +419,7 @@ impl ExecutionEntryPoint { let retdata = runner.get_return_values()?; - self.build_call_info_deprecated::( + self.build_call_info_deprecated::( previous_cairo_usage, resources_manager, runner.hint_processor.syscall_handler.starknet_storage_state, @@ -393,19 +430,16 @@ impl ExecutionEntryPoint { ) } - fn _execute( + fn _execute( &self, - state: &mut T, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, tx_execution_context: &mut TransactionExecutionContext, contract_class: Box, class_hash: [u8; 32], support_reverted: bool, - ) -> Result - where - T: State + StateReader, - { + ) -> Result { let previous_cairo_usage = resources_manager.cairo_usage.clone(); // fetch selected entry point @@ -424,7 +458,7 @@ impl ExecutionEntryPoint { )?; validate_contract_deployed(state, &self.contract_address)?; // prepare OS context - let os_context = StarknetRunner::>::prepare_os_context_cairo1( + let os_context = StarknetRunner::>::prepare_os_context_cairo1( &cairo_runner, &mut vm, self.initial_gas.into(), @@ -535,7 +569,7 @@ impl ExecutionEntryPoint { resources_manager.cairo_usage += &runner.get_execution_resources()?; let call_result = runner.get_call_result(self.initial_gas)?; - self.build_call_info::( + self.build_call_info::( previous_cairo_usage, resources_manager, runner.hint_processor.syscall_handler.starknet_storage_state, diff --git a/src/execution/mod.rs b/src/execution/mod.rs index fcb782f1e..a86df2146 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -434,6 +434,7 @@ impl TxInfoStruct { pub struct TransactionExecutionInfo { pub validate_info: Option, pub call_info: Option, + pub revert_error: Option, pub fee_transfer_info: Option, pub actual_fee: u128, pub actual_resources: HashMap, @@ -444,6 +445,7 @@ impl TransactionExecutionInfo { pub fn new( validate_info: Option, call_info: Option, + revert_error: Option, fee_transfer_info: Option, actual_fee: u128, actual_resources: HashMap, @@ -452,6 +454,7 @@ impl TransactionExecutionInfo { TransactionExecutionInfo { validate_info, call_info, + revert_error, fee_transfer_info, actual_fee, actual_resources, @@ -490,6 +493,7 @@ impl TransactionExecutionInfo { TransactionExecutionInfo { validate_info, call_info: execute_call_info, + revert_error: None, fee_transfer_info, actual_fee: 0, actual_resources: HashMap::new(), @@ -500,12 +504,14 @@ impl TransactionExecutionInfo { pub fn new_without_fee_info( validate_info: Option, call_info: Option, + revert_error: Option, actual_resources: HashMap, tx_type: Option, ) -> Self { TransactionExecutionInfo { validate_info, call_info, + revert_error, fee_transfer_info: None, actual_fee: 0, actual_resources, diff --git a/src/lib.rs b/src/lib.rs index bee36a39e..4e7e06b32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ #![deny(warnings)] #![forbid(unsafe_code)] #![cfg_attr(coverage_nightly, feature(no_coverage))] -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use crate::{ execution::{ @@ -17,6 +17,7 @@ use crate::{ use cairo_vm::felt::Felt252; use definitions::block_context::BlockContext; +use execution::execution_entry_point::ExecutionResult; use state::cached_state::CachedState; use transaction::{fee::calculate_tx_fee, L1Handler}; use utils::Address; @@ -57,7 +58,7 @@ pub fn simulate_transaction( skip_execute: bool, skip_fee_transfer: bool, ) -> Result, TransactionError> { - let mut cache_state = CachedState::new(state, None, Some(HashMap::new())); + let mut cache_state = CachedState::new(Arc::new(state), None, Some(HashMap::new())); let mut result = Vec::with_capacity(transactions.len()); for transaction in transactions { let tx_for_simulation = @@ -80,7 +81,7 @@ where T: StateReader, { // This is used as a copy of the original state, we can update this cached state freely. - let mut cached_state = CachedState::::new(state, None, None); + let mut cached_state = CachedState::::new(Arc::new(state), None, None); let mut result = Vec::with_capacity(transactions.len()); for transaction in transactions { @@ -103,11 +104,11 @@ where Ok(result) } -pub fn call_contract( +pub fn call_contract( contract_address: Felt252, entrypoint_selector: Felt252, calldata: Vec, - state: &mut T, + state: &mut CachedState, block_context: BlockContext, caller_address: Address, ) -> Result, TransactionError> { @@ -143,14 +144,16 @@ pub fn call_contract( version.into(), ); - let call_info = execution_entrypoint.execute( + let ExecutionResult { call_info, .. } = execution_entrypoint.execute( state, &block_context, &mut ExecutionResourcesManager::default(), &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps, )?; + let call_info = call_info.ok_or(TransactionError::CallInfoIsNone)?; Ok(call_info.retdata) } @@ -164,7 +167,7 @@ where T: StateReader, { // This is used as a copy of the original state, we can update this cached state freely. - let mut cached_state = CachedState::::new(state, None, None); + let mut cached_state = CachedState::::new(Arc::new(state), None, None); // Check if the contract is deployed. cached_state.get_class_hash_at(l1_handler.contract_address())?; @@ -183,9 +186,9 @@ where } } -pub fn execute_transaction( +pub fn execute_transaction( tx: Transaction, - state: &mut T, + state: &mut CachedState, block_context: BlockContext, remaining_gas: u128, ) -> Result { @@ -196,6 +199,7 @@ pub fn execute_transaction( mod test { use std::collections::HashMap; use std::path::PathBuf; + use std::sync::Arc; use crate::core::contract_address::{compute_deprecated_class_hash, compute_sierra_class_hash}; use crate::definitions::constants::INITIAL_GAS_COST; @@ -312,7 +316,7 @@ mod test { .address_to_nonce_mut() .insert(address.clone(), nonce); - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); let calldata = [1.into(), 1.into(), 10.into()].to_vec(); let retdata = call_contract( @@ -364,7 +368,7 @@ mod test { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes let contract_classes = HashMap::from([(class_hash, contract_class)]); @@ -407,7 +411,7 @@ mod test { .address_to_nonce_mut() .insert(address.clone(), nonce); - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); let calldata = [1.into(), 1.into(), 10.into()].to_vec(); let invoke = InvokeFunction::new( @@ -658,7 +662,7 @@ mod test { #[test] fn test_simulate_deploy() { - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(Default::default()), None); state @@ -695,7 +699,7 @@ mod test { #[test] fn test_simulate_declare() { - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let state = CachedState::new(state_reader, Some(Default::default()), None); let block_context = &Default::default(); @@ -730,7 +734,7 @@ mod test { #[test] fn test_simulate_invoke() { - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(Default::default()), None); state @@ -789,7 +793,7 @@ mod test { #[test] fn test_simulate_deploy_account() { - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(Default::default()), None); state @@ -907,7 +911,7 @@ mod test { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); @@ -938,7 +942,7 @@ mod test { #[test] fn test_deploy_and_invoke_simulation() { - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(Default::default()), None); state diff --git a/src/state/cached_state.rs b/src/state/cached_state.rs index 5f6fb753c..53cbc1905 100644 --- a/src/state/cached_state.rs +++ b/src/state/cached_state.rs @@ -14,7 +14,7 @@ use cairo_lang_starknet::casm_contract_class::CasmContractClass; use cairo_vm::felt::Felt252; use getset::{Getters, MutGetters}; use num_traits::Zero; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; // K: class_hash V: ContractClass pub type ContractClassCache = HashMap; @@ -24,8 +24,7 @@ pub const UNINITIALIZED_CLASS_HASH: &ClassHash = b"\x00\x00\x00\x00\x00\x00\x00\ #[derive(Default, Clone, Debug, Eq, Getters, MutGetters, PartialEq)] pub struct CachedState { - #[get = "pub"] - pub(crate) state_reader: T, + pub state_reader: Arc, #[getset(get = "pub", get_mut = "pub")] pub(crate) cache: StateCache, #[get = "pub"] @@ -36,7 +35,7 @@ pub struct CachedState { impl CachedState { pub fn new( - state_reader: T, + state_reader: Arc, contract_class_cache: Option, casm_class_cache: Option, ) -> Self { @@ -49,7 +48,7 @@ impl CachedState { } pub fn new_for_testing( - state_reader: T, + state_reader: Arc, contract_classes: Option, cache: StateCache, casm_contract_classes: Option, @@ -82,16 +81,19 @@ impl CachedState { } impl StateReader for CachedState { - fn get_class_hash_at(&mut self, contract_address: &Address) -> Result { + fn get_class_hash_at(&self, contract_address: &Address) -> Result { if self.cache.get_class_hash(contract_address).is_none() { - let class_hash = match self.state_reader.get_class_hash_at(contract_address) { - Ok(x) => x, - Err(StateError::NoneContractState(_)) => [0; 32], - Err(e) => return Err(e), - }; - self.cache - .class_hash_initial_values - .insert(contract_address.clone(), class_hash); + match self.state_reader.get_class_hash_at(contract_address) { + Ok(class_hash) => { + return Ok(class_hash); + } + Err(StateError::NoneContractState(_)) => { + return Ok([0; 32]); + } + Err(e) => { + return Err(e); + } + } } self.cache @@ -100,12 +102,9 @@ impl StateReader for CachedState { .cloned() } - fn get_nonce_at(&mut self, contract_address: &Address) -> Result { + fn get_nonce_at(&self, contract_address: &Address) -> Result { if self.cache.get_nonce(contract_address).is_none() { - let nonce = self.state_reader.get_nonce_at(contract_address)?; - self.cache - .nonce_initial_values - .insert(contract_address.clone(), nonce); + return self.state_reader.get_nonce_at(contract_address); } self.cache .get_nonce(contract_address) @@ -113,21 +112,22 @@ impl StateReader for CachedState { .cloned() } - fn get_storage_at(&mut self, storage_entry: &StorageEntry) -> Result { + fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result { if self.cache.get_storage(storage_entry).is_none() { - let value = match self.state_reader.get_storage_at(storage_entry) { - Ok(x) => x, + match self.state_reader.get_storage_at(storage_entry) { + Ok(storage) => { + return Ok(storage); + } Err( StateError::EmptyKeyInStorage | StateError::NoneStoragLeaf(_) | StateError::NoneStorage(_) | StateError::NoneContractState(_), - ) => Felt252::zero(), - Err(e) => return Err(e), - }; - self.cache - .storage_initial_values - .insert(storage_entry.clone(), value); + ) => return Ok(Felt252::zero()), + Err(e) => { + return Err(e); + } + } } self.cache @@ -137,21 +137,23 @@ impl StateReader for CachedState { } // TODO: check if that the proper way to store it (converting hash to address) - fn get_compiled_class_hash(&mut self, class_hash: &ClassHash) -> Result { - let hash = self.cache.class_hash_to_compiled_class_hash.get(class_hash); - if let Some(hash) = hash { - Ok(*hash) - } else { - let compiled_class_hash = self.state_reader.get_compiled_class_hash(class_hash)?; - let address = Address(Felt252::from_bytes_be(&compiled_class_hash)); - self.cache - .class_hash_initial_values - .insert(address, compiled_class_hash); - Ok(compiled_class_hash) + fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result { + if self + .cache + .class_hash_to_compiled_class_hash + .get(class_hash) + .is_none() + { + return self.state_reader.get_compiled_class_hash(class_hash); } + self.cache + .class_hash_to_compiled_class_hash + .get(class_hash) + .ok_or_else(|| StateError::NoneCompiledClass(*class_hash)) + .cloned() } - fn get_contract_class(&mut self, class_hash: &ClassHash) -> Result { + fn get_contract_class(&self, class_hash: &ClassHash) -> Result { // This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes //, which can be on the cache or on the state_reader, different cases will be described below: if class_hash == UNINITIALIZED_CLASS_HASH { @@ -188,20 +190,7 @@ impl StateReader for CachedState { } } // II: FETCHING FROM STATE_READER - let contract = self.state_reader.get_contract_class(class_hash)?; - match contract { - CompiledClass::Casm(ref class) => { - // We call this method instead of state_reader's in order to update the cache's class_hash_initial_values map - let compiled_class_hash = self.get_compiled_class_hash(class_hash)?; - self.casm_contract_classes - .as_mut() - .and_then(|m| m.insert(compiled_class_hash, *class.clone())); - } - CompiledClass::Deprecated(ref contract) => { - self.set_contract_class(class_hash, &contract.clone())? - } - } - Ok(contract) + self.state_reader.get_contract_class(class_hash) } } @@ -332,6 +321,128 @@ impl State for CachedState { let modified_contracts = storage_updates.keys().map(|k| k.0.clone()).len(); (modified_contracts, storage_updates.len()) } + + fn get_class_hash_at(&mut self, contract_address: &Address) -> Result { + if self.cache.get_class_hash(contract_address).is_none() { + let class_hash = match self.state_reader.get_class_hash_at(contract_address) { + Ok(class_hash) => class_hash, + Err(StateError::NoneContractState(_)) => [0; 32], + Err(e) => return Err(e), + }; + self.cache + .class_hash_initial_values + .insert(contract_address.clone(), class_hash); + } + + self.cache + .get_class_hash(contract_address) + .ok_or_else(|| StateError::NoneClassHash(contract_address.clone())) + .cloned() + } + + fn get_nonce_at(&mut self, contract_address: &Address) -> Result { + if self.cache.get_nonce(contract_address).is_none() { + let nonce = self.state_reader.get_nonce_at(contract_address)?; + self.cache + .nonce_initial_values + .insert(contract_address.clone(), nonce); + } + self.cache + .get_nonce(contract_address) + .ok_or_else(|| StateError::NoneNonce(contract_address.clone())) + .cloned() + } + + fn get_storage_at(&mut self, storage_entry: &StorageEntry) -> Result { + if self.cache.get_storage(storage_entry).is_none() { + let value = match self.state_reader.get_storage_at(storage_entry) { + Ok(value) => value, + Err( + StateError::EmptyKeyInStorage + | StateError::NoneStoragLeaf(_) + | StateError::NoneStorage(_) + | StateError::NoneContractState(_), + ) => Felt252::zero(), + Err(e) => return Err(e), + }; + self.cache + .storage_initial_values + .insert(storage_entry.clone(), value); + } + + self.cache + .get_storage(storage_entry) + .ok_or_else(|| StateError::NoneStorage(storage_entry.clone())) + .cloned() + } + + // TODO: check if that the proper way to store it (converting hash to address) + fn get_compiled_class_hash(&mut self, class_hash: &ClassHash) -> Result { + let hash = self.cache.class_hash_to_compiled_class_hash.get(class_hash); + if let Some(hash) = hash { + Ok(*hash) + } else { + let compiled_class_hash = self.state_reader.get_compiled_class_hash(class_hash)?; + let address = Address(Felt252::from_bytes_be(&compiled_class_hash)); + self.cache + .class_hash_initial_values + .insert(address, compiled_class_hash); + Ok(compiled_class_hash) + } + } + + fn get_contract_class(&mut self, class_hash: &ClassHash) -> Result { + // This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes + //, which can be on the cache or on the state_reader, different cases will be described below: + if class_hash == UNINITIALIZED_CLASS_HASH { + return Err(StateError::UninitiaizedClassHash); + } + // I: FETCHING FROM CACHE + // I: DEPRECATED CONTRACT CLASS + // deprecated contract classes dont have compiled class hashes, so we only have one case + if let Some(compiled_class) = self + .contract_classes + .as_ref() + .and_then(|x| x.get(class_hash)) + { + return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone()))); + } + // I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH + if let Some(compiled_class) = self + .casm_contract_classes + .as_ref() + .and_then(|x| x.get(class_hash)) + { + return Ok(CompiledClass::Casm(Box::new(compiled_class.clone()))); + } + // I: CASM CONTRACT CLASS : CLASS_HASH + if let Some(compiled_class_hash) = + self.cache.class_hash_to_compiled_class_hash.get(class_hash) + { + if let Some(casm_class) = &mut self + .casm_contract_classes + .as_ref() + .and_then(|m| m.get(compiled_class_hash)) + { + return Ok(CompiledClass::Casm(Box::new(casm_class.clone()))); + } + } + // II: FETCHING FROM STATE_READER + let contract = self.state_reader.get_contract_class(class_hash)?; + match contract { + CompiledClass::Casm(ref class) => { + // We call this method instead of state_reader's in order to update the cache's class_hash_initial_values map + let compiled_class_hash = self.get_compiled_class_hash(class_hash)?; + self.casm_contract_classes + .as_mut() + .and_then(|m| m.insert(compiled_class_hash, *class.clone())); + } + CompiledClass::Deprecated(ref contract) => { + self.set_contract_class(class_hash, &contract.clone())? + } + } + Ok(contract) + } } #[cfg(test)] @@ -369,7 +480,7 @@ mod tests { .address_to_storage_mut() .insert(storage_entry, storage_value); - let mut cached_state = CachedState::new(state_reader, None, None); + let mut cached_state = CachedState::new(Arc::new(state_reader), None, None); assert_eq!( cached_state.get_class_hash_at(&contract_address).unwrap(), @@ -401,7 +512,7 @@ mod tests { .class_hash_to_contract_class .insert([1; 32], contract_class); - let mut cached_state = CachedState::new(state_reader, None, None); + let mut cached_state = CachedState::new(Arc::new(state_reader), None, None); cached_state.set_contract_classes(HashMap::new()).unwrap(); assert!(cached_state.contract_classes.is_some()); @@ -417,18 +528,8 @@ mod tests { #[test] fn cached_state_storage_test() { - let mut cached_state = CachedState::new( - InMemoryStateReader::new( - HashMap::new(), - HashMap::new(), - HashMap::new(), - HashMap::new(), - HashMap::new(), - HashMap::new(), - ), - None, - None, - ); + let mut cached_state = + CachedState::new(Arc::new(InMemoryStateReader::default()), None, None); let storage_entry: StorageEntry = (Address(31.into()), [0; 32]); let value = Felt252::new(10); @@ -445,14 +546,7 @@ mod tests { #[test] fn cached_state_deploy_contract_test() { - let state_reader = InMemoryStateReader::new( - HashMap::new(), - HashMap::new(), - HashMap::new(), - HashMap::new(), - HashMap::new(), - HashMap::new(), - ); + let state_reader = Arc::new(InMemoryStateReader::default()); let contract_address = Address(32123.into()); @@ -465,14 +559,7 @@ mod tests { #[test] fn get_and_set_storage() { - let state_reader = InMemoryStateReader::new( - HashMap::new(), - HashMap::new(), - HashMap::new(), - HashMap::new(), - HashMap::new(), - HashMap::new(), - ); + let state_reader = Arc::new(InMemoryStateReader::default()); let contract_address = Address(31.into()); let storage_key = [18; 32]; @@ -506,7 +593,7 @@ mod tests { HashMap::new(), HashMap::new(), ); - let mut cached_state = CachedState::new(state_reader, None, None); + let mut cached_state = CachedState::new(Arc::new(state_reader), None, None); cached_state.set_contract_classes(HashMap::new()).unwrap(); let result = cached_state @@ -529,7 +616,7 @@ mod tests { let contract_address = Address(0.into()); - let mut cached_state = CachedState::new(state_reader, None, None); + let mut cached_state = CachedState::new(Arc::new(state_reader), None, None); let result = cached_state .deploy_contract(contract_address.clone(), [10; 32]) @@ -554,7 +641,7 @@ mod tests { let contract_address = Address(42.into()); - let mut cached_state = CachedState::new(state_reader, None, None); + let mut cached_state = CachedState::new(Arc::new(state_reader), None, None); cached_state .deploy_contract(contract_address.clone(), [10; 32]) @@ -582,7 +669,7 @@ mod tests { let contract_address = Address(32123.into()); - let mut cached_state = CachedState::new(state_reader, None, None); + let mut cached_state = CachedState::new(Arc::new(state_reader), None, None); cached_state .deploy_contract(contract_address.clone(), [10; 32]) @@ -611,7 +698,7 @@ mod tests { let address_one = Address(Felt252::one()); - let mut cached_state = CachedState::new(state_reader, None, None); + let mut cached_state = CachedState::new(Arc::new(state_reader), None, None); let state_diff = StateDiff { address_to_class_hash: HashMap::from([( diff --git a/src/state/contract_storage_state.rs b/src/state/contract_storage_state.rs index a63c29f00..795bc339a 100644 --- a/src/state/contract_storage_state.rs +++ b/src/state/contract_storage_state.rs @@ -1,4 +1,7 @@ -use super::state_api::{State, StateReader}; +use super::{ + cached_state::CachedState, + state_api::{State, StateReader}, +}; use crate::{ core::errors::state_errors::StateError, utils::{Address, ClassHash}, @@ -7,16 +10,16 @@ use cairo_vm::felt::Felt252; use std::collections::HashSet; #[derive(Debug)] -pub(crate) struct ContractStorageState<'a, T: State + StateReader> { - pub(crate) state: &'a mut T, +pub(crate) struct ContractStorageState<'a, S: StateReader> { + pub(crate) state: &'a mut CachedState, pub(crate) contract_address: Address, /// Maintain all read request values in chronological order pub(crate) read_values: Vec, pub(crate) accessed_keys: HashSet, } -impl<'a, T: State + StateReader> ContractStorageState<'a, T> { - pub(crate) fn new(state: &'a mut T, contract_address: Address) -> Self { +impl<'a, S: StateReader> ContractStorageState<'a, S> { + pub(crate) fn new(state: &'a mut CachedState, contract_address: Address) -> Self { Self { state, contract_address, diff --git a/src/state/in_memory_state_reader.rs b/src/state/in_memory_state_reader.rs index e9ece9d90..9a553f212 100644 --- a/src/state/in_memory_state_reader.rs +++ b/src/state/in_memory_state_reader.rs @@ -75,7 +75,7 @@ impl InMemoryStateReader { /// # Returns /// The [CompiledClass] with the given [CompiledClassHash]. fn get_compiled_class( - &mut self, + &self, compiled_class_hash: &CompiledClassHash, ) -> Result { if let Some(compiled_class) = self.casm_contract_classes.get(compiled_class_hash) { @@ -89,7 +89,7 @@ impl InMemoryStateReader { } impl StateReader for InMemoryStateReader { - fn get_class_hash_at(&mut self, contract_address: &Address) -> Result { + fn get_class_hash_at(&self, contract_address: &Address) -> Result { let class_hash = self .address_to_class_hash .get(contract_address) @@ -97,7 +97,7 @@ impl StateReader for InMemoryStateReader { class_hash.cloned() } - fn get_nonce_at(&mut self, contract_address: &Address) -> Result { + fn get_nonce_at(&self, contract_address: &Address) -> Result { let nonce = self .address_to_nonce .get(contract_address) @@ -105,7 +105,7 @@ impl StateReader for InMemoryStateReader { nonce.cloned() } - fn get_storage_at(&mut self, storage_entry: &StorageEntry) -> Result { + fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result { let storage = self .address_to_storage .get(storage_entry) @@ -114,7 +114,7 @@ impl StateReader for InMemoryStateReader { } fn get_compiled_class_hash( - &mut self, + &self, class_hash: &ClassHash, ) -> Result { self.class_hash_to_compiled_class_hash @@ -123,7 +123,7 @@ impl StateReader for InMemoryStateReader { .copied() } - fn get_contract_class(&mut self, class_hash: &ClassHash) -> Result { + fn get_contract_class(&self, class_hash: &ClassHash) -> Result { // Deprecated contract classes dont have a compiled_class_hash, we dont need to fetch it if let Some(compiled_class) = self.class_hash_to_contract_class.get(class_hash) { return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone()))); diff --git a/src/state/mod.rs b/src/state/mod.rs index 824527351..7698c3bf9 100644 --- a/src/state/mod.rs +++ b/src/state/mod.rs @@ -13,7 +13,7 @@ use crate::{ }; use cairo_vm::{felt::Felt252, vm::runners::cairo_runner::ExecutionResources}; use getset::Getters; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use crate::{ transaction::error::TransactionError, @@ -166,7 +166,7 @@ impl StateDiff { }) } - pub fn to_cached_state(&self, state_reader: T) -> Result, StateError> + pub fn to_cached_state(&self, state_reader: Arc) -> Result, StateError> where T: StateReader + Clone, { @@ -238,7 +238,7 @@ fn test_validate_legal_progress() { #[cfg(test)] mod test { - use std::collections::HashMap; + use std::{collections::HashMap, sync::Arc}; use super::StateDiff; use crate::{ @@ -267,7 +267,7 @@ mod test { .address_to_nonce .insert(contract_address, nonce); - let cached_state = CachedState::new(state_reader, None, None); + let cached_state = CachedState::new(Arc::new(state_reader), None, None); let diff = StateDiff::from_cached_state(cached_state).unwrap(); @@ -327,11 +327,11 @@ mod test { .address_to_nonce .insert(contract_address.clone(), nonce); - let mut cached_state_original = CachedState::new(state_reader.clone(), None, None); + let cached_state_original = CachedState::new(Arc::new(state_reader.clone()), None, None); let diff = StateDiff::from_cached_state(cached_state_original.clone()).unwrap(); - let mut cached_state = diff.to_cached_state(state_reader).unwrap(); + let cached_state = diff.to_cached_state(Arc::new(state_reader)).unwrap(); assert_eq!( cached_state_original.contract_classes(), @@ -375,7 +375,7 @@ mod test { HashMap::new(), ); let cached_state = CachedState::new_for_testing( - state_reader, + Arc::new(state_reader), Some(ContractClassCache::new()), cache, None, diff --git a/src/state/state_api.rs b/src/state/state_api.rs index 873a3413c..2dbfbe8b3 100644 --- a/src/state/state_api.rs +++ b/src/state/state_api.rs @@ -12,16 +12,16 @@ use cairo_vm::felt::Felt252; pub trait StateReader { /// Returns the contract class of the given class hash or compiled class hash. - fn get_contract_class(&mut self, class_hash: &ClassHash) -> Result; + fn get_contract_class(&self, class_hash: &ClassHash) -> Result; /// Returns the class hash of the contract class at the given address. - fn get_class_hash_at(&mut self, contract_address: &Address) -> Result; + fn get_class_hash_at(&self, contract_address: &Address) -> Result; /// Returns the nonce of the given contract instance. - fn get_nonce_at(&mut self, contract_address: &Address) -> Result; + fn get_nonce_at(&self, contract_address: &Address) -> Result; /// Returns the storage value under the given key in the given contract instance. - fn get_storage_at(&mut self, storage_entry: &StorageEntry) -> Result; + fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result; /// Return the class hash of the given casm contract class fn get_compiled_class_hash( - &mut self, + &self, class_hash: &ClassHash, ) -> Result; } @@ -32,29 +32,46 @@ pub trait State { class_hash: &ClassHash, contract_class: &ContractClass, ) -> Result<(), StateError>; + fn deploy_contract( &mut self, contract_address: Address, class_hash: ClassHash, ) -> Result<(), StateError>; + fn increment_nonce(&mut self, contract_address: &Address) -> Result<(), StateError>; + fn set_storage_at(&mut self, storage_entry: &StorageEntry, value: Felt252); + fn set_class_hash_at( &mut self, contract_address: Address, class_hash: ClassHash, ) -> Result<(), StateError>; + fn set_compiled_class( &mut self, compiled_class_hash: &Felt252, casm_class: CasmContractClass, ) -> Result<(), StateError>; + fn set_compiled_class_hash( &mut self, class_hash: &Felt252, compiled_class_hash: &Felt252, ) -> Result<(), StateError>; fn apply_state_update(&mut self, sate_updates: &StateDiff) -> Result<(), StateError>; + /// Counts the amount of modified contracts and the updates to the storage fn count_actual_storage_changes(&mut self) -> (usize, usize); + + fn get_class_hash_at(&mut self, contract_address: &Address) -> Result; + + fn get_nonce_at(&mut self, contract_address: &Address) -> Result; + + fn get_storage_at(&mut self, storage_entry: &StorageEntry) -> Result; + + fn get_compiled_class_hash(&mut self, class_hash: &ClassHash) -> Result; + + fn get_contract_class(&mut self, class_hash: &ClassHash) -> Result; } diff --git a/src/syscalls/business_logic_syscall_handler.rs b/src/syscalls/business_logic_syscall_handler.rs index 835976525..e0d9ab0bd 100644 --- a/src/syscalls/business_logic_syscall_handler.rs +++ b/src/syscalls/business_logic_syscall_handler.rs @@ -21,7 +21,9 @@ use super::{ }; use crate::definitions::block_context::BlockContext; use crate::definitions::constants::BLOCK_HASH_CONTRACT_ADDRESS; +use crate::execution::execution_entry_point::ExecutionResult; use crate::services::api::contract_classes::compiled_class::CompiledClass; +use crate::state::cached_state::CachedState; use crate::state::BlockInfo; use crate::transaction::error::TransactionError; use crate::utils::calculate_sn_keccak; @@ -113,7 +115,7 @@ lazy_static! { } #[derive(Debug)] -pub struct BusinessLogicSyscallHandler<'a, T: State + StateReader> { +pub struct BusinessLogicSyscallHandler<'a, S: StateReader> { pub(crate) events: Vec, pub(crate) expected_syscall_ptr: Relocatable, pub(crate) resources_manager: ExecutionResourcesManager, @@ -124,7 +126,7 @@ pub struct BusinessLogicSyscallHandler<'a, T: State + StateReader> { pub(crate) read_only_segments: Vec<(Relocatable, MaybeRelocatable)>, pub(crate) internal_calls: Vec, pub(crate) block_context: BlockContext, - pub(crate) starknet_storage_state: ContractStorageState<'a, T>, + pub(crate) starknet_storage_state: ContractStorageState<'a, S>, pub(crate) support_reverted: bool, pub(crate) entry_point_selector: Felt252, pub(crate) selector_to_syscall: &'a HashMap, @@ -132,11 +134,11 @@ pub struct BusinessLogicSyscallHandler<'a, T: State + StateReader> { // TODO: execution entry point may no be a parameter field, but there is no way to generate a default for now -impl<'a, T: State + StateReader> BusinessLogicSyscallHandler<'a, T> { +impl<'a, S: StateReader> BusinessLogicSyscallHandler<'a, S> { #[allow(clippy::too_many_arguments)] pub fn new( tx_execution_context: TransactionExecutionContext, - state: &'a mut T, + state: &'a mut CachedState, resources_manager: ExecutionResourcesManager, caller_address: Address, contract_address: Address, @@ -168,7 +170,7 @@ impl<'a, T: State + StateReader> BusinessLogicSyscallHandler<'a, T> { selector_to_syscall: &SELECTOR_TO_SYSCALL, } } - pub fn default_with_state(state: &'a mut T) -> Self { + pub fn default_with_state(state: &'a mut CachedState) -> Self { BusinessLogicSyscallHandler::new_for_testing( BlockInfo::default(), Default::default(), @@ -179,7 +181,7 @@ impl<'a, T: State + StateReader> BusinessLogicSyscallHandler<'a, T> { pub fn new_for_testing( block_info: BlockInfo, _contract_address: Address, - state: &'a mut T, + state: &'a mut CachedState, ) -> Self { let syscalls = Vec::from([ "emit_event".to_string(), @@ -237,17 +239,26 @@ impl<'a, T: State + StateReader> BusinessLogicSyscallHandler<'a, T> { remaining_gas: u128, execution_entry_point: ExecutionEntryPoint, ) -> Result { - let result = execution_entry_point + let ExecutionResult { + call_info, + revert_error, + .. + } = execution_entry_point .execute( self.starknet_storage_state.state, &self.block_context, &mut self.resources_manager, &mut self.tx_execution_context, - self.support_reverted, + false, + self.block_context.invoke_tx_max_n_steps, ) .map_err(|err| SyscallHandlerError::ExecutionError(err.to_string()))?; - let retdata_maybe_reloc = result + let call_info = call_info.ok_or(SyscallHandlerError::ExecutionError( + revert_error.unwrap_or("Execution error".to_string()), + ))?; + + let retdata_maybe_reloc = call_info .retdata .clone() .into_iter() @@ -255,12 +266,12 @@ impl<'a, T: State + StateReader> BusinessLogicSyscallHandler<'a, T> { .collect::>(); let retdata_start = self.allocate_segment(vm, retdata_maybe_reloc)?; - let retdata_end = (retdata_start + result.retdata.len())?; + let retdata_end = (retdata_start + call_info.retdata.len())?; - let remaining_gas = remaining_gas.saturating_sub(result.gas_consumed); + let remaining_gas = remaining_gas.saturating_sub(call_info.gas_consumed); let gas = remaining_gas; - let body = if result.failure_flag { + let body = if call_info.failure_flag { Some(ResponseBody::Failure(FailureReason { retdata_start, retdata_end, @@ -272,7 +283,7 @@ impl<'a, T: State + StateReader> BusinessLogicSyscallHandler<'a, T> { })) }; - self.internal_calls.push(result); + self.internal_calls.push(call_info); Ok(SyscallResponse { gas, body }) } @@ -338,16 +349,25 @@ impl<'a, T: State + StateReader> BusinessLogicSyscallHandler<'a, T> { remainig_gas, ); - let call_info = call + let ExecutionResult { + call_info, + revert_error, + .. + } = call .execute( self.starknet_storage_state.state, &self.block_context, &mut self.resources_manager, &mut self.tx_execution_context, self.support_reverted, + self.block_context.invoke_tx_max_n_steps, ) .map_err(|_| StateError::ExecutionEntryPoint())?; + let call_info = call_info.ok_or(StateError::CustomError( + revert_error.unwrap_or("Execution error".to_string()), + ))?; + self.internal_calls.push(call_info.clone()); Ok(call_info.result()) @@ -521,10 +541,7 @@ impl<'a, T: State + StateReader> BusinessLogicSyscallHandler<'a, T> { } } -impl<'a, T> BusinessLogicSyscallHandler<'a, T> -where - T: State + StateReader, -{ +impl<'a, S: StateReader> BusinessLogicSyscallHandler<'a, S> { fn emit_event( &mut self, vm: &VirtualMachine, diff --git a/src/syscalls/deprecated_business_logic_syscall_handler.rs b/src/syscalls/deprecated_business_logic_syscall_handler.rs index 35b4b7d02..19c55b002 100644 --- a/src/syscalls/deprecated_business_logic_syscall_handler.rs +++ b/src/syscalls/deprecated_business_logic_syscall_handler.rs @@ -10,14 +10,16 @@ use super::{ syscall_handler_errors::SyscallHandlerError, syscall_info::get_deprecated_syscall_size_from_name, }; -use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; use crate::{ core::errors::state_errors::StateError, definitions::{ block_context::BlockContext, constants::{CONSTRUCTOR_ENTRY_POINT_SELECTOR, INITIAL_GAS_COST}, }, - execution::{execution_entry_point::ExecutionEntryPoint, *}, + execution::{ + execution_entry_point::{ExecutionEntryPoint, ExecutionResult}, + *, + }, hash_utils::calculate_contract_address, services::api::{ contract_class_errors::ContractClassError, contract_classes::compiled_class::CompiledClass, @@ -31,20 +33,23 @@ use crate::{ transaction::error::TransactionError, utils::*, }; +use crate::{ + services::api::contract_classes::deprecated_contract_class::EntryPointType, + state::cached_state::CachedState, +}; use cairo_vm::felt::Felt252; use cairo_vm::{ types::relocatable::{MaybeRelocatable, Relocatable}, vm::vm_core::VirtualMachine, }; use num_traits::{One, ToPrimitive, Zero}; -use std::borrow::{Borrow, BorrowMut}; //* ----------------------------------- //* DeprecatedBLSyscallHandler implementation //* ----------------------------------- /// Deprecated version of BusinessLogicSyscallHandler. #[derive(Debug)] -pub struct DeprecatedBLSyscallHandler<'a, T: State + StateReader> { +pub struct DeprecatedBLSyscallHandler<'a, S: StateReader> { pub(crate) tx_execution_context: TransactionExecutionContext, /// Events emitted by the current contract call. pub(crate) events: Vec, @@ -56,15 +61,15 @@ pub struct DeprecatedBLSyscallHandler<'a, T: State + StateReader> { pub(crate) l2_to_l1_messages: Vec, pub(crate) block_context: BlockContext, pub(crate) tx_info_ptr: Option, - pub(crate) starknet_storage_state: ContractStorageState<'a, T>, + pub(crate) starknet_storage_state: ContractStorageState<'a, S>, pub(crate) internal_calls: Vec, pub(crate) expected_syscall_ptr: Relocatable, } -impl<'a, T: State + StateReader> DeprecatedBLSyscallHandler<'a, T> { +impl<'a, S: StateReader> DeprecatedBLSyscallHandler<'a, S> { pub fn new( tx_execution_context: TransactionExecutionContext, - state: &'a mut T, + state: &'a mut CachedState, resources_manager: ExecutionResourcesManager, caller_address: Address, contract_address: Address, @@ -95,7 +100,7 @@ impl<'a, T: State + StateReader> DeprecatedBLSyscallHandler<'a, T> { } } - pub fn default_with(state: &'a mut T) -> Self { + pub fn default_with(state: &'a mut CachedState) -> Self { DeprecatedBLSyscallHandler::new_for_testing(BlockInfo::default(), Default::default(), state) } @@ -108,7 +113,7 @@ impl<'a, T: State + StateReader> DeprecatedBLSyscallHandler<'a, T> { pub fn new_for_testing( block_info: BlockInfo, _contract_address: Address, - state: &'a mut T, + state: &'a mut CachedState, ) -> Self { let syscalls = Vec::from([ "emit_event".to_string(), @@ -234,34 +239,14 @@ impl<'a, T: State + StateReader> DeprecatedBLSyscallHandler<'a, T> { &mut self.resources_manager, &mut self.tx_execution_context, false, + self.block_context.invoke_tx_max_n_steps, ) .map_err(|_| StateError::ExecutionEntryPoint())?; Ok(()) } } -impl<'a, T> Borrow for DeprecatedBLSyscallHandler<'a, T> -where - T: State + StateReader, -{ - fn borrow(&self) -> &T { - self.starknet_storage_state.state - } -} - -impl<'a, T> BorrowMut for DeprecatedBLSyscallHandler<'a, T> -where - T: State + StateReader, -{ - fn borrow_mut(&mut self) -> &mut T { - self.starknet_storage_state.state - } -} - -impl<'a, T> DeprecatedBLSyscallHandler<'a, T> -where - T: State + StateReader, -{ +impl<'a, S: StateReader> DeprecatedBLSyscallHandler<'a, S> { pub(crate) fn emit_event( &mut self, vm: &VirtualMachine, @@ -451,21 +436,29 @@ where ); entry_point.code_address = code_address; - entry_point + let ExecutionResult { + call_info, + revert_error, + .. + } = entry_point .execute( self.starknet_storage_state.state, &self.block_context, &mut self.resources_manager, &mut self.tx_execution_context, false, + self.block_context.invoke_tx_max_n_steps, ) - .map(|x| { - let retdata = x.retdata.clone(); - self.internal_calls.push(x); + .map_err(|e| SyscallHandlerError::ExecutionError(e.to_string()))?; + + let call_info = call_info.ok_or(SyscallHandlerError::ExecutionError( + revert_error.unwrap_or("Execution error".to_string()), + ))?; + + let retdata = call_info.retdata.clone(); + self.internal_calls.push(call_info); - retdata - }) - .map_err(|e| SyscallHandlerError::ExecutionError(e.to_string())) + Ok(retdata) } pub(crate) fn get_block_info(&self) -> &BlockInfo { @@ -946,7 +939,7 @@ mod tests { use std::{any::Any, borrow::Cow, collections::HashMap}; type DeprecatedBLSyscallHandler<'a> = - super::DeprecatedBLSyscallHandler<'a, CachedState>; + super::DeprecatedBLSyscallHandler<'a, InMemoryStateReader>; #[test] fn run_alloc_hint_ap_is_not_empty() { diff --git a/src/syscalls/deprecated_syscall_handler.rs b/src/syscalls/deprecated_syscall_handler.rs index 149e972d6..ba10df2a8 100644 --- a/src/syscalls/deprecated_syscall_handler.rs +++ b/src/syscalls/deprecated_syscall_handler.rs @@ -2,10 +2,7 @@ use super::{ deprecated_business_logic_syscall_handler::DeprecatedBLSyscallHandler, hint_code::*, other_syscalls, syscall_handler::HintProcessorPostRun, }; -use crate::{ - state::state_api::{State, StateReader}, - syscalls::syscall_handler_errors::SyscallHandlerError, -}; +use crate::{state::state_api::StateReader, syscalls::syscall_handler_errors::SyscallHandlerError}; use cairo_vm::{ felt::Felt252, hint_processor::hint_processor_definition::HintProcessorLogic, @@ -25,15 +22,15 @@ use cairo_vm::{ }; use std::{any::Any, collections::HashMap}; -pub(crate) struct DeprecatedSyscallHintProcessor<'a, T: State + StateReader> { +pub(crate) struct DeprecatedSyscallHintProcessor<'a, S: StateReader> { pub(crate) builtin_hint_processor: BuiltinHintProcessor, - pub(crate) syscall_handler: DeprecatedBLSyscallHandler<'a, T>, + pub(crate) syscall_handler: DeprecatedBLSyscallHandler<'a, S>, run_resources: RunResources, } -impl<'a, T: State + StateReader> DeprecatedSyscallHintProcessor<'a, T> { +impl<'a, S: StateReader> DeprecatedSyscallHintProcessor<'a, S> { pub fn new( - syscall_handler: DeprecatedBLSyscallHandler<'a, T>, + syscall_handler: DeprecatedBLSyscallHandler<'a, S>, run_resources: RunResources, ) -> Self { DeprecatedSyscallHintProcessor { @@ -152,7 +149,7 @@ impl<'a, T: State + StateReader> DeprecatedSyscallHintProcessor<'a, T> { } } -impl<'a, T: State + StateReader> HintProcessorLogic for DeprecatedSyscallHintProcessor<'a, T> { +impl<'a, S: StateReader> HintProcessorLogic for DeprecatedSyscallHintProcessor<'a, S> { fn execute_hint( &mut self, vm: &mut VirtualMachine, @@ -174,7 +171,7 @@ impl<'a, T: State + StateReader> HintProcessorLogic for DeprecatedSyscallHintPro } } -impl<'a, T: State + StateReader> ResourceTracker for DeprecatedSyscallHintProcessor<'a, T> { +impl<'a, S: StateReader> ResourceTracker for DeprecatedSyscallHintProcessor<'a, S> { fn consumed(&self) -> bool { self.run_resources.consumed() } @@ -192,7 +189,7 @@ impl<'a, T: State + StateReader> ResourceTracker for DeprecatedSyscallHintProces } } -impl<'a, T: State + StateReader> HintProcessorPostRun for DeprecatedSyscallHintProcessor<'a, T> { +impl<'a, S: StateReader> HintProcessorPostRun for DeprecatedSyscallHintProcessor<'a, S> { fn post_run( &self, runner: &mut VirtualMachine, @@ -214,6 +211,8 @@ fn get_syscall_ptr( #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; use crate::{ @@ -226,10 +225,7 @@ mod tests { memory_insert, services::api::contract_classes::deprecated_contract_class::ContractClass, state::in_memory_state_reader::InMemoryStateReader, - state::{ - cached_state::CachedState, - state_api::{State, StateReader}, - }, + state::{cached_state::CachedState, state_api::State}, syscalls::deprecated_syscall_request::{ DeprecatedDeployRequest, DeprecatedSendMessageToL1SysCallRequest, DeprecatedSyscallRequest, @@ -247,7 +243,7 @@ mod tests { type DeprecatedBLSyscallHandler<'a> = crate::syscalls::deprecated_business_logic_syscall_handler::DeprecatedBLSyscallHandler< 'a, - CachedState, + InMemoryStateReader, >; type SyscallHintProcessor<'a, T> = super::DeprecatedSyscallHintProcessor<'a, T>; @@ -712,7 +708,7 @@ mod tests { ] ); - let mut state = CachedState::::default(); + let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), None, None); let mut hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -747,7 +743,7 @@ mod tests { let hint_data = HintProcessorData::new_default(GET_CONTRACT_ADDRESS.to_string(), ids_data); // invoke syscall - let mut state = CachedState::::default(); + let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), None, None); let mut hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -788,7 +784,7 @@ mod tests { let hint_data = HintProcessorData::new_default(GET_TX_SIGNATURE.to_string(), ids_data); // invoke syscall - let mut state = CachedState::::default(); + let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), None, None); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -855,7 +851,7 @@ mod tests { let hint_data = HintProcessorData::new_default(STORAGE_READ.to_string(), ids_data); - let mut state = CachedState::::default(); + let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), None, None); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -919,7 +915,7 @@ mod tests { let hint_data = HintProcessorData::new_default(STORAGE_WRITE.to_string(), ids_data); - let mut state = CachedState::::default(); + let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), None, None); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -992,7 +988,7 @@ mod tests { let hint_data = HintProcessorData::new_default(DEPLOY.to_string(), ids_data); // Create SyscallHintProcessor - let mut state = CachedState::::default(); + let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), None, None); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), @@ -1089,7 +1085,7 @@ mod tests { ); // Create SyscallHintProcessor - let mut state = CachedState::::default(); + let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), None, None); let mut syscall_handler_hint_processor = SyscallHintProcessor::new( DeprecatedBLSyscallHandler::default_with(&mut state), RunResources::default(), diff --git a/src/syscalls/deprecated_syscall_response.rs b/src/syscalls/deprecated_syscall_response.rs index 5456dccae..1851f25ef 100644 --- a/src/syscalls/deprecated_syscall_response.rs +++ b/src/syscalls/deprecated_syscall_response.rs @@ -300,6 +300,8 @@ impl DeprecatedWriteSyscallResponse for DeprecatedStorageReadResponse { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::{ add_segments, @@ -312,12 +314,12 @@ mod tests { type DeprecatedBLSyscallHandler<'a> = crate::syscalls::deprecated_business_logic_syscall_handler::DeprecatedBLSyscallHandler< 'a, - CachedState, + InMemoryStateReader, >; #[test] fn write_get_caller_address_response() { - let mut state = CachedState::::default(); + let mut state = CachedState::new(Arc::new(InMemoryStateReader::default()), None, None); let syscall = DeprecatedBLSyscallHandler::default_with(&mut state); let mut vm = vm!(); diff --git a/src/syscalls/syscall_handler.rs b/src/syscalls/syscall_handler.rs index c1761362e..67db9ebad 100644 --- a/src/syscalls/syscall_handler.rs +++ b/src/syscalls/syscall_handler.rs @@ -1,5 +1,5 @@ use super::business_logic_syscall_handler::BusinessLogicSyscallHandler; -use crate::state::state_api::{State, StateReader}; +use crate::state::state_api::StateReader; use crate::transaction::error::TransactionError; use cairo_lang_casm::{ hints::{Hint, StarknetHint}, @@ -32,15 +32,15 @@ pub(crate) trait HintProcessorPostRun { } #[allow(unused)] -pub(crate) struct SyscallHintProcessor<'a, T: State + StateReader> { +pub(crate) struct SyscallHintProcessor<'a, S: StateReader> { pub(crate) cairo1_hint_processor: Cairo1HintProcessor, - pub(crate) syscall_handler: BusinessLogicSyscallHandler<'a, T>, + pub(crate) syscall_handler: BusinessLogicSyscallHandler<'a, S>, pub(crate) run_resources: RunResources, } -impl<'a, T: State + StateReader> SyscallHintProcessor<'a, T> { +impl<'a, S: StateReader> SyscallHintProcessor<'a, S> { pub fn new( - syscall_handler: BusinessLogicSyscallHandler<'a, T>, + syscall_handler: BusinessLogicSyscallHandler<'a, S>, hints: &[(usize, Vec)], run_resources: RunResources, ) -> Self { @@ -52,7 +52,7 @@ impl<'a, T: State + StateReader> SyscallHintProcessor<'a, T> { } } -impl<'a, T: State + StateReader> HintProcessorLogic for SyscallHintProcessor<'a, T> { +impl<'a, S: StateReader> HintProcessorLogic for SyscallHintProcessor<'a, S> { fn execute_hint( &mut self, vm: &mut VirtualMachine, @@ -111,7 +111,7 @@ impl<'a, T: State + StateReader> HintProcessorLogic for SyscallHintProcessor<'a, } } -impl<'a, T: State + StateReader> ResourceTracker for SyscallHintProcessor<'a, T> { +impl<'a, S: StateReader> ResourceTracker for SyscallHintProcessor<'a, S> { fn consumed(&self) -> bool { self.run_resources.consumed() } @@ -129,7 +129,7 @@ impl<'a, T: State + StateReader> ResourceTracker for SyscallHintProcessor<'a, T> } } -impl<'a, T: State + StateReader> HintProcessorPostRun for SyscallHintProcessor<'a, T> { +impl<'a, S: StateReader> HintProcessorPostRun for SyscallHintProcessor<'a, S> { fn post_run( &self, runner: &mut VirtualMachine, diff --git a/src/testing/erc20.rs b/src/testing/erc20.rs index 4dd0eefe4..eab540352 100644 --- a/src/testing/erc20.rs +++ b/src/testing/erc20.rs @@ -1,5 +1,5 @@ #![allow(unused_imports)] -use std::{collections::HashMap, io::Bytes, path::Path, vec}; +use std::{collections::HashMap, io::Bytes, path::Path, sync::Arc}; use crate::{ call_contract, @@ -70,7 +70,7 @@ fn test_erc20_cairo2() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); let name_ = Felt252::from_bytes_be(b"some-token"); let symbol_ = Felt252::from_bytes_be(b"my-super-awesome-token"); @@ -125,9 +125,10 @@ fn test_erc20_cairo2() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps, ) .unwrap(); - let erc20_address = call_info.retdata.get(0).unwrap().clone(); + let erc20_address = call_info.call_info.unwrap().retdata.get(0).unwrap().clone(); // ACCOUNT 1 let program_data_account = diff --git a/src/testing/mod.rs b/src/testing/mod.rs index 1e7141d92..6d784bab6 100644 --- a/src/testing/mod.rs +++ b/src/testing/mod.rs @@ -3,7 +3,7 @@ pub mod state; pub mod state_error; pub mod type_utils; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use cairo_vm::felt::{felt_str, Felt252}; use lazy_static::lazy_static; @@ -153,7 +153,7 @@ pub fn create_account_tx_test_state( .class_hash_to_contract_class_mut() .insert(class_hash, contract_class); } - state_reader + Arc::new(state_reader) }, Some(HashMap::new()), Some(HashMap::new()), diff --git a/src/testing/state.rs b/src/testing/state.rs index 2fb1a12c2..dacde9d66 100644 --- a/src/testing/state.rs +++ b/src/testing/state.rs @@ -1,4 +1,5 @@ use super::{state_error::StarknetStateError, type_utils::ExecutionInfo}; +use crate::execution::execution_entry_point::ExecutionResult; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; use crate::{ definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION}, @@ -22,6 +23,7 @@ use crate::{ use cairo_vm::felt::Felt252; use num_traits::{One, Zero}; use std::collections::HashMap; +use std::sync::Arc; // --------------------------------------------------------------------- /// StarkNet testing object. Represents a state of a StarkNet network. @@ -36,7 +38,7 @@ pub struct StarknetState { impl StarknetState { pub fn new(context: Option) -> Self { let block_context = context.unwrap_or_default(); - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let state = CachedState::new(state_reader, Some(HashMap::new()), Some(HashMap::new())); @@ -146,14 +148,19 @@ impl StarknetState { let mut resources_manager = ExecutionResourcesManager::default(); let mut tx_execution_context = TransactionExecutionContext::default(); - let call_info = call.execute( + let ExecutionResult { call_info, .. } = call.execute( &mut self.state, &self.block_context, &mut resources_manager, &mut tx_execution_context, false, + self.block_context.invoke_tx_max_n_steps, )?; + let call_info = call_info.ok_or(StarknetStateError::Transaction( + TransactionError::CallInfoIsNone, + ))?; + let exec_info = ExecutionInfo::Call(Box::new(call_info.clone())); self.add_messages_and_events(&exec_info)?; @@ -364,6 +371,7 @@ mod tests { entry_point_type: Some(EntryPointType::Constructor), ..Default::default() }), + revert_error: None, fee_transfer_info: None, actual_fee: 0, actual_resources, @@ -435,7 +443,7 @@ mod tests { .class_hash_to_contract_class_mut() .insert(class_hash, contract_class.clone()); - let state = CachedState::new(state_reader, Some(contract_class_cache), None); + let state = CachedState::new(Arc::new(state_reader), Some(contract_class_cache), None); //* -------------------------------------------- //* Create starknet state with previous data @@ -446,25 +454,21 @@ mod tests { starknet_state.state = state; starknet_state .state - .state_reader - .address_to_class_hash_mut() - .insert(sender_address.clone(), class_hash); + .set_class_hash_at(sender_address.clone(), class_hash) + .unwrap(); starknet_state .state - .state_reader - .address_to_nonce_mut() + .cache + .nonce_writes .insert(sender_address.clone(), nonce); + + starknet_state.state.set_storage_at(&storage_entry, storage); + starknet_state .state - .state_reader - .address_to_storage_mut() - .insert(storage_entry, storage); - starknet_state - .state - .state_reader - .class_hash_to_contract_class_mut() - .insert(class_hash, contract_class); + .set_contract_class(&class_hash, &contract_class) + .unwrap(); // -------------------------------------------- // Test declare with starknet state diff --git a/src/transaction/declare.rs b/src/transaction/declare.rs index d3dbd58d4..e9a73ea5d 100644 --- a/src/transaction/declare.rs +++ b/src/transaction/declare.rs @@ -1,5 +1,7 @@ use crate::definitions::constants::QUERY_VERSION_BASE; +use crate::execution::execution_entry_point::ExecutionResult; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; +use crate::state::cached_state::CachedState; use crate::{ core::{ contract_address::compute_deprecated_class_hash, @@ -153,9 +155,9 @@ impl Declare { /// Executes a call to the cairo-vm using the accounts_validation.cairo contract to validate /// the contract that is being declared. Then it returns the transaction execution info of the run. - pub fn apply( + pub fn apply( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, ) -> Result { verify_version(&self.version, self.max_fee, &self.nonce, &self.signature)?; @@ -180,6 +182,7 @@ impl Declare { Ok(TransactionExecutionInfo::new_without_fee_info( validate_info, None, + None, actual_resources, Some(self.tx_type), )) @@ -200,9 +203,9 @@ impl Declare { ) } - pub fn run_validate_entrypoint( + pub fn run_validate_entrypoint( &self, - state: &mut S, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, ) -> Result, TransactionError> { @@ -223,14 +226,17 @@ impl Declare { 0, ); - let call_info = entry_point.execute( + let ExecutionResult { call_info, .. } = entry_point.execute( state, block_context, resources_manager, &mut self.get_execution_context(block_context.invoke_tx_max_n_steps), false, + block_context.validate_max_n_steps, )?; + let call_info = call_info.ok_or(TransactionError::CallInfoIsNone)?; + verify_no_calls_to_other_contracts(&call_info) .map_err(|_| TransactionError::UnauthorizedActionOnValidate)?; @@ -238,9 +244,9 @@ impl Declare { } /// Calculates and charges the actual fee. - pub fn charge_fee( + pub fn charge_fee( &self, - state: &mut S, + state: &mut CachedState, resources: &HashMap, block_context: &BlockContext, ) -> Result { @@ -290,9 +296,9 @@ impl Declare { /// Calculates actual fee used by the transaction using the execution /// info returned by apply(), then updates the transaction execution info with the data of the fee. - pub fn execute( + pub fn execute( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, ) -> Result { let mut tx_exec_info = self.apply(state, block_context)?; @@ -337,7 +343,7 @@ mod tests { vm::runners::cairo_runner::ExecutionResources, }; use num_traits::{One, Zero}; - use std::{collections::HashMap, path::PathBuf}; + use std::{collections::HashMap, path::PathBuf, sync::Arc}; use crate::{ definitions::{ @@ -383,7 +389,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address, Felt252::new(1)); - let mut state = CachedState::new(state_reader, Some(contract_class_cache), None); + let mut state = CachedState::new(Arc::new(state_reader), Some(contract_class_cache), None); //* --------------------------------------- //* Test declare with previous data @@ -445,6 +451,7 @@ mod tests { let transaction_exec_info = TransactionExecutionInfo { validate_info, call_info: None, + revert_error: None, fee_transfer_info: None, actual_fee: 0, actual_resources, @@ -477,22 +484,6 @@ mod tests { contract_class_cache.insert(class_hash, contract_class); - // store sender_address - let sender_address = Address(1.into()); - // this is not conceptually correct as the sender address would be an - // Account contract (not the contract that we are currently declaring) - // but for testing reasons its ok - - let mut state_reader = InMemoryStateReader::default(); - state_reader - .address_to_class_hash_mut() - .insert(sender_address.clone(), class_hash); - state_reader - .address_to_nonce_mut() - .insert(sender_address, Felt252::new(1)); - - let _state = CachedState::new(state_reader, Some(contract_class_cache), None); - //* --------------------------------------- //* Test declare with previous data //* --------------------------------------- @@ -554,7 +545,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address, Felt252::new(1)); - let _state = CachedState::new(state_reader, Some(contract_class_cache), None); + let _state = CachedState::new(Arc::new(state_reader), Some(contract_class_cache), None); //* --------------------------------------- //* Test declare with previous data @@ -617,7 +608,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address, Felt252::new(1)); - let _state = CachedState::new(state_reader, Some(contract_class_cache), None); + let _state = CachedState::new(Arc::new(state_reader), Some(contract_class_cache), None); //* --------------------------------------- //* Test declare with previous data @@ -679,7 +670,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address, Felt252::zero()); - let mut state = CachedState::new(state_reader, Some(contract_class_cache), None); + let mut state = CachedState::new(Arc::new(state_reader), Some(contract_class_cache), None); //* --------------------------------------- //* Test declare with previous data @@ -755,7 +746,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address, Felt252::zero()); - let mut state = CachedState::new(state_reader, Some(contract_class_cache), None); + let mut state = CachedState::new(Arc::new(state_reader), Some(contract_class_cache), None); //* --------------------------------------- //* Test declare with previous data @@ -800,7 +791,7 @@ mod tests { // Instantiate CachedState let contract_class_cache = HashMap::new(); - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(contract_class_cache), None); @@ -859,7 +850,7 @@ mod tests { .address_to_nonce_mut() .insert(sender_address, Felt252::zero()); - let mut state = CachedState::new(state_reader, Some(contract_class_cache), None); + let mut state = CachedState::new(Arc::new(state_reader), Some(contract_class_cache), None); //* --------------------------------------- //* Test declare with previous data diff --git a/src/transaction/declare_v2.rs b/src/transaction/declare_v2.rs index 99ddeef66..3ab61e2be 100644 --- a/src/transaction/declare_v2.rs +++ b/src/transaction/declare_v2.rs @@ -1,8 +1,10 @@ use super::{verify_version, Transaction}; use crate::core::contract_address::{compute_casm_class_hash, compute_sierra_class_hash}; use crate::definitions::constants::QUERY_VERSION_BASE; +use crate::execution::execution_entry_point::ExecutionResult; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; +use crate::state::cached_state::CachedState; use crate::{ core::transaction_hash::calculate_declare_v2_transaction_hash, definitions::{ @@ -277,9 +279,9 @@ impl DeclareV2 { /// - state: An state that implements the State and StateReader traits. /// - resources: the resources that are in use by the contract /// - block_context: The block that contains the execution context - pub fn charge_fee( + pub fn charge_fee( &self, - state: &mut S, + state: &mut CachedState, resources: &HashMap, block_context: &BlockContext, ) -> Result { @@ -334,9 +336,9 @@ impl DeclareV2 { /// ## Parameter: /// - state: An state that implements the State and StateReader traits. /// - block_context: The block that contains the execution context - pub fn execute( + pub fn execute( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, ) -> Result { verify_version(&self.version, self.max_fee, &self.nonce, &self.signature)?; @@ -373,6 +375,7 @@ impl DeclareV2 { let mut tx_exec_info = TransactionExecutionInfo::new_without_fee_info( validate_info, None, + None, actual_resources, Some(self.tx_type), ); @@ -406,10 +409,10 @@ impl DeclareV2 { Ok(()) } - fn run_validate_entrypoint( + fn run_validate_entrypoint( &self, mut remaining_gas: u128, - state: &mut S, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, ) -> Result<(CallInfo, u128), TransactionError> { @@ -430,16 +433,17 @@ impl DeclareV2 { let mut tx_execution_context = self.get_execution_context(block_context.validate_max_n_steps); - let call_info = if self.skip_execute { - None + let ExecutionResult { call_info, .. } = if self.skip_execute { + ExecutionResult::default() } else { - Some(entry_point.execute( + entry_point.execute( state, block_context, resources_manager, &mut tx_execution_context, - false, - )?) + true, + block_context.validate_max_n_steps, + )? }; let call_info = verify_no_calls_to_other_contracts(&call_info)?; remaining_gas -= call_info.gas_consumed; @@ -469,6 +473,7 @@ impl DeclareV2 { #[cfg(test)] mod tests { + use std::sync::Arc; use std::{collections::HashMap, fs::File, io::BufReader, path::PathBuf}; use super::DeclareV2; @@ -527,7 +532,7 @@ mod tests { // crate state to store casm contract class let casm_contract_class_cache = HashMap::new(); - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, None, Some(casm_contract_class_cache)); // call compile and store @@ -596,7 +601,7 @@ mod tests { // crate state to store casm contract class let casm_contract_class_cache = HashMap::new(); - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, None, Some(casm_contract_class_cache)); // call compile and store @@ -667,7 +672,7 @@ mod tests { // crate state to store casm contract class let casm_contract_class_cache = HashMap::new(); - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, None, Some(casm_contract_class_cache)); // call compile and store @@ -736,7 +741,7 @@ mod tests { // crate state to store casm contract class let casm_contract_class_cache = HashMap::new(); - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, None, Some(casm_contract_class_cache)); // call compile and store @@ -806,7 +811,7 @@ mod tests { // crate state to store casm contract class let casm_contract_class_cache = HashMap::new(); - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, None, Some(casm_contract_class_cache)); let expected_err = format!( diff --git a/src/transaction/deploy.rs b/src/transaction/deploy.rs index b1d602447..4d5fd1a56 100644 --- a/src/transaction/deploy.rs +++ b/src/transaction/deploy.rs @@ -1,4 +1,6 @@ +use crate::execution::execution_entry_point::ExecutionResult; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; +use crate::state::cached_state::CachedState; use crate::{ core::{ contract_address::compute_deprecated_class_hash, errors::state_errors::StateError, @@ -139,9 +141,9 @@ impl Deploy { /// ## Parameters /// - state: A state that implements the [`State`] and [`StateReader`] traits. /// - block_context: The block's execution context. - pub fn apply( + pub fn apply( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, ) -> Result { state.deploy_contract(self.contract_address.clone(), self.contract_hash)?; @@ -188,6 +190,7 @@ impl Deploy { Ok(TransactionExecutionInfo::new_without_fee_info( None, Some(call_info), + None, actual_resources, Some(self.tx_type), )) @@ -197,9 +200,9 @@ impl Deploy { /// ## Parameters /// - state: A state that implements the [`State`] and [`StateReader`] traits. /// - block_context: The block's execution context. - pub fn invoke_constructor( + pub fn invoke_constructor( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, ) -> Result { let call = ExecutionEntryPoint::new( @@ -224,18 +227,23 @@ impl Deploy { ); let mut resources_manager = ExecutionResourcesManager::default(); - let call_info = call.execute( + let ExecutionResult { + call_info, + revert_error, + .. + } = call.execute( state, block_context, &mut resources_manager, &mut tx_execution_context, - false, + true, + block_context.validate_max_n_steps, )?; let changes = state.count_actual_storage_changes(); let actual_resources = calculate_tx_resources( resources_manager, - &[Some(call_info.clone())], + &[call_info.clone()], self.tx_type, changes, None, @@ -243,7 +251,8 @@ impl Deploy { Ok(TransactionExecutionInfo::new_without_fee_info( None, - Some(call_info), + call_info, + revert_error, actual_resources, Some(self.tx_type), )) @@ -254,9 +263,9 @@ impl Deploy { /// ## Parameters /// - state: A state that implements the [`State`] and [`StateReader`] traits. /// - block_context: The block's execution context. - pub fn execute( + pub fn execute( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, ) -> Result { let mut tx_exec_info = self.apply(state, block_context)?; @@ -290,7 +299,7 @@ impl Deploy { #[cfg(test)] mod tests { - use std::collections::HashMap; + use std::{collections::HashMap, sync::Arc}; use super::*; use crate::{ @@ -301,7 +310,7 @@ mod tests { #[test] fn invoke_constructor_test() { // Instantiate CachedState - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(Default::default()), None); // Set contract_class @@ -348,7 +357,7 @@ mod tests { #[test] fn invoke_constructor_no_calldata_should_fail() { // Instantiate CachedState - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(Default::default()), None); let contract_class = @@ -374,7 +383,7 @@ mod tests { #[test] fn deploy_contract_without_constructor_should_fail() { // Instantiate CachedState - let state_reader = InMemoryStateReader::default(); + let state_reader = Arc::new(InMemoryStateReader::default()); let mut state = CachedState::new(state_reader, Some(Default::default()), None); let contract_path = "starknet_programs/amm.json"; diff --git a/src/transaction/deploy_account.rs b/src/transaction/deploy_account.rs index bc132bc5c..7b161bf87 100644 --- a/src/transaction/deploy_account.rs +++ b/src/transaction/deploy_account.rs @@ -1,6 +1,8 @@ use super::{invoke_function::verify_no_calls_to_other_contracts, Transaction}; use crate::definitions::constants::QUERY_VERSION_BASE; +use crate::execution::execution_entry_point::ExecutionResult; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; +use crate::state::cached_state::CachedState; use crate::{ core::{ errors::state_errors::StateError, @@ -151,14 +153,11 @@ impl DeployAccount { } } - pub fn execute( + pub fn execute( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, - ) -> Result - where - S: State + StateReader, - { + ) -> Result { let mut tx_info = self.apply(state, block_context)?; self.handle_nonce(state)?; @@ -185,14 +184,11 @@ impl DeployAccount { /// Execute a call to the cairo-vm using the accounts_validation.cairo contract to validate /// the contract that is being declared. Then it returns the transaction execution info of the run. - fn apply( + fn apply( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, - ) -> Result - where - S: State + StateReader, - { + ) -> Result { let contract_class = state.get_contract_class(&self.class_hash)?; state.deploy_contract(self.contract_address.clone(), self.class_hash)?; @@ -219,21 +215,19 @@ impl DeployAccount { Ok(TransactionExecutionInfo::new_without_fee_info( validate_info, Some(constructor_call_info), + None, actual_resources, Some(TransactionType::DeployAccount), )) } - pub fn handle_constructor( + pub fn handle_constructor( &self, contract_class: CompiledClass, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, resources_manager: &mut ExecutionResourcesManager, - ) -> Result - where - S: State + StateReader, - { + ) -> Result { if self.constructor_entry_points_empty(contract_class)? { if !self.constructor_calldata.is_empty() { return Err(TransactionError::EmptyConstructorCalldata); @@ -267,15 +261,12 @@ impl DeployAccount { Ok(()) } - pub fn run_constructor_entrypoint( + pub fn run_constructor_entrypoint( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, resources_manager: &mut ExecutionResourcesManager, - ) -> Result - where - S: State + StateReader, - { + ) -> Result { let entry_point = ExecutionEntryPoint::new( self.contract_address.clone(), self.constructor_calldata.clone(), @@ -287,16 +278,17 @@ impl DeployAccount { INITIAL_GAS_COST, ); - let call_info = if self.skip_execute { - None + let ExecutionResult { call_info, .. } = if self.skip_execute { + ExecutionResult::default() } else { - Some(entry_point.execute( + entry_point.execute( state, block_context, resources_manager, &mut self.get_execution_context(block_context.validate_max_n_steps), false, - )?) + block_context.validate_max_n_steps, + )? }; let call_info = verify_no_calls_to_other_contracts(&call_info) @@ -316,15 +308,12 @@ impl DeployAccount { ) } - pub fn run_validate_entrypoint( + pub fn run_validate_entrypoint( &self, - state: &mut S, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, - ) -> Result, TransactionError> - where - S: State + StateReader, - { + ) -> Result, TransactionError> { if self.version.is_zero() || self.version == *QUERY_VERSION_BASE { return Ok(None); } @@ -346,16 +335,17 @@ impl DeployAccount { INITIAL_GAS_COST, ); - let call_info = if self.skip_execute { - None + let ExecutionResult { call_info, .. } = if self.skip_execute { + ExecutionResult::default() } else { - Some(call.execute( + call.execute( state, block_context, resources_manager, &mut self.get_execution_context(block_context.validate_max_n_steps), false, - )?) + block_context.validate_max_n_steps, + )? }; verify_no_calls_to_other_contracts(&call_info) @@ -364,15 +354,12 @@ impl DeployAccount { Ok(call_info) } - fn charge_fee( + fn charge_fee( &self, - state: &mut S, + state: &mut CachedState, resources: &HashMap, block_context: &BlockContext, - ) -> Result - where - S: State + StateReader, - { + ) -> Result { if self.max_fee.is_zero() { return Ok((None, 0)); } @@ -418,7 +405,7 @@ impl DeployAccount { #[cfg(test)] mod tests { - use std::path::PathBuf; + use std::{path::PathBuf, sync::Arc}; use super::*; use crate::{ @@ -440,7 +427,7 @@ mod tests { let block_context = BlockContext::default(); let mut _state = CachedState::new( - InMemoryStateReader::default(), + Arc::new(InMemoryStateReader::default()), Some(Default::default()), None, ); @@ -476,7 +463,7 @@ mod tests { let block_context = BlockContext::default(); let mut state = CachedState::new( - InMemoryStateReader::default(), + Arc::new(InMemoryStateReader::default()), Some(Default::default()), None, ); @@ -528,7 +515,7 @@ mod tests { let block_context = BlockContext::default(); let mut state = CachedState::new( - InMemoryStateReader::default(), + Arc::new(InMemoryStateReader::default()), Some(Default::default()), None, ); diff --git a/src/transaction/fee.rs b/src/transaction/fee.rs index 780bcde08..a42e5c8ce 100644 --- a/src/transaction/fee.rs +++ b/src/transaction/fee.rs @@ -1,5 +1,7 @@ use super::error::TransactionError; +use crate::execution::execution_entry_point::ExecutionResult; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; +use crate::state::cached_state::CachedState; use crate::{ definitions::{ block_context::BlockContext, @@ -8,7 +10,7 @@ use crate::{ execution::{ execution_entry_point::ExecutionEntryPoint, CallInfo, TransactionExecutionContext, }, - state::state_api::{State, StateReader}, + state::state_api::StateReader, state::ExecutionResourcesManager, }; use cairo_vm::felt::Felt252; @@ -20,8 +22,8 @@ pub type FeeInfo = (Option, u128); /// Transfers the amount actual_fee from the caller account to the sequencer. /// Returns the resulting CallInfo of the transfer call. -pub(crate) fn execute_fee_transfer( - state: &mut S, +pub(crate) fn execute_fee_transfer( + state: &mut CachedState, block_context: &BlockContext, tx_execution_context: &mut TransactionExecutionContext, actual_fee: u128, @@ -54,15 +56,18 @@ pub(crate) fn execute_fee_transfer( ); let mut resources_manager = ExecutionResourcesManager::default(); - let fee_transfer_exec = fee_transfer_call.execute( - state, - block_context, - &mut resources_manager, - tx_execution_context, - false, - ); - // TODO: Avoid masking the error from the fee transfer. - fee_transfer_exec.map_err(|e| TransactionError::FeeTransferError(Box::new(e))) + let ExecutionResult { call_info, .. } = fee_transfer_call + .execute( + state, + block_context, + &mut resources_manager, + tx_execution_context, + false, + block_context.invoke_tx_max_n_steps, + ) + .map_err(|e| TransactionError::FeeTransferError(Box::new(e)))?; + + call_info.ok_or(TransactionError::CallInfoIsNone) } // ---------------------------------------------------------------------------------------- diff --git a/src/transaction/invoke_function.rs b/src/transaction/invoke_function.rs index 6edfc1804..bf5b44a3e 100644 --- a/src/transaction/invoke_function.rs +++ b/src/transaction/invoke_function.rs @@ -10,11 +10,11 @@ use crate::{ transaction_type::TransactionType, }, execution::{ - execution_entry_point::ExecutionEntryPoint, CallInfo, TransactionExecutionContext, - TransactionExecutionInfo, + execution_entry_point::{ExecutionEntryPoint, ExecutionResult}, + CallInfo, TransactionExecutionContext, TransactionExecutionInfo, }, state::state_api::{State, StateReader}, - state::ExecutionResourcesManager, + state::{cached_state::CachedState, ExecutionResourcesManager}, transaction::{ error::TransactionError, fee::{calculate_tx_fee, execute_fee_transfer, FeeInfo}, @@ -147,15 +147,12 @@ impl InvokeFunction { /// - state: A state that implements the [`State`] and [`StateReader`] traits. /// - resources_manager: the resources that are in use by the contract /// - block_context: The block's execution context - pub(crate) fn run_validate_entrypoint( + pub(crate) fn run_validate_entrypoint( &self, - state: &mut T, + state: &mut CachedState, resources_manager: &mut ExecutionResourcesManager, block_context: &BlockContext, - ) -> Result, TransactionError> - where - T: State + StateReader, - { + ) -> Result, TransactionError> { if self.entry_point_selector != *EXECUTE_ENTRY_POINT_SELECTOR { return Ok(None); } @@ -177,13 +174,14 @@ impl InvokeFunction { 0, ); - let call_info = Some(call.execute( + let ExecutionResult { call_info, .. } = call.execute( state, block_context, resources_manager, &mut self.get_execution_context(block_context.validate_max_n_steps)?, false, - )?); + block_context.validate_max_n_steps, + )?; let call_info = verify_no_calls_to_other_contracts(&call_info) .map_err(|_| TransactionError::InvalidContractCall)?; @@ -193,16 +191,13 @@ impl InvokeFunction { /// Builds the transaction execution context and executes the entry point. /// Returns the CallInfo. - fn run_execute_entrypoint( + fn run_execute_entrypoint( &self, - state: &mut T, + state: &mut CachedState, block_context: &BlockContext, resources_manager: &mut ExecutionResourcesManager, remaining_gas: u128, - ) -> Result - where - T: State + StateReader, - { + ) -> Result { let call = ExecutionEntryPoint::new( self.contract_address.clone(), self.calldata.clone(), @@ -218,7 +213,8 @@ impl InvokeFunction { block_context, resources_manager, &mut self.get_execution_context(block_context.invoke_tx_max_n_steps)?, - false, + true, + block_context.invoke_tx_max_n_steps, ) } @@ -228,28 +224,29 @@ impl InvokeFunction { /// - state: A state that implements the [`State`] and [`StateReader`] traits. /// - block_context: The block's execution context. /// - remaining_gas: The amount of gas that the transaction disposes. - pub fn apply( + pub fn apply( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, remaining_gas: u128, - ) -> Result - where - S: State + StateReader, - { + ) -> Result { let mut resources_manager = ExecutionResourcesManager::default(); let validate_info = self.run_validate_entrypoint(state, &mut resources_manager, block_context)?; // Execute transaction - let call_info = if self.skip_execute { - None + let ExecutionResult { + call_info, + revert_error, + .. + } = if self.skip_execute { + ExecutionResult::default() } else { - Some(self.run_execute_entrypoint( + self.run_execute_entrypoint( state, block_context, &mut resources_manager, remaining_gas, - )?) + )? }; let changes = state.count_actual_storage_changes(); let actual_resources = calculate_tx_resources( @@ -262,21 +259,19 @@ impl InvokeFunction { let transaction_execution_info = TransactionExecutionInfo::new_without_fee_info( validate_info, call_info, + revert_error, actual_resources, Some(self.tx_type), ); Ok(transaction_execution_info) } - fn charge_fee( + fn charge_fee( &self, - state: &mut S, + state: &mut CachedState, resources: &HashMap, block_context: &BlockContext, - ) -> Result - where - S: State + StateReader, - { + ) -> Result { if self.max_fee.is_zero() { return Ok((None, 0)); } @@ -308,9 +303,9 @@ impl InvokeFunction { /// - state: A state that implements the [`State`] and [`StateReader`] traits. /// - block_context: The block's execution context. /// - remaining_gas: The amount of gas that the transaction disposes. - pub fn execute( + pub fn execute( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, remaining_gas: u128, ) -> Result { @@ -424,7 +419,7 @@ mod tests { state::cached_state::CachedState, state::in_memory_state_reader::InMemoryStateReader, }; use num_traits::Num; - use std::collections::HashMap; + use std::{collections::HashMap, sync::Arc}; #[test] fn test_invoke_apply_without_fees() { @@ -465,7 +460,7 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); @@ -533,7 +528,7 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); @@ -597,7 +592,7 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); @@ -655,7 +650,7 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); @@ -719,7 +714,7 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); @@ -777,7 +772,7 @@ mod tests { skip_fee_transfer: false, }; - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); @@ -837,7 +832,7 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); @@ -901,7 +896,7 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); @@ -963,7 +958,7 @@ mod tests { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); diff --git a/src/transaction/l1_handler.rs b/src/transaction/l1_handler.rs index 4ee334684..f6f8b0ad2 100644 --- a/src/transaction/l1_handler.rs +++ b/src/transaction/l1_handler.rs @@ -1,4 +1,8 @@ -use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; +use crate::{ + execution::execution_entry_point::ExecutionResult, + services::api::contract_classes::deprecated_contract_class::EntryPointType, + state::cached_state::CachedState, +}; use cairo_vm::felt::Felt252; use getset::Getters; use num_traits::Zero; @@ -89,15 +93,12 @@ impl L1Handler { } /// Applies self to 'state' by executing the L1-handler entry point. - pub fn execute( + pub fn execute( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, remaining_gas: u128, - ) -> Result - where - S: State + StateReader, - { + ) -> Result { let mut resources_manager = ExecutionResourcesManager::default(); let entrypoint = ExecutionEntryPoint::new( self.contract_address.clone(), @@ -110,16 +111,21 @@ impl L1Handler { remaining_gas, ); - let call_info = if self.skip_execute { - None + let ExecutionResult { + call_info, + revert_error, + .. + } = if self.skip_execute { + ExecutionResult::default() } else { - Some(entrypoint.execute( + entrypoint.execute( state, block_context, &mut resources_manager, &mut self.get_execution_context(block_context.invoke_tx_max_n_steps)?, - false, - )?) + true, + block_context.invoke_tx_max_n_steps, + )? }; let changes = state.count_actual_storage_changes(); @@ -154,6 +160,7 @@ impl L1Handler { Ok(TransactionExecutionInfo::new_without_fee_info( None, call_info, + revert_error, actual_resources, Some(TransactionType::L1Handler), )) @@ -198,7 +205,10 @@ impl L1Handler { #[cfg(test)] mod test { - use std::collections::{HashMap, HashSet}; + use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + }; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; use cairo_vm::{ @@ -255,7 +265,7 @@ mod test { .address_to_nonce .insert(contract_address, nonce); - let mut state = CachedState::new(state_reader.clone(), None, None); + let mut state = CachedState::new(Arc::new(state_reader), None, None); // Initialize state.contract_classes state.set_contract_classes(HashMap::new()).unwrap(); @@ -321,6 +331,7 @@ mod test { gas_consumed: 0, failure_flag: false, }), + revert_error: None, fee_transfer_info: None, actual_fee: 0, actual_resources: HashMap::from([ diff --git a/src/transaction/mod.rs b/src/transaction/mod.rs index fdd924e1a..6729c40af 100644 --- a/src/transaction/mod.rs +++ b/src/transaction/mod.rs @@ -19,7 +19,7 @@ pub use verify_version::verify_version; use crate::{ definitions::block_context::BlockContext, execution::TransactionExecutionInfo, - state::state_api::{State, StateReader}, + state::{cached_state::CachedState, state_api::StateReader}, utils::Address, }; use error::TransactionError; @@ -66,9 +66,9 @@ impl Transaction { ///- state: a structure that implements State and StateReader traits. ///- block_context: The block context of the transaction that is about to be executed. ///- remaining_gas: The gas supplied to execute the transaction. - pub fn execute( + pub fn execute( &self, - state: &mut S, + state: &mut CachedState, block_context: &BlockContext, remaining_gas: u128, ) -> Result { diff --git a/src/utils.rs b/src/utils.rs index 0333baca5..f43218b28 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,14 +1,13 @@ use crate::core::errors::hash_errors::HashError; use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType; +use crate::state::state_api::State; use crate::{ definitions::transaction_type::TransactionType, execution::{ gas_usage::calculate_tx_gas_usage, os_usage::get_additional_os_resources, CallInfo, }, state::ExecutionResourcesManager, - state::{ - cached_state::UNINITIALIZED_CLASS_HASH, state_api::StateReader, state_cache::StorageEntry, - }, + state::{cached_state::UNINITIALIZED_CLASS_HASH, state_cache::StorageEntry}, syscalls::syscall_handler_errors::SyscallHandlerError, transaction::error::TransactionError, }; @@ -260,7 +259,7 @@ where //* Execution entry point utils //* ---------------------------- -pub fn get_deployed_address_class_hash_at_address( +pub fn get_deployed_address_class_hash_at_address( state: &mut S, contract_address: &Address, ) -> Result { @@ -277,7 +276,7 @@ pub fn get_deployed_address_class_hash_at_address( Ok(class_hash) } -pub fn validate_contract_deployed( +pub fn validate_contract_deployed( state: &mut S, contract_address: &Address, ) -> Result { diff --git a/tests/cairo_1_syscalls.rs b/tests/cairo_1_syscalls.rs index 9b8a3fe35..cdf711386 100644 --- a/tests/cairo_1_syscalls.rs +++ b/tests/cairo_1_syscalls.rs @@ -1,4 +1,7 @@ -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use cairo_lang_starknet::casm_contract_class::CasmContractClass; use cairo_vm::{ @@ -53,7 +56,7 @@ fn storage_write_read() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -101,6 +104,7 @@ fn storage_write_read() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); @@ -121,9 +125,10 @@ fn storage_write_read() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - assert_eq!(call_info.retdata, [25.into()]); + assert_eq!(call_info.call_info.unwrap().retdata, [25.into()]); // RUN INCREASE_BALANCE // Create an execution entry point @@ -142,6 +147,7 @@ fn storage_write_read() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); @@ -162,9 +168,10 @@ fn storage_write_read() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - assert_eq!(call_info.retdata, [125.into()]) + assert_eq!(call_info.call_info.unwrap().retdata, [125.into()]) } #[test] @@ -217,7 +224,7 @@ fn library_call() { .insert(lib_address, lib_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // Create an execution entry point let calldata = [25.into(), Felt252::from_bytes_be(&lib_class_hash)].to_vec(); @@ -314,7 +321,10 @@ fn library_call() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps() ) + .unwrap() + .call_info .unwrap(), expected_call_info ); @@ -380,7 +390,7 @@ fn call_contract_storage_write_read() { .insert(simple_wallet_address.clone(), simple_wallet_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -432,6 +442,7 @@ fn call_contract_storage_write_read() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); @@ -454,9 +465,10 @@ fn call_contract_storage_write_read() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - assert_eq!(call_info.retdata, [25.into()]); + assert_eq!(call_info.call_info.unwrap().retdata, [25.into()]); // RUN INCREASE_BALANCE // Create an execution entry point @@ -477,6 +489,7 @@ fn call_contract_storage_write_read() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); @@ -499,9 +512,10 @@ fn call_contract_storage_write_read() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - assert_eq!(call_info.retdata, [125.into()]) + assert_eq!(call_info.call_info.unwrap().retdata, [125.into()]) } #[test] @@ -532,7 +546,7 @@ fn emit_event() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // Create an execution entry point let calldata = [].to_vec(); @@ -569,10 +583,11 @@ fn emit_event() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); assert_eq!( - call_info.events, + call_info.call_info.unwrap().events, vec![ OrderedEvent { order: 0, @@ -645,7 +660,7 @@ fn deploy_cairo1_from_cairo1() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // arguments of deploy contract let calldata: Vec<_> = [test_felt_hash, salt].to_vec(); @@ -685,6 +700,7 @@ fn deploy_cairo1_from_cairo1() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ); assert!(call_info.is_ok()); @@ -743,7 +759,7 @@ fn deploy_cairo0_from_cairo1_without_constructor() { // Create state from the state_reader and contract cache. let mut state = CachedState::new( - state_reader, + Arc::new(state_reader), Some(contract_class_cache), Some(casm_contract_class_cache), ); @@ -786,6 +802,7 @@ fn deploy_cairo0_from_cairo1_without_constructor() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ); assert!(call_info.is_ok()); @@ -843,7 +860,7 @@ fn deploy_cairo0_from_cairo1_with_constructor() { // Create state from the state_reader and contract cache. let mut state = CachedState::new( - state_reader, + Arc::new(state_reader), Some(contract_class_cache), Some(casm_contract_class_cache), ); @@ -886,6 +903,7 @@ fn deploy_cairo0_from_cairo1_with_constructor() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ); assert!(call_info.is_ok()); @@ -943,8 +961,8 @@ fn deploy_cairo0_and_invoke() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new( - state_reader, + let mut state: CachedState<_> = CachedState::new( + Arc::new(state_reader), Some(contract_class_cache), Some(casm_contract_class_cache), ); @@ -987,6 +1005,7 @@ fn deploy_cairo0_and_invoke() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ); assert!(call_info.is_ok()); @@ -1027,10 +1046,11 @@ fn deploy_cairo0_and_invoke() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - let retdata = call_info.retdata; + let retdata = call_info.call_info.unwrap().retdata; // expected result 3! = 6 assert_eq!(retdata, [6.into()].to_vec()); @@ -1065,7 +1085,7 @@ fn test_send_message_to_l1_syscall() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); let create_execute_extrypoint = |selector: &BigUint, calldata: Vec, @@ -1112,6 +1132,7 @@ fn test_send_message_to_l1_syscall() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); @@ -1140,7 +1161,7 @@ fn test_send_message_to_l1_syscall() { ..Default::default() }; - assert_eq!(call_info, expected_call_info); + assert_eq!(call_info.call_info.unwrap(), expected_call_info); } #[test] @@ -1171,7 +1192,7 @@ fn test_get_execution_info() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -1218,6 +1239,7 @@ fn test_get_execution_info() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); @@ -1247,7 +1269,7 @@ fn test_get_execution_info() { ..Default::default() }; - assert_eq!(call_info, expected_call_info); + assert_eq!(call_info.call_info.unwrap(), expected_call_info); } #[test] @@ -1290,7 +1312,7 @@ fn replace_class_internal() { contract_class_cache.insert(class_hash_b, contract_class_b.clone()); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // Run upgrade entrypoint and check that the storage was updated with the new contract class // Create an execution entry point @@ -1329,6 +1351,7 @@ fn replace_class_internal() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); // Check that the class was indeed replaced in storage @@ -1411,7 +1434,7 @@ fn replace_class_contract_call() { .insert(wrapper_address, nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // INITIALIZE STARKNET CONFIG let block_context = BlockContext::default(); @@ -1450,9 +1473,10 @@ fn replace_class_contract_call() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - assert_eq!(result.retdata, vec![25.into()]); + assert_eq!(result.call_info.unwrap().retdata, vec![25.into()]); // REPLACE_CLASS @@ -1476,6 +1500,7 @@ fn replace_class_contract_call() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); @@ -1501,9 +1526,10 @@ fn replace_class_contract_call() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - assert_eq!(result.retdata, vec![17.into()]); + assert_eq!(result.call_info.unwrap().retdata, vec![17.into()]); } #[test] @@ -1574,7 +1600,7 @@ fn replace_class_contract_call_same_transaction() { .insert(wrapper_address, nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // INITIALIZE STARKNET CONFIG let block_context = BlockContext::default(); @@ -1613,9 +1639,13 @@ fn replace_class_contract_call_same_transaction() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - assert_eq!(result.retdata, vec![25.into(), 17.into()]); + assert_eq!( + result.call_info.unwrap().retdata, + vec![25.into(), 17.into()] + ); } #[test] @@ -1686,7 +1716,7 @@ fn call_contract_upgrade_cairo_0_to_cairo_1_same_transaction() { // Create state from the state_reader and contract cache. let mut state = CachedState::new( - state_reader, + Arc::new(state_reader), Some(deprecated_contract_class_cache), Some(casm_contract_class_cache), ); @@ -1728,9 +1758,13 @@ fn call_contract_upgrade_cairo_0_to_cairo_1_same_transaction() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - assert_eq!(result.retdata, vec![33.into(), 17.into()]); + assert_eq!( + result.call_info.unwrap().retdata, + vec![33.into(), 17.into()] + ); } #[test] @@ -1799,7 +1833,7 @@ fn call_contract_downgrade_cairo_1_to_cairo_0_same_transaction() { // Create state from the state_reader and contract cache. let mut state = CachedState::new( - state_reader, + Arc::new(state_reader), Some(deprecated_contract_class_cache), Some(casm_contract_class_cache), ); @@ -1841,9 +1875,13 @@ fn call_contract_downgrade_cairo_1_to_cairo_0_same_transaction() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - assert_eq!(result.retdata, vec![17.into(), 33.into()]); + assert_eq!( + result.call_info.unwrap().retdata, + vec![17.into(), 33.into()] + ); } #[test] @@ -1908,7 +1946,7 @@ fn call_contract_replace_class_cairo_0() { // Create state from the state_reader and contract cache. let mut state = CachedState::new( - state_reader, + Arc::new(state_reader), Some(deprecated_contract_class_cache), Some(casm_contract_class_cache), ); @@ -1950,9 +1988,13 @@ fn call_contract_replace_class_cairo_0() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - assert_eq!(result.retdata, vec![64.into(), 33.into()]); + assert_eq!( + result.call_info.unwrap().retdata, + vec![64.into(), 33.into()] + ); } #[test] @@ -1983,7 +2025,7 @@ fn test_out_of_gas_failure() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // Create an execution entry point let calldata = [].to_vec(); @@ -2021,8 +2063,10 @@ fn test_out_of_gas_failure() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); + let call_info = call_info.call_info.unwrap(); assert_eq!( call_info.retdata, vec![Felt252::from_bytes_be("Out of gas".as_bytes())] @@ -2058,7 +2102,7 @@ fn deploy_syscall_failure_uninitialized_class_hash() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // Create an execution entry point let calldata = [Felt252::zero()].to_vec(); @@ -2095,10 +2139,11 @@ fn deploy_syscall_failure_uninitialized_class_hash() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); assert_eq!( - std::str::from_utf8(&call_info.retdata[0].to_be_bytes()) + std::str::from_utf8(&call_info.call_info.unwrap().retdata[0].to_be_bytes()) .unwrap() .trim_start_matches('\0'), "CLASS_HASH_NOT_FOUND" @@ -2142,7 +2187,7 @@ fn deploy_syscall_failure_in_constructor() { contract_class_cache.insert(f_c_class_hash.to_be_bytes(), f_c_contract_class); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // Create an execution entry point let calldata = [f_c_class_hash].to_vec(); @@ -2179,12 +2224,13 @@ fn deploy_syscall_failure_in_constructor() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); // Check that we get the error from the constructor // assert( 1 == 0 , 'Oops'); assert_eq!( - std::str::from_utf8(&call_info.retdata[0].to_be_bytes()) + std::str::from_utf8(&call_info.call_info.unwrap().retdata[0].to_be_bytes()) .unwrap() .trim_start_matches('\0'), "Oops" @@ -2219,7 +2265,7 @@ fn storage_read_no_value() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -2267,10 +2313,11 @@ fn storage_read_no_value() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); // As the value doesn't exist in storage, it's value will be 0 - assert_eq!(call_info.retdata, [0.into()]); + assert_eq!(call_info.call_info.unwrap().retdata, [0.into()]); } #[test] @@ -2303,7 +2350,7 @@ fn storage_read_unavailable_address_domain() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -2351,11 +2398,12 @@ fn storage_read_unavailable_address_domain() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); assert_eq!( - call_info.retdata[0], + call_info.call_info.unwrap().retdata[0], Felt252::from_bytes_be(b"Unsupported address domain") ); } @@ -2390,7 +2438,7 @@ fn storage_write_unavailable_address_domain() { .insert(address.clone(), nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); let block_context = BlockContext::default(); let mut tx_execution_context = TransactionExecutionContext::new( @@ -2438,11 +2486,12 @@ fn storage_write_unavailable_address_domain() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); assert_eq!( - call_info.retdata[0], + call_info.call_info.unwrap().retdata[0], Felt252::from_bytes_be(b"Unsupported address domain") ); } @@ -2495,7 +2544,7 @@ fn library_call_failure() { .insert(lib_address, lib_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // Create an execution entry point let calldata = [25.into(), Felt252::from_bytes_be(&lib_class_hash)].to_vec(); @@ -2538,8 +2587,12 @@ fn library_call_failure() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); + + let call_info = call_info.call_info.unwrap(); + assert_eq!( std::str::from_utf8(&call_info.retdata[0].to_be_bytes()) .unwrap() @@ -2600,7 +2653,7 @@ fn send_messages_to_l1_different_contract_calls() { .insert(send_msg_address, send_msg_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, None, Some(contract_class_cache)); + let mut state = CachedState::new(Arc::new(state_reader), None, Some(contract_class_cache)); // Create an execution entry point let calldata = [25.into(), 50.into(), 75.into()].to_vec(); @@ -2638,9 +2691,14 @@ fn send_messages_to_l1_different_contract_calls() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - let l1_to_l2_messages = call_info.get_sorted_l2_to_l1_messages().unwrap(); + let l1_to_l2_messages = call_info + .call_info + .unwrap() + .get_sorted_l2_to_l1_messages() + .unwrap(); assert_eq!( l1_to_l2_messages, vec![ @@ -2714,7 +2772,7 @@ fn send_messages_to_l1_different_contract_calls_cairo1_to_cairo0() { // Create state from the state_reader and contract cache. let mut state = CachedState::new( - state_reader, + Arc::new(state_reader), Some(deprecated_contract_class_cache), Some(contract_class_cache), ); @@ -2755,9 +2813,14 @@ fn send_messages_to_l1_different_contract_calls_cairo1_to_cairo0() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - let l1_to_l2_messages = call_info.get_sorted_l2_to_l1_messages().unwrap(); + let l1_to_l2_messages = call_info + .call_info + .unwrap() + .get_sorted_l2_to_l1_messages() + .unwrap(); assert_eq!( l1_to_l2_messages, vec![ @@ -2829,7 +2892,7 @@ fn send_messages_to_l1_different_contract_calls_cairo0_to_cairo1() { // Create state from the state_reader and contract cache. let mut state = CachedState::new( - state_reader, + Arc::new(state_reader), Some(deprecated_contract_class_cache), Some(contract_class_cache), ); @@ -2870,9 +2933,14 @@ fn send_messages_to_l1_different_contract_calls_cairo0_to_cairo1() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); - let l1_to_l2_messages = call_info.get_sorted_l2_to_l1_messages().unwrap(); + let l1_to_l2_messages = call_info + .call_info + .unwrap() + .get_sorted_l2_to_l1_messages() + .unwrap(); assert_eq!( l1_to_l2_messages, vec![ diff --git a/tests/complex_contracts/amm_contracts/amm.rs b/tests/complex_contracts/amm_contracts/amm.rs index a6e52408b..bbab48eb9 100644 --- a/tests/complex_contracts/amm_contracts/amm.rs +++ b/tests/complex_contracts/amm_contracts/amm.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use cairo_vm::vm::runners::builtin_runner::HASH_BUILTIN_NAME; use cairo_vm::vm::runners::cairo_runner::ExecutionResources; @@ -54,7 +55,7 @@ fn swap(calldata: &[Felt252], call_config: &mut CallConfig) -> Result CachedState { let state_cache = ContractClassCache::new(); CachedState::new( - in_memory_state_reader, + Arc::new(in_memory_state_reader), Some(state_cache), Some(HashMap::new()), ) @@ -230,7 +231,7 @@ fn expected_state_after_tx(fee: u128) -> CachedState { ]); CachedState::new_for_testing( - in_memory_state_reader, + Arc::new(in_memory_state_reader), Some(contract_classes_cache), state_cache_after_invoke_tx(fee), Some(HashMap::new()), @@ -514,7 +515,7 @@ fn validate_final_balances( #[test] fn test_create_account_tx_test_state() { - let (block_context, mut state) = create_account_tx_test_state().unwrap(); + let (block_context, state) = create_account_tx_test_state().unwrap(); assert_eq!(state, expected_state_before_tx()); @@ -845,6 +846,7 @@ fn test_declare_tx() { ..Default::default() }), None, + None, Some(expected_declare_fee_transfer_info(fee)), fee, resources, @@ -901,6 +903,7 @@ fn test_declarev2_tx() { ..Default::default() }), None, + None, Some(expected_declare_fee_transfer_info(fee)), fee, resources, @@ -1096,6 +1099,7 @@ fn expected_transaction_execution_info(block_context: &BlockContext) -> Transact TransactionExecutionInfo::new( Some(expected_validate_call_info_2()), Some(expected_execute_call_info()), + None, Some(expected_fee_transfer_info(fee)), fee, resources, @@ -1115,6 +1119,7 @@ fn expected_fib_transaction_execution_info( TransactionExecutionInfo::new( Some(expected_fib_validate_call_info_2()), Some(expected_fib_execute_call_info()), + None, Some(expected_fib_fee_transfer_info(fee)), fee, resources, @@ -1243,7 +1248,13 @@ fn test_deploy_account() { .execute(&mut state, &block_context) .unwrap(); - assert_eq!(state, state_after); + assert_eq!(state.contract_classes(), state_after.contract_classes()); + assert_eq!( + state.casm_contract_classes(), + state_after.casm_contract_classes() + ); + assert_eq!(state.state_reader, state_after.state_reader); + assert_eq!(state.cache(), state_after.cache()); let expected_validate_call_info = expected_validate_call_info( VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR.clone(), @@ -1288,6 +1299,7 @@ fn test_deploy_account() { let expected_execution_info = TransactionExecutionInfo::new( expected_validate_call_info.into(), expected_execute_call_info.into(), + None, expected_fee_transfer_call_info.into(), fee, // Entry **not** in blockifier. @@ -1318,7 +1330,7 @@ fn expected_deploy_account_states() -> ( ) { let fee = Felt252::from(3684); let mut state_before = CachedState::new( - InMemoryStateReader::new( + Arc::new(InMemoryStateReader::new( HashMap::from([ (Address(0x101.into()), felt_to_hash(&0x111.into())), (Address(0x100.into()), felt_to_hash(&0x110.into())), @@ -1352,7 +1364,7 @@ fn expected_deploy_account_states() -> ( ]), HashMap::new(), HashMap::new(), - ), + )), Some(ContractClassCache::new()), Some(HashMap::new()), ); @@ -1484,11 +1496,11 @@ fn test_state_for_declare_tx() { .is_one()); // Check state.state_reader - let mut state_reader = state.state_reader().clone(); + let state_reader = state.state_reader.clone(); assert_eq!( - state_reader.address_to_class_hash_mut(), - &mut HashMap::from([ + state_reader.address_to_class_hash, + HashMap::from([ ( TEST_ERC20_CONTRACT_ADDRESS.clone(), felt_to_hash(&TEST_ERC20_CONTRACT_CLASS_HASH) @@ -1505,8 +1517,8 @@ fn test_state_for_declare_tx() { ); assert_eq!( - state_reader.address_to_nonce_mut(), - &mut HashMap::from([ + state_reader.address_to_nonce, + HashMap::from([ (TEST_ERC20_CONTRACT_ADDRESS.clone(), Felt252::zero()), (TEST_CONTRACT_ADDRESS.clone(), Felt252::zero()), (TEST_ACCOUNT_CONTRACT_ADDRESS.clone(), Felt252::zero()), @@ -1514,8 +1526,8 @@ fn test_state_for_declare_tx() { ); assert_eq!( - state_reader.address_to_storage_mut(), - &mut HashMap::from([( + state_reader.address_to_storage, + HashMap::from([( ( TEST_ERC20_CONTRACT_ADDRESS.clone(), felt_to_hash(&TEST_ERC20_ACCOUNT_BALANCE_KEY) @@ -1525,8 +1537,8 @@ fn test_state_for_declare_tx() { ); assert_eq!( - state_reader.class_hash_to_contract_class_mut(), - &mut HashMap::from([ + state_reader.class_hash_to_contract_class, + HashMap::from([ ( felt_to_hash(&TEST_ERC20_CONTRACT_CLASS_HASH), ContractClass::from_path(ERC20_CONTRACT_PATH).unwrap() @@ -1837,6 +1849,7 @@ fn test_library_call_with_declare_v2() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); @@ -1899,5 +1912,5 @@ fn test_library_call_with_declare_v2() { ..Default::default() }; - assert_eq!(call_info, expected_call_info); + assert_eq!(call_info.call_info.unwrap(), expected_call_info); } diff --git a/tests/storage.rs b/tests/storage.rs index 4428980c2..19d03359e 100644 --- a/tests/storage.rs +++ b/tests/storage.rs @@ -12,6 +12,7 @@ use starknet_in_rust::{ state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager}, utils::{calculate_sn_keccak, Address}, }; +use std::sync::Arc; use std::{ collections::{HashMap, HashSet}, path::PathBuf, @@ -65,7 +66,7 @@ fn integration_storage_test() { //* Create state with previous data //* --------------------------------------- - let mut state = CachedState::new(state_reader, Some(contract_class_cache), None); + let mut state = CachedState::new(Arc::new(state_reader), Some(contract_class_cache), None); //* ------------------------------------ //* Create execution entry point @@ -132,7 +133,10 @@ fn integration_storage_test() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps() ) + .unwrap() + .call_info .unwrap(), expected_call_info ); diff --git a/tests/syscalls.rs b/tests/syscalls.rs index cef39b230..3cbc0417b 100644 --- a/tests/syscalls.rs +++ b/tests/syscalls.rs @@ -34,6 +34,7 @@ use std::{ collections::{HashMap, HashSet}, iter::empty, path::{Path, PathBuf}, + sync::Arc, }; #[allow(clippy::too_many_arguments)] @@ -115,7 +116,7 @@ fn test_contract<'a>( Some(contract_class_cache) }; - let mut state = CachedState::new(state_reader, contract_class_cache, None); + let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache, None); storage_entries .into_iter() .for_each(|(a, b, c)| state.set_storage_at(&(a, b), c)); @@ -143,8 +144,11 @@ fn test_contract<'a>( &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) - .expect("Could not execute contract"); + .expect("Could not execute contract") + .call_info + .unwrap(); assert_eq!(result.contract_address, contract_address); assert_eq!(result.contract_address, contract_address); @@ -1096,7 +1100,7 @@ fn deploy_cairo1_from_cairo0_with_constructor() { // Create state from the state_reader and contract cache. let mut state = CachedState::new( - state_reader, + Arc::new(state_reader), Some(contract_class_cache), Some(casm_contract_class_cache), ); @@ -1139,6 +1143,7 @@ fn deploy_cairo1_from_cairo0_with_constructor() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ); assert!(call_info.is_ok()); @@ -1196,7 +1201,7 @@ fn deploy_cairo1_from_cairo0_without_constructor() { // Create state from the state_reader and contract cache. let mut state = CachedState::new( - state_reader, + Arc::new(state_reader), Some(contract_class_cache), Some(casm_contract_class_cache), ); @@ -1240,6 +1245,7 @@ fn deploy_cairo1_from_cairo0_without_constructor() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) .unwrap(); @@ -1298,7 +1304,7 @@ fn deploy_cairo1_and_invoke() { // Create state from the state_reader and contract cache. let mut state = CachedState::new( - state_reader, + Arc::new(state_reader), Some(contract_class_cache), Some(casm_contract_class_cache), ); @@ -1341,6 +1347,7 @@ fn deploy_cairo1_and_invoke() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ); assert!(call_info.is_ok()); @@ -1379,7 +1386,10 @@ fn deploy_cairo1_and_invoke() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) + .unwrap() + .call_info .unwrap(); let retdata = call_info.retdata; @@ -1431,7 +1441,11 @@ fn send_messages_to_l1_different_contract_calls() { .insert(send_msg_address, send_msg_nonce); // Create state from the state_reader and contract cache. - let mut state = CachedState::new(state_reader, Some(deprecated_contract_class_cache), None); + let mut state = CachedState::new( + Arc::new(state_reader), + Some(deprecated_contract_class_cache), + None, + ); // Create an execution entry point let calldata = [25.into(), 50.into(), 75.into()].to_vec(); @@ -1469,7 +1483,10 @@ fn send_messages_to_l1_different_contract_calls() { &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ) + .unwrap() + .call_info .unwrap(); let l1_to_l2_messages = call_info.get_sorted_l2_to_l1_messages().unwrap(); assert_eq!( diff --git a/tests/syscalls_errors.rs b/tests/syscalls_errors.rs index efbef0b5f..17d85bb1f 100644 --- a/tests/syscalls_errors.rs +++ b/tests/syscalls_errors.rs @@ -18,6 +18,7 @@ use starknet_in_rust::{ utils::{calculate_sn_keccak, Address, ClassHash}, }; use std::path::Path; +use std::sync::Arc; use assert_matches::assert_matches; @@ -96,7 +97,7 @@ fn test_contract<'a>( Some(contract_class_cache) }; - let mut state = CachedState::new(state_reader, contract_class_cache, None); + let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache, None); storage_entries .into_iter() .for_each(|(a, b, c)| state.set_storage_at(&(a, b), c)); @@ -123,6 +124,7 @@ fn test_contract<'a>( &mut resources_manager, &mut tx_execution_context, false, + block_context.invoke_tx_max_n_steps(), ); assert_matches!(result, Err(e) if e.to_string().contains(error_msg));