@@ -82,15 +82,16 @@ impl<T: StateReader> CachedState<T> {
82
82
. ok_or ( StateError :: MissingCasmClassCache )
83
83
}
84
84
85
- pub fn create_copy ( & self ) -> Self {
86
- let mut state = CachedState :: new (
87
- self . state_reader . clone ( ) ,
88
- self . contract_classes . clone ( ) ,
89
- self . casm_contract_classes . clone ( ) ,
90
- ) ;
91
- state. cache = self . cache . clone ( ) ;
92
-
93
- state
85
+ /// Creates a copy of this state with an empty cache for saving changes and applying them
86
+ /// later.
87
+ pub fn create_transactional ( & self ) -> TransactionalCachedState < T > {
88
+ let state_reader = Arc :: new ( TransactionalCachedStateReader :: new ( self ) ) ;
89
+ CachedState {
90
+ state_reader,
91
+ cache : Default :: default ( ) ,
92
+ contract_classes : Default :: default ( ) ,
93
+ casm_contract_classes : Default :: default ( ) ,
94
+ }
94
95
}
95
96
}
96
97
@@ -471,10 +472,10 @@ impl<T: StateReader> State for CachedState<T> {
471
472
match contract {
472
473
CompiledClass :: Casm ( ref class) => {
473
474
// We call this method instead of state_reader's in order to update the cache's class_hash_initial_values map
474
- let compiled_class_hash = self . get_compiled_class_hash ( class_hash) ?;
475
+ // let compiled_class_hash = self.get_compiled_class_hash(class_hash)?;
475
476
self . casm_contract_classes
476
477
. as_mut ( )
477
- . and_then ( |m| m. insert ( compiled_class_hash , * class. clone ( ) ) ) ;
478
+ . and_then ( |m| m. insert ( * class_hash , * class. clone ( ) ) ) ;
478
479
}
479
480
CompiledClass :: Deprecated ( ref contract) => {
480
481
self . set_contract_class ( class_hash, & contract. clone ( ) ) ?
@@ -484,6 +485,196 @@ impl<T: StateReader> State for CachedState<T> {
484
485
}
485
486
}
486
487
488
+ /// A CachedState which has access to another, "parent" state, used for executing transactions
489
+ /// without commiting changes to the parent.
490
+ pub type TransactionalCachedState < ' a , T > = CachedState < TransactionalCachedStateReader < ' a , T > > ;
491
+
492
+ impl < ' a , T : StateReader > TransactionalCachedState < ' a , T > {
493
+ pub fn count_actual_storage_changes ( & mut self ) -> Result < ( usize , usize ) , StateError > {
494
+ let storage_updates = subtract_mappings (
495
+ self . cache . storage_writes . clone ( ) ,
496
+ self . cache . storage_initial_values . clone ( ) ,
497
+ ) ;
498
+
499
+ let n_modified_contracts = {
500
+ let storage_unique_updates = storage_updates. keys ( ) . map ( |k| k. 0 . clone ( ) ) ;
501
+
502
+ let class_hash_updates: Vec < _ > = subtract_mappings (
503
+ self . cache . class_hash_writes . clone ( ) ,
504
+ self . cache . class_hash_initial_values . clone ( ) ,
505
+ )
506
+ . keys ( )
507
+ . cloned ( )
508
+ . collect ( ) ;
509
+
510
+ let nonce_updates: Vec < _ > = subtract_mappings (
511
+ self . cache . nonce_writes . clone ( ) ,
512
+ self . cache . nonce_initial_values . clone ( ) ,
513
+ )
514
+ . keys ( )
515
+ . cloned ( )
516
+ . collect ( ) ;
517
+
518
+ let mut modified_contracts: HashSet < Address > = HashSet :: new ( ) ;
519
+ modified_contracts. extend ( storage_unique_updates) ;
520
+ modified_contracts. extend ( class_hash_updates) ;
521
+ modified_contracts. extend ( nonce_updates) ;
522
+
523
+ modified_contracts. len ( )
524
+ } ;
525
+
526
+ Ok ( ( n_modified_contracts, storage_updates. len ( ) ) )
527
+ }
528
+ }
529
+
530
+ /// State reader used for transactional states which allows to check the parent state's cache and
531
+ /// state reader if a transactional cache miss happens.
532
+ ///
533
+ /// In practice this will act as a way to access the parent state's cache and other fields,
534
+ /// without referencing the whole parent state, so there's no need to adapt state-modifying
535
+ /// functions in the case that a transactional state is needed.
536
+ #[ derive( Debug , MutGetters , Getters , PartialEq , Clone ) ]
537
+ pub struct TransactionalCachedStateReader < ' a , T : StateReader > {
538
+ /// The parent state's state_reader
539
+ #[ get( get = "pub" ) ]
540
+ pub ( crate ) state_reader : Arc < T > ,
541
+ /// The parent state's cache
542
+ #[ get( get = "pub" ) ]
543
+ pub ( crate ) cache : & ' a StateCache ,
544
+ /// The parent state's contract_classes
545
+ #[ get( get = "pub" ) ]
546
+ pub ( crate ) contract_classes : Option < ContractClassCache > ,
547
+ /// The parent state's casm_contract_classes
548
+ #[ get( get = "pub" ) ]
549
+ pub ( crate ) casm_contract_classes : Option < CasmClassCache > ,
550
+ }
551
+
552
+ impl < ' a , T : StateReader > TransactionalCachedStateReader < ' a , T > {
553
+ fn new ( state : & ' a CachedState < T > ) -> Self {
554
+ Self {
555
+ state_reader : state. state_reader . clone ( ) ,
556
+ cache : & state. cache ,
557
+ contract_classes : state. contract_classes . clone ( ) ,
558
+ casm_contract_classes : state. casm_contract_classes . clone ( ) ,
559
+ }
560
+ }
561
+ }
562
+
563
+ impl < ' a , T : StateReader > StateReader for TransactionalCachedStateReader < ' a , T > {
564
+ fn get_class_hash_at ( & self , contract_address : & Address ) -> Result < ClassHash , StateError > {
565
+ if self . cache . get_class_hash ( contract_address) . is_none ( ) {
566
+ match self . state_reader . get_class_hash_at ( contract_address) {
567
+ Ok ( class_hash) => {
568
+ return Ok ( class_hash) ;
569
+ }
570
+ Err ( StateError :: NoneContractState ( _) ) => {
571
+ return Ok ( [ 0 ; 32 ] ) ;
572
+ }
573
+ Err ( e) => {
574
+ return Err ( e) ;
575
+ }
576
+ }
577
+ }
578
+
579
+ self . cache
580
+ . get_class_hash ( contract_address)
581
+ . ok_or_else ( || StateError :: NoneClassHash ( contract_address. clone ( ) ) )
582
+ . cloned ( )
583
+ }
584
+
585
+ fn get_nonce_at ( & self , contract_address : & Address ) -> Result < Felt252 , StateError > {
586
+ if self . cache . get_nonce ( contract_address) . is_none ( ) {
587
+ return self . state_reader . get_nonce_at ( contract_address) ;
588
+ }
589
+ self . cache
590
+ . get_nonce ( contract_address)
591
+ . ok_or_else ( || StateError :: NoneNonce ( contract_address. clone ( ) ) )
592
+ . cloned ( )
593
+ }
594
+
595
+ fn get_storage_at ( & self , storage_entry : & StorageEntry ) -> Result < Felt252 , StateError > {
596
+ if self . cache . get_storage ( storage_entry) . is_none ( ) {
597
+ match self . state_reader . get_storage_at ( storage_entry) {
598
+ Ok ( storage) => {
599
+ return Ok ( storage) ;
600
+ }
601
+ Err (
602
+ StateError :: EmptyKeyInStorage
603
+ | StateError :: NoneStoragLeaf ( _)
604
+ | StateError :: NoneStorage ( _)
605
+ | StateError :: NoneContractState ( _) ,
606
+ ) => return Ok ( Felt252 :: zero ( ) ) ,
607
+ Err ( e) => {
608
+ return Err ( e) ;
609
+ }
610
+ }
611
+ }
612
+
613
+ self . cache
614
+ . get_storage ( storage_entry)
615
+ . ok_or_else ( || StateError :: NoneStorage ( storage_entry. clone ( ) ) )
616
+ . cloned ( )
617
+ }
618
+
619
+ // TODO: check if that the proper way to store it (converting hash to address)
620
+ fn get_compiled_class_hash ( & self , class_hash : & ClassHash ) -> Result < ClassHash , StateError > {
621
+ if self
622
+ . cache
623
+ . class_hash_to_compiled_class_hash
624
+ . get ( class_hash)
625
+ . is_none ( )
626
+ {
627
+ return self . state_reader . get_compiled_class_hash ( class_hash) ;
628
+ }
629
+ self . cache
630
+ . class_hash_to_compiled_class_hash
631
+ . get ( class_hash)
632
+ . ok_or_else ( || StateError :: NoneCompiledClass ( * class_hash) )
633
+ . cloned ( )
634
+ }
635
+
636
+ fn get_contract_class ( & self , class_hash : & ClassHash ) -> Result < CompiledClass , StateError > {
637
+ // This method can receive both compiled_class_hash & class_hash and return both casm and deprecated contract classes
638
+ //, which can be on the cache or on the state_reader, different cases will be described below:
639
+ if class_hash == UNINITIALIZED_CLASS_HASH {
640
+ return Err ( StateError :: UninitiaizedClassHash ) ;
641
+ }
642
+ // I: FETCHING FROM CACHE
643
+ // I: DEPRECATED CONTRACT CLASS
644
+ // deprecated contract classes dont have compiled class hashes, so we only have one case
645
+ if let Some ( compiled_class) = self
646
+ . contract_classes
647
+ . as_ref ( )
648
+ . and_then ( |x| x. get ( class_hash) )
649
+ {
650
+ return Ok ( CompiledClass :: Deprecated ( Box :: new ( compiled_class. clone ( ) ) ) ) ;
651
+ }
652
+ // I: CASM CONTRACT CLASS : COMPILED_CLASS_HASH
653
+ if let Some ( compiled_class) = self
654
+ . casm_contract_classes
655
+ . as_ref ( )
656
+ . and_then ( |x| x. get ( class_hash) )
657
+ {
658
+ return Ok ( CompiledClass :: Casm ( Box :: new ( compiled_class. clone ( ) ) ) ) ;
659
+ }
660
+ // I: CASM CONTRACT CLASS : CLASS_HASH
661
+ if let Some ( compiled_class_hash) =
662
+ self . cache . class_hash_to_compiled_class_hash . get ( class_hash)
663
+ {
664
+ if let Some ( casm_class) = & mut self
665
+ . casm_contract_classes
666
+ . as_ref ( )
667
+ . and_then ( |m| m. get ( compiled_class_hash) )
668
+ {
669
+ return Ok ( CompiledClass :: Casm ( Box :: new ( casm_class. clone ( ) ) ) ) ;
670
+ }
671
+ }
672
+ // II: FETCHING FROM STATE_READER
673
+ let contract = self . state_reader . get_contract_class ( class_hash) ?;
674
+ Ok ( contract)
675
+ }
676
+ }
677
+
487
678
#[ cfg( test) ]
488
679
mod tests {
489
680
use super :: * ;
0 commit comments