Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Commit 6046d75

Browse files
xqftjuanbonoedg-l
authored andcommitted
Fix transactions bypassing the max_fee by introducing new revert logic (#901)
* Make tx fail when actual_fee exceeds max_fee * Changed test * Formatting * Fix logic * Leave fail only without charging * Change test * Fix test broken by better fee calc * Fixed test fee * Update fee on test_deploy_account * Remove comment * Added fee transfer * Test with invoke * Added revert logic for invoke * Modify tests, add fixes * Add revert error * Fix test_invoke_tx_account * Fixed test_invoke_tx_exceeded_max_fee * Fix test_get_nonce_at * Rely on another contract * Introduced transactional state (#917) * Introduced transactional state * WIP * Fixed the rest of tests * Replaced old revert logic from entrypoint exec * depl acc revert test * Remove update writes fix * WIP Fixed many tests * fix test * fix more tests * more fixes * fix another test * fix latest test * name * remove comment * merge * unignore * format * vis * need to be pub for tests * fix test * format * use the count_actual_storage_changes impl from cached state * fix bug * fix tests --------- Co-authored-by: Juan Bono <juanbono94@gmail.com> Co-authored-by: Edgar Luque <git@edgarluque.com>
1 parent affa012 commit 6046d75

21 files changed

+584
-138
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ serde_json_pythonic = { git = "https://github.com/xJonathanLEI/serde_json_python
6161
[dev-dependencies]
6262
assert_matches = "1.5.0"
6363
coverage-helper = "0.1.0"
64+
pretty_assertions_sorted = "1.2.3"
6465

6566
[[bench]]
6667
path = "bench/internals.rs"

cli/src/main.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use starknet_in_rust::{
2727
compiled_class::CompiledClass, deprecated_contract_class::ContractClass,
2828
},
2929
state::{cached_state::CachedState, state_api::State},
30-
state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager},
30+
state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager, StateDiff},
3131
transaction::{error::TransactionError, InvokeFunction},
3232
utils::{felt_to_hash, string_to_hash, Address},
3333
};
@@ -200,7 +200,9 @@ fn invoke_parser(
200200
Some(Felt252::zero()),
201201
transaction_hash.unwrap(),
202202
)?;
203-
let _tx_info = internal_invoke.apply(cached_state, &BlockContext::default(), 0)?;
203+
let mut transactional_state = cached_state.create_transactional();
204+
let _tx_info = internal_invoke.apply(&mut transactional_state, &BlockContext::default(), 0)?;
205+
cached_state.apply_state_update(&StateDiff::from_cached_state(transactional_state)?)?;
204206

205207
let tx_hash = calculate_transaction_hash_common(
206208
TransactionHashPrefix::Invoke,

rpc_state_reader/tests/blockifier_tests.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,29 @@ fn blockifier_test_case_tx(hash: &str, block_number: u64, chain: RpcChain) {
331331
.len()
332332
);
333333
}
334+
335+
#[test_case(
336+
"0x00b6d59c19d5178886b4c939656167db0660fe325345138025a3cc4175b21897",
337+
200303, // real block 200304
338+
RpcChain::MainNet
339+
)]
340+
#[test_case(
341+
"0x02b28b4846a756e0cec6385d6d13f811e745a88c7e75a3ebc5fead5b4af152a3",
342+
200302, // real block 200304
343+
RpcChain::MainNet
344+
=> ignore["broken on both due to a cairo-vm error"]
345+
)]
346+
fn blockifier_test_case_reverted_tx(hash: &str, block_number: u64, chain: RpcChain) {
347+
let (tx_info, trace, receipt) = execute_tx(hash, chain, BlockNumber(block_number));
348+
349+
assert_eq!(tx_info.revert_error.is_some(), trace.revert_error.is_some());
350+
351+
let diff = 100 * receipt.actual_fee.abs_diff(tx_info.actual_fee.0) / receipt.actual_fee;
352+
353+
if diff >= 5 {
354+
assert_eq!(
355+
tx_info.actual_fee.0, receipt.actual_fee,
356+
"actual_fee mismatch differs from the baseline by more than 5% ({diff}%)",
357+
);
358+
}
359+
}

rpc_state_reader/tests/sir_tests.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,29 @@ fn starknet_in_rust_test_case_tx(hash: &str, block_number: u64, chain: RpcChain)
293293
}
294294
}
295295
}
296+
297+
#[test_case(
298+
"0x00b6d59c19d5178886b4c939656167db0660fe325345138025a3cc4175b21897",
299+
200303, // real block 200304
300+
RpcChain::MainNet
301+
)]
302+
#[test_case(
303+
"0x02b28b4846a756e0cec6385d6d13f811e745a88c7e75a3ebc5fead5b4af152a3",
304+
200302, // real block 200304
305+
RpcChain::MainNet
306+
=> ignore["broken on both due to a cairo-vm error"]
307+
)]
308+
fn starknet_in_rust_test_case_reverted_tx(hash: &str, block_number: u64, chain: RpcChain) {
309+
let (tx_info, trace, receipt) = execute_tx(hash, chain, BlockNumber(block_number));
310+
311+
assert_eq!(tx_info.revert_error.is_some(), trace.revert_error.is_some());
312+
313+
let diff = 100 * receipt.actual_fee.abs_diff(tx_info.actual_fee) / receipt.actual_fee;
314+
315+
if diff >= 5 {
316+
assert_eq!(
317+
tx_info.actual_fee, receipt.actual_fee,
318+
"actual_fee mismatch differs from the baseline by more than 5% ({diff}%)",
319+
);
320+
}
321+
}

