Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Commit ec0036b

Browse files
authored
Introduced transactional state (#917)
* Introduced transactional state * WIP * Fixed the rest of tests * Replaced old revert logic from entrypoint exec * depl acc revert test * Remove update writes fix
1 parent 21df2c7 commit ec0036b

File tree

12 files changed

+467
-81
lines changed

12 files changed

+467
-81
lines changed

cli/src/main.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use starknet_in_rust::{
2525
serde_structs::read_abi,
2626
services::api::contract_classes::deprecated_contract_class::ContractClass,
2727
state::{cached_state::CachedState, state_api::State},
28-
state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager},
28+
state::{in_memory_state_reader::InMemoryStateReader, ExecutionResourcesManager, StateDiff},
2929
transaction::{error::TransactionError, InvokeFunction},
3030
utils::{felt_to_hash, string_to_hash, Address},
3131
};
@@ -195,7 +195,9 @@ fn invoke_parser(
195195
Some(Felt252::zero()),
196196
transaction_hash.unwrap(),
197197
)?;
198-
let _tx_info = internal_invoke.apply(cached_state, &BlockContext::default(), 0)?;
198+
let mut transactional_state = cached_state.create_transactional();
199+
let _tx_info = internal_invoke.apply(&mut transactional_state, &BlockContext::default(), 0)?;
200+
cached_state.apply_state_update(&StateDiff::from_cached_state(transactional_state)?)?;
199201

200202
let tx_hash = calculate_transaction_hash_common(
201203
TransactionHashPrefix::Invoke,

rpc_state_reader/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ mod tests {
530530
rpc_state.get_transaction(tx_hash);
531531
}
532532

533+
#[ignore]
533534
#[test]
534535
fn test_get_block_info() {
535536
let rpc_state = RpcState::new(

src/execution/execution_entry_point.rs

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ use crate::services::api::contract_classes::deprecated_contract_class::{
22
ContractEntryPoint, EntryPointType,
33
};
44
use crate::state::cached_state::CachedState;
5-
use crate::state::StateDiff;
65
use crate::{
76
definitions::{block_context::BlockContext, constants::DEFAULT_ENTRY_POINT_SELECTOR},
87
runner::StarknetRunner,
@@ -126,31 +125,20 @@ impl ExecutionEntryPoint {
126125
})
127126
}
128127
CompiledClass::Casm(contract_class) => {
129-
let mut tmp_state = CachedState::new(
130-
state.state_reader.clone(),
131-
state.contract_classes.clone(),
132-
state.casm_contract_classes.clone(),
133-
);
134-
tmp_state.cache = state.cache.clone();
135-
136128
match self._execute(
137-
&mut tmp_state,
129+
state,
138130
resources_manager,
139131
block_context,
140132
tx_execution_context,
141133
contract_class,
142134
class_hash,
143135
support_reverted,
144136
) {
145-
Ok(call_info) => {
146-
let state_diff = StateDiff::from_cached_state(tmp_state)?;
147-
state.apply_state_update(&state_diff)?;
148-
Ok(ExecutionResult {
149-
call_info: Some(call_info),
150-
revert_error: None,
151-
n_reverted_steps: 0,
152-
})
153-
}
137+
Ok(call_info) => Ok(ExecutionResult {
138+
call_info: Some(call_info),
139+
revert_error: None,
140+
n_reverted_steps: 0,
141+
}),
154142
Err(e) => {
155143
if !support_reverted {
156144
return Err(e);

src/execution/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,11 +610,11 @@ impl TransactionExecutionInfo {
610610
Ok(sorted_messages)
611611
}
612612

613-
pub fn to_revert_error(self, revert_error: String) -> Self {
613+
pub fn to_revert_error(self, revert_error: &str) -> Self {
614614
TransactionExecutionInfo {
615615
validate_info: None,
616616
call_info: None,
617-
revert_error: Some(revert_error),
617+
revert_error: Some(revert_error.to_string()),
618618
fee_transfer_info: None,
619619
..self
620620
}

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ mod test {
297297
let transaction = Transaction::InvokeFunction(invoke_function);
298298

299299
let estimated_fee = estimate_fee(&[transaction], state, &block_context).unwrap();
300-
assert_eq!(estimated_fee[0], (2483, 2448));
300+
assert_eq!(estimated_fee[0], (1259, 1224));
301301
}
302302

303303
#[test]
@@ -1014,7 +1014,7 @@ mod test {
10141014

10151015
assert_eq!(
10161016
estimate_fee(&[deploy, invoke_tx], state, block_context,).unwrap(),
1017-
[(0, 3672), (0, 2448)]
1017+
[(0, 3672), (0, 1224)]
10181018
);
10191019
}
10201020

src/state/cached_state.rs

Lines changed: 202 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,16 @@ impl<T: StateReader> CachedState<T> {
8282
.ok_or(StateError::MissingCasmClassCache)
8383
}
8484

85-
pub fn create_copy(&self) -> Self {
86-
let mut state = CachedState::new(
87-
self.state_reader.clone(),
88-
self.contract_classes.clone(),
89-
self.casm_contract_classes.clone(),
90-
);
91-
state.cache = self.cache.clone();
92-
93-
state
85+
/// Creates a copy of this state with an empty cache for saving changes and applying them
86+
/// later.
87+
pub fn create_transactional(&self) -> TransactionalCachedState<T> {
88+
let state_reader = Arc::new(TransactionalCachedStateReader::new(self));
89+
CachedState {
90+
state_reader,
91+
cache: Default::default(),
92+
contract_classes: Default::default(),
93+
casm_contract_classes: Default::default(),
94+
}
9495
}
9596
}
9697

@@ -471,10 +472,10 @@ impl<T: StateReader> State for CachedState<T> {
471472
match contract {
472473
CompiledClass::Casm(ref class) => {
473474
// We call this method instead of state_reader's in order to update the cache's class_hash_initial_values map
474-
let compiled_class_hash = self.get_compiled_class_hash(class_hash)?;
475+
//let compiled_class_hash = self.get_compiled_class_hash(class_hash)?;
475476
self.casm_contract_classes
476477
.as_mut()
477-
.and_then(|m| m.insert(compiled_class_hash, *class.clone()));
478+
.and_then(|m| m.insert(*class_hash, *class.clone()));
478479
}
479480
CompiledClass::Deprecated(ref contract) => {
480481
self.set_contract_class(class_hash, &contract.clone())?
@@ -484,6 +485,196 @@ impl<T: StateReader> State for CachedState<T> {
484485
}
485486
}
486487

488+
/// A CachedState which has access to another, "parent" state, used for executing transactions
489+
/// without commiting changes to the parent.
490+
pub type TransactionalCachedState<'a, T> = CachedState<TransactionalCachedStateReader<'a, T>>;
491+
492+
impl<'a, T: StateReader> TransactionalCachedState<'a, T> {
493+
pub fn count_actual_storage_changes(&mut self) -> Result<(usize, usize), StateError> {
494+
let storage_updates = subtract_mappings(
495+
self.cache.storage_writes.clone(),
496+
self.cache.storage_initial_values.clone(),
497+
);
498+
499+
let n_modified_contracts = {
500+
let storage_unique_updates = storage_updates.keys().map(|k| k.0.clone());
501+
502+
let class_hash_updates: Vec<_> = subtract_mappings(
503+
self.cache.class_hash_writes.clone(),
504+
self.cache.class_hash_initial_values.clone(),
505+
)
506+
.keys()
507+
.cloned()
508+
.collect();
509+
510+
let nonce_updates: Vec<_> = subtract_mappings(
511+
self.cache.nonce_writes.clone(),
512+
self.cache.nonce_initial_values.clone(),
513+
)
514+
.keys()
515+
.cloned()
516+
.collect();
517+
518+
let mut modified_contracts: HashSet<Address> = HashSet::new();
519+
modified_contracts.extend(storage_unique_updates);
520+
modified_contracts.extend(class_hash_updates);
521+
modified_contracts.extend(nonce_updates);
522+
523+
modified_contracts.len()
524+
};
525+
526+
Ok((n_modified_contracts, storage_updates.len()))
527+
}
528+
}
529+
530+
/// State reader used for transactional states which allows to check the parent state's cache and
531+
/// state reader if a transactional cache miss happens.
532+
///
533+
/// In practice this will act as a way to access the parent state's cache and other fields,
534+
/// without referencing the whole parent state, so there's no need to adapt state-modifying
535+
/// functions in the case that a transactional state is needed.
536+
#[derive(Debug, MutGetters, Getters, PartialEq, Clone)]
537+
pub struct TransactionalCachedStateReader<'a, T: StateReader> {
538+
/// The parent state's state_reader
539+
#[get(get = "pub")]
540+
pub(crate) state_reader: Arc<T>,
541+
/// The parent state's cache
542+
#[get(get = "pub")]
543+
pub(crate) cache: &'a StateCache,
544+
/// The parent state's contract_classes
545+
#[get(get = "pub")]
546+
pub(crate) contract_classes: Option<ContractClassCache>,
547+
/// The parent state's casm_contract_classes
548+
#[get(get = "pub")]
549+
pub(crate) casm_contract_classes: Option<CasmClassCache>,
550+
}
551+
552+
impl<'a, T: StateReader> TransactionalCachedStateReader<'a, T> {
553+
fn new(state: &'a CachedState<T>) -> Self {
554+
Self {
555+
state_reader: state.state_reader.clone(),
556+
cache: &state.cache,
557+
contract_classes: state.contract_classes.clone(),
558+
casm_contract_classes: state.casm_contract_classes.clone(),
559+
}
560+
}
561+
}
562+
563+
impl<'a, T: StateReader> StateReader for TransactionalCachedStateReader<'a, T> {
564+
fn get_class_hash_at(&self, contract_address: &Address) -> Result<ClassHash, StateError> {
565+
if self.cache.get_class_hash(contract_address).is_none() {
566+
match self.state_reader.get_class_hash_at(contract_address) {
567+
Ok(class_hash) => {
568+
return Ok(class_hash);
569+
}
570+
Err(StateError::NoneContractState(_)) => {
571+
return Ok([0; 32]);
572+
}
573+
Err(e) => {
574+
return Err(e);
575+
}
576+
}
577+
}
578+
579+
self.cache
580+
.get_class_hash(contract_address)
581+
.ok_or_else(|| StateError::NoneClassHash(contract_address.clone()))
582+
.cloned()
583+
}
584+
585+
fn get_nonce_at(&self, contract_address: &Address) -> Result<Felt252, StateError> {
586+
if self.cache.get_nonce(contract_address).is_none() {
587+
return self.state_reader.get_nonce_at(contract_address);
588+
}
589+
self.cache
590+
.get_nonce(contract_address)
591+
.ok_or_else(|| StateError::NoneNonce(contract_address.clone()))
592+
.cloned()
593+
}
594+
595+
fn get_storage_at(&self, storage_entry: &StorageEntry) -> Result<Felt252, StateError> {
596+
if self.cache.get_storage(storage_entry).is_none() {
597+
match self.state_reader.get_storage_at(storage_entry) {
598+
Ok(storage) => {
599+
return Ok(storage);
600+
}
601+
Err(
602+
StateError::EmptyKeyInStorage
603+
| StateError::NoneStoragLeaf(_)
604+
| StateError::NoneStorage(_)
605+
| StateError::NoneContractState(_),
606+
) => return Ok(Felt252::zero()),
607+
Err(e) => {
608+
return Err(e);
609+
}
610+
}
611+
}
612+
613+
self.cache
614+
.get_storage(storage_entry)
615+
.ok_or_else(|| StateError::NoneStorage(storage_entry.clone()))
616+
.cloned()
617+
}
618+
619+
// TODO: check if that the proper way to store it (converting hash to address)
620+
fn get_compiled_class_hash(&self, class_hash: &ClassHash) -> Result<ClassHash, StateError> {
621+
if self
622+
.cache
623+
.class_hash_to_compiled_class_hash
624+
.get(class_hash)
625+
.is_none()
626+
{
627+
return self.state_reader.get_compiled_class_hash(class_hash);
628+
}
629+
self.cache
630+
.class_hash_to_compiled_class_hash
631+
.get(class_hash)
632+
.ok_or_else(|| StateError::NoneCompiledClass(*class_hash))
633+
.cloned()
634+
}
635+
636+
fn get_contract_class(&self, class_hash: &ClassHash) -> Result<CompiledClass, StateError> {
637+
// This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes
638+
//, which can be on the cache or on the state_reader, different cases will be described below:
639+
if class_hash == UNINITIALIZED_CLASS_HASH {
640+
return Err(StateError::UninitiaizedClassHash);
641+
}
642+
// I: FETCHING FROM CACHE
643+
// I: DEPRECATED CONTRACT CLASS
644+
// deprecated contract classes dont have compiled class hashes, so we only have one case
645+
if let Some(compiled_class) = self
646+
.contract_classes
647+
.as_ref()
648+
.and_then(|x| x.get(class_hash))
649+
{
650+
return Ok(CompiledClass::Deprecated(Box::new(compiled_class.clone())));
651+
}
652+
// I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH
653+
if let Some(compiled_class) = self
654+
.casm_contract_classes
655+
.as_ref()
656+
.and_then(|x| x.get(class_hash))
657+
{
658+
return Ok(CompiledClass::Casm(Box::new(compiled_class.clone())));
659+
}
660+
// I: CASM CONTRACT CLASS : CLASS_HASH
661+
if let Some(compiled_class_hash) =
662+
self.cache.class_hash_to_compiled_class_hash.get(class_hash)
663+
{
664+
if let Some(casm_class) = &mut self
665+
.casm_contract_classes
666+
.as_ref()
667+
.and_then(|m| m.get(compiled_class_hash))
668+
{
669+
return Ok(CompiledClass::Casm(Box::new(casm_class.clone())));
670+
}
671+
}
672+
// II: FETCHING FROM STATE_READER
673+
let contract = self.state_reader.get_contract_class(class_hash)?;
674+
Ok(contract)
675+
}
676+
}
677+
487678
#[cfg(test)]
488679
mod tests {
489680
use super::*;

src/state/in_memory_state_reader.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ pub struct InMemoryStateReader {
3232
#[getset(get_mut = "pub")]
3333
pub(crate) casm_contract_classes: CasmClassCache,
3434
#[getset(get_mut = "pub")]
35-
pub(crate) class_hash_to_compiled_class_hash: HashMap<ClassHash, CompiledClassHash>,
35+
pub class_hash_to_compiled_class_hash: HashMap<ClassHash, CompiledClassHash>,
3636
}
3737

3838
impl InMemoryStateReader {

src/syscalls/deprecated_syscall_handler.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ mod tests {
228228

229229
use super::*;
230230
use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType;
231+
use crate::state::StateDiff;
231232
use crate::{
232233
add_segments, allocate_selector, any_box,
233234
definitions::{
@@ -1188,9 +1189,14 @@ mod tests {
11881189
)
11891190
.unwrap();
11901191

1192+
let mut transactional = state.create_transactional();
11911193
// Invoke result
11921194
let result = internal_invoke_function
1193-
.apply(&mut state, &BlockContext::default(), 0)
1195+
.apply(&mut transactional, &BlockContext::default(), 0)
1196+
.unwrap();
1197+
1198+
state
1199+
.apply_state_update(&StateDiff::from_cached_state(transactional).unwrap())
11941200
.unwrap();
11951201

11961202
let result_call_info = result.call_info.unwrap();

src/testing/state.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ mod tests {
574574
.unwrap();
575575
let actual_resources = HashMap::from([
576576
("n_steps".to_string(), 3457),
577-
("l1_gas_usage".to_string(), 2448),
577+
("l1_gas_usage".to_string(), 1224),
578578
("range_check_builtin".to_string(), 80),
579579
("pedersen_builtin".to_string(), 16),
580580
]);

0 commit comments

Comments
 (0)