src/execution/execution_entry_point.rs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use crate::services::api::contract_classes::deprecated_contract_class::{
44
ContractEntryPoint, EntryPointType,
55
};
66
use crate::state::cached_state::CachedState;
7-
use crate::state::StateDiff;
87
use crate::{
98
definitions::{block_context::BlockContext, constants::DEFAULT_ENTRY_POINT_SELECTOR},
109
runner::StarknetRunner,
@@ -128,28 +127,20 @@ impl ExecutionEntryPoint {
128127
})
129128
}
130129
CompiledClass::Casm(contract_class) => {
131-
let mut tmp_state =
132-
CachedState::new(state.state_reader.clone(), state.contract_classes.clone());
133-
tmp_state.cache = state.cache.clone();
134-
135130
match self._execute(
136-
&mut tmp_state,
131+
state,
137132
resources_manager,
138133
block_context,
139134
tx_execution_context,
140135
contract_class,
141136
class_hash,
142137
support_reverted,
143138
) {
144-
Ok(call_info) => {
145-
let state_diff = StateDiff::from_cached_state(tmp_state)?;
146-
state.apply_state_update(&state_diff)?;
147-
Ok(ExecutionResult {
148-
call_info: Some(call_info),
149-
revert_error: None,
150-
n_reverted_steps: 0,
151-
})
152-
}
139+
Ok(call_info) => Ok(ExecutionResult {
140+
call_info: Some(call_info),
141+
revert_error: None,
142+
n_reverted_steps: 0,
143+
}),
153144
Err(e) => {
154145
if !support_reverted {
155146
return Err(e);

src/execution/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,16 @@ impl TransactionExecutionInfo {
609609

610610
Ok(sorted_messages)
611611
}
612+
613+
pub fn to_revert_error(self, revert_error: &str) -> Self {
614+
TransactionExecutionInfo {
615+
validate_info: None,
616+
call_info: None,
617+
revert_error: Some(revert_error.to_string()),
618+
fee_transfer_info: None,
619+
..self
620+
}
621+
}
612622
}
613623

614624
// --------------------

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ mod test {
249249

250250
use crate::services::api::contract_classes::compiled_class::CompiledClass;
251251
use lazy_static::lazy_static;
252+
use pretty_assertions_sorted::assert_eq;
252253

253254
lazy_static! {
254255
// include_str! doesn't seem to work in CI

src/state/cached_state.rs

Lines changed: 126 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,19 @@ impl<T: StateReader> CachedState<T> {
9999
self.contract_classes = contract_classes;
100100
Ok(())
101101
}
102+
103+
/// Creates a copy of this state with an empty cache for saving changes and applying them
104+
/// later.
105+
pub fn create_transactional(&self) -> TransactionalCachedState<T> {
106+
let state_reader = Arc::new(TransactionalCachedStateReader::new(self));
107+
CachedState {
108+
state_reader,
109+
cache: self.cache.clone(),
110+
contract_classes: self.contract_classes.clone(),
111+
cache_hits: 0,
112+
cache_misses: 0,
113+
}
114+
}
102115
}
103116

104117
impl<T: StateReader> StateReader for CachedState<T> {
@@ -134,19 +147,13 @@ impl<T: StateReader> StateReader for CachedState<T> {
134147
// TODO: check if that the proper way to store it (converting hash to address)
135148
/// Returned the compiled class hash for a given class hash.
136149
fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result<ClassHash, StateError> {
137-
if self
138-
.cache
139-
.class_hash_to_compiled_class_hash
140-
.get(class_hash)
141-
.is_none()
150+
if let Some(compiled_class_hash) =
151+
self.cache.class_hash_to_compiled_class_hash.get(class_hash)
142152
{
143-
return self.state_reader.get_compiled_class_hash(class_hash);
153+
Ok(*compiled_class_hash)
154+
} else {
155+
self.state_reader.get_compiled_class_hash(class_hash)
144156
}
145-
self.cache
146-
.class_hash_to_compiled_class_hash
147-
.get(class_hash)
148-
.ok_or_else(|| StateError::NoneCompiledClass(*class_hash))
149-
.cloned()
150157
}
151158

152159
/// Returns the contract class for a given class hash.
@@ -438,6 +445,114 @@ impl<T: StateReader> State for CachedState<T> {
438445
}
439446
}
440447

448+
/// A CachedState which has access to another, "parent" state, used for executing transactions
449+
/// without commiting changes to the parent.
450+
pub type TransactionalCachedState<'a, T> = CachedState<TransactionalCachedStateReader<'a, T>>;
451+
452+
/// State reader used for transactional states which allows to check the parent state's cache and
453+
/// state reader if a transactional cache miss happens.
454+
///
455+
/// In practice this will act as a way to access the parent state's cache and other fields,
456+
/// without referencing the whole parent state, so there's no need to adapt state-modifying
457+
/// functions in the case that a transactional state is needed.
458+
#[derive(Debug, MutGetters, Getters, PartialEq, Clone)]
459+
pub struct TransactionalCachedStateReader<'a, T: StateReader> {
460+
/// The parent state's state_reader
461+
#[get(get = "pub")]
462+
pub(crate) state_reader: Arc<T>,
463+
/// The parent state's cache
464+
#[get(get = "pub")]
465+
pub(crate) cache: &'a StateCache,
466+
/// The parent state's contract_classes
467+
#[get(get = "pub")]
468+
pub(crate) contract_classes: ContractClassCache,
469+
}
470+
471+
impl<'a, T: StateReader> TransactionalCachedStateReader<'a, T> {
472+
fn new(state: &'a CachedState<T>) -> Self {
473+
Self {
474+
state_reader: state.state_reader.clone(),
475+
cache: &state.cache,
476+
contract_classes: state.contract_classes.clone(),
477+
}
478+
}
479+
}
480+
481+
impl<'a, T: StateReader> StateReader for TransactionalCachedStateReader<'a, T> {
482+
/// Returns the class hash for a given contract address.
483+
/// Returns zero as default value if missing
484+
fn get_class_hash_at(&self, contract_address: &Address) -> Result<ClassHash, StateError> {
485+
self.cache
486+
.get_class_hash(contract_address)
487+
.map(|a| Ok(*a))
488+
.unwrap_or_else(|| self.state_reader.get_class_hash_at(contract_address))
489+
}
490+
491+
/// Returns the nonce for a given contract address.
492+
fn get_nonce_at(&self, contract_address: &Address) -> Result<Felt252, StateError> {
493+
if self.cache.get_nonce(contract_address).is_none() {
494+
return self.state_reader.get_nonce_at(contract_address);
495+
}
496+
self.cache
497+
.get_nonce(contract_address)
498+
.ok_or_else(|| StateError::NoneNonce(contract_address.clone()))
499+
.cloned()
500+
}
501+
502+
/// Returns storage data for a given storage entry.
503+
/// Returns zero as default value if missing
504+
fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result<Felt252, StateError> {
505+
self.cache
506+
.get_storage(storage_entry)
507+
.map(|v| Ok(v.clone()))
508+
.unwrap_or_else(|| self.state_reader.get_storage_at(storage_entry))
509+
}
510+
511+
// TODO: check if that the proper way to store it (converting hash to address)
512+
/// Returned the compiled class hash for a given class hash.
513+
fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result<ClassHash, StateError> {
514+
if self
515+
.cache
516+
.class_hash_to_compiled_class_hash
517+
.get(class_hash)
518+
.is_none()
519+
{
520+
return self.state_reader.get_compiled_class_hash(class_hash);
521+
}
522+
self.cache
523+
.class_hash_to_compiled_class_hash
524+
.get(class_hash)
525+
.ok_or_else(|| StateError::NoneCompiledClass(*class_hash))
526+
.cloned()
527+
}
528+
529+
/// Returns the contract class for a given class hash.
530+
fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError> {
531+
// This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes
532+
//, which can be on the cache or on the state_reader, different cases will be described below:
533+
if class_hash == UNINITIALIZED_CLASS_HASH {
534+
return Err(StateError::UninitiaizedClassHash);
535+
}
536+
537+
// I: FETCHING FROM CACHE
538+
if let Some(compiled_class) = self.contract_classes.get(class_hash) {
539+
return Ok(compiled_class.clone());
540+
}
541+
542+
// I: CASM CONTRACT CLASS : CLASS_HASH
543+
if let Some(compiled_class_hash) =
544+
self.cache.class_hash_to_compiled_class_hash.get(class_hash)
545+
{
546+
if let Some(casm_class) = self.contract_classes.get(compiled_class_hash) {
547+
return Ok(casm_class.clone());
548+
}
549+
}
550+
551+
// II: FETCHING FROM STATE_READER
552+
self.state_reader.get_contract_class(class_hash)
553+
}
554+
}
555+
441556
impl<T: StateReader> CachedState<T> {
442557
// Updates the cache's storage_initial_values according to those in storage_writes
443558
// If a key is present in the storage_writes but not in storage_initial_values,

src/state/in_memory_state_reader.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub struct InMemoryStateReader {
2525
#[getset(get_mut = "pub")]
2626
pub class_hash_to_compiled_class: HashMap<ClassHash, CompiledClass>,
2727
#[getset(get_mut = "pub")]
28-
pub(crate) class_hash_to_compiled_class_hash: HashMap<ClassHash, CompiledClassHash>,
28+
pub class_hash_to_compiled_class_hash: HashMap<ClassHash, CompiledClassHash>,
2929
}
3030

3131
impl InMemoryStateReader {

0 commit comments

Comments
 (0)