Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Fix/Refactor State::count actual storage changes #1086

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/definitions/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::collections::HashMap;

pub(crate) const L2_TO_L1_MSG_HEADER_SIZE: usize = 3;
pub(crate) const L1_TO_L2_MSG_HEADER_SIZE: usize = 5;
pub(crate) const DEPLOYMENT_INFO_SIZE: usize = 1;
pub(crate) const CLASS_UPDATE_SIZE: usize = 1;
pub(crate) const CONSUMED_MSG_TO_L2_N_TOPICS: usize = 3;
pub(crate) const LOG_MSG_TO_L1_N_TOPICS: usize = 2;
pub(crate) const N_DEFAULT_TOPICS: usize = 1; // Events have one default topic.
Expand Down
47 changes: 25 additions & 22 deletions src/execution/gas_usage.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::definitions::constants::*;
use crate::execution::L2toL1MessageInfo;
use crate::services::eth_definitions::eth_gas_constants::*;
use crate::state::state_api::StateChangesCount;

/// Estimates L1 gas usage by Starknet's update state and the verifier
///
Expand All @@ -19,16 +20,13 @@ use crate::services::eth_definitions::eth_gas_constants::*;
/// The estimation of L1 gas usage as a `usize` value.
pub fn calculate_tx_gas_usage(
l2_to_l1_messages: Vec<L2toL1MessageInfo>,
n_modified_contracts: usize,
n_storage_changes: usize,
state_changes: &StateChangesCount,
l1_handler_payload_size: Option<usize>,
n_deployments: usize,
) -> usize {
let residual_message_segment_length =
get_message_segment_lenght(&l2_to_l1_messages, l1_handler_payload_size);

let residual_onchain_data_segment_length =
get_onchain_data_segment_length(n_modified_contracts, n_storage_changes, n_deployments);
let residual_onchain_data_segment_length = get_onchain_data_segment_length(state_changes);

let n_l2_to_l1_messages = l2_to_l1_messages.len();
let n_l1_to_l2_messages = match l1_handler_payload_size {
Expand Down Expand Up @@ -95,22 +93,18 @@ pub fn get_message_segment_lenght(
}

/// Calculates the amount of `felt252` added to the output message's segment by the given operations.
///
/// # Parameters:
///
/// - `n_modified_contracts`: The number of contracts modified by the transaction.
/// - `n_storage_changes`: The number of storage changes made by the transaction.
/// - `n_deployments`: The number of contracts deployed by the transaction.
///
/// # Returns:
///
/// The on-chain data segment length
pub const fn get_onchain_data_segment_length(
n_modified_contracts: usize,
n_storage_changes: usize,
n_deployments: usize,
) -> usize {
n_modified_contracts * 2 + n_deployments * DEPLOYMENT_INFO_SIZE + n_storage_changes * 2
pub const fn get_onchain_data_segment_length(state_changes: &StateChangesCount) -> usize {
// For each newly modified contract:
// contract address (1 word).
// + 1 word with the following info: A flag indicating whether the class hash was updated, the
// number of entry updates, and the new nonce.
state_changes.n_modified_contracts * 2
// For each class updated (through a deploy or a class replacement).
+ state_changes.n_class_hash_updates * CLASS_UPDATE_SIZE
// For each modified storage cell: key, new value.
+ state_changes.n_storage_updates * 2
// For each compiled class updated (through declare): class_hash, compiled_class_hash
+ state_changes.n_compiled_class_hash_updates * 2
}

/// Calculates the cost of ConsumedMessageToL2 event emissions caused by an L1 handler with the given
Expand Down Expand Up @@ -261,7 +255,16 @@ mod test {
let message2 = L2toL1MessageInfo::new(ord_ev2, Address(1235.into()));

assert_eq!(
calculate_tx_gas_usage(vec![message1, message2], 2, 2, Some(2), 1),
calculate_tx_gas_usage(
vec![message1, message2],
&StateChangesCount {
n_storage_updates: 2,
n_class_hash_updates: 1,
n_compiled_class_hash_updates: 0,
n_modified_contracts: 2
},
Some(2)
),
76439
)
}
Expand Down
39 changes: 26 additions & 13 deletions src/state/cached_state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{
state_api::{State, StateReader},
state_api::{State, StateChangesCount, StateReader},
state_cache::{StateCache, StorageEntry},
};
use crate::{
Expand Down Expand Up @@ -283,10 +283,10 @@ impl<T: StateReader> State for CachedState<T> {
Ok(())
}

fn count_actual_storage_changes(
fn count_actual_state_changes(
&mut self,
fee_token_and_sender_address: Option<(&Address, &Address)>,
) -> Result<(usize, usize), StateError> {
) -> Result<StateChangesCount, StateError> {
self.update_initial_values_of_write_only_accesses()?;

let mut storage_updates = subtract_mappings(
Expand All @@ -296,17 +296,24 @@ impl<T: StateReader> State for CachedState<T> {

let storage_unique_updates = storage_updates.keys().map(|k| k.0.clone());

let class_hash_updates = subtract_mappings_keys(
let class_hash_updates: Vec<&Address> = subtract_mappings_keys(
&self.cache.class_hash_writes,
&self.cache.class_hash_initial_values,
)
.collect();
let n_class_hash_updates = class_hash_updates.len();

let compiled_class_hash_updates = subtract_mappings_keys(
&self.cache.compiled_class_hash_writes,
&self.cache.compiled_class_hash_initial_values,
);

let nonce_updates =
subtract_mappings_keys(&self.cache.nonce_writes, &self.cache.nonce_initial_values);

let mut modified_contracts: HashSet<Address> = HashSet::new();
modified_contracts.extend(storage_unique_updates);
modified_contracts.extend(class_hash_updates.cloned());
modified_contracts.extend(class_hash_updates.into_iter().cloned());
modified_contracts.extend(nonce_updates.cloned());

// Add fee transfer storage update before actually charging it, as it needs to be included in the
Expand All @@ -320,7 +327,12 @@ impl<T: StateReader> State for CachedState<T> {
modified_contracts.remove(fee_token_address);
}

Ok((modified_contracts.len(), storage_updates.len()))
Ok(StateChangesCount {
n_storage_updates: storage_updates.len(),
n_class_hash_updates,
n_compiled_class_hash_updates: compiled_class_hash_updates.count(),
n_modified_contracts: modified_contracts.len(),
})
}

/// Returns the class hash for a given contract address.
Expand Down Expand Up @@ -800,7 +812,7 @@ mod tests {

/// This test calculate the number of actual storage changes.
#[test]
fn count_actual_storage_changes_test() {
fn count_actual_state_changes_test() {
let state_reader = InMemoryStateReader::default();

let mut cached_state = CachedState::new(Arc::new(state_reader), HashMap::new());
Expand All @@ -822,14 +834,15 @@ mod tests {
let fee_token_address = Address(123.into());
let sender_address = Address(321.into());

let expected_changes = {
let n_storage_updates = 3 + 1; // + 1 fee transfer balance update
let n_modified_contracts = 2;

(n_modified_contracts, n_storage_updates)
let expected_changes = StateChangesCount {
n_storage_updates: 3 + 1, // + 1 fee transfer balance update,
n_class_hash_updates: 0,
n_compiled_class_hash_updates: 0,
n_modified_contracts: 2,
};

let changes = cached_state
.count_actual_storage_changes(Some((&fee_token_address, &sender_address)))
.count_actual_state_changes(Some((&fee_token_address, &sender_address)))
.unwrap();

assert_eq!(changes, expected_changes);
Expand Down
14 changes: 11 additions & 3 deletions src/state/state_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ pub trait StateReader {
) -> Result<CompiledClassHash, StateError>;
}

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct StateChangesCount {
pub n_storage_updates: usize,
pub n_class_hash_updates: usize,
pub n_compiled_class_hash_updates: usize,
pub n_modified_contracts: usize,
}

pub trait State {
fn set_contract_class(
&mut self,
Expand Down Expand Up @@ -63,11 +71,11 @@ pub trait State {

fn apply_state_update(&mut self, sate_updates: &StateDiff) -> Result<(), StateError>;

/// Counts the amount of modified contracts and the updates to the storage
fn count_actual_storage_changes(
/// Counts the amount of state changes
fn count_actual_state_changes(
&mut self,
fee_token_and_sender_address: Option<(&Address, &Address)>,
) -> Result<(usize, usize), StateError>;
) -> Result<StateChangesCount, StateError>;

/// Returns the class hash of the contract class at the given address.
/// Returns zero by default if the value is not present
Expand Down
2 changes: 1 addition & 1 deletion src/transaction/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ impl Declare {
} else {
self.run_validate_entrypoint(state, &mut resources_manager, block_context)?
};
let changes = state.count_actual_storage_changes(Some((
let changes = state.count_actual_state_changes(Some((
&block_context.starknet_os_config.fee_token_address,
&self.sender_address,
)))?;
Expand Down
2 changes: 1 addition & 1 deletion src/transaction/declare_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ impl DeclareV2 {
(info, gas)
};

let storage_changes = state.count_actual_storage_changes(Some((
let storage_changes = state.count_actual_state_changes(Some((
&block_context.starknet_os_config.fee_token_address,
&self.sender_address,
)))?;
Expand Down
4 changes: 2 additions & 2 deletions src/transaction/deploy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ impl Deploy {

let resources_manager = ExecutionResourcesManager::default();

let changes = state.count_actual_storage_changes(None)?;
let changes = state.count_actual_state_changes(None)?;
let actual_resources = calculate_tx_resources(
resources_manager,
&[Some(call_info.clone())],
Expand Down Expand Up @@ -250,7 +250,7 @@ impl Deploy {
block_context.validate_max_n_steps,
)?;

let changes = state.count_actual_storage_changes(None)?;
let changes = state.count_actual_state_changes(None)?;
let actual_resources = calculate_tx_resources(
resources_manager,
&[call_info.clone()],
Expand Down
2 changes: 1 addition & 1 deletion src/transaction/deploy_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ impl DeployAccount {
resources_manager,
&[Some(constructor_call_info.clone()), validate_info.clone()],
TransactionType::DeployAccount,
state.count_actual_storage_changes(Some((
state.count_actual_state_changes(Some((
&block_context.starknet_os_config.fee_token_address,
&self.contract_address,
)))?,
Expand Down
2 changes: 1 addition & 1 deletion src/transaction/invoke_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ impl InvokeFunction {
remaining_gas,
)?
};
let changes = state.count_actual_storage_changes(Some((
let changes = state.count_actual_state_changes(Some((
&block_context.starknet_os_config.fee_token_address,
&self.contract_address,
)))?;
Expand Down
2 changes: 1 addition & 1 deletion src/transaction/l1_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl L1Handler {
)?
};

let changes = state.count_actual_storage_changes(None)?;
let changes = state.count_actual_state_changes(None)?;
let actual_resources = calculate_tx_resources(
resources_manager,
&[call_info.clone()],
Expand Down
16 changes: 4 additions & 12 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::core::errors::hash_errors::HashError;
use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType;
use crate::state::state_api::State;
use crate::state::state_api::{State, StateChangesCount};
use crate::{
definitions::transaction_type::TransactionType,
execution::{
Expand Down Expand Up @@ -169,28 +169,20 @@ pub fn calculate_tx_resources(
resources_manager: ExecutionResourcesManager,
call_info: &[Option<CallInfo>],
tx_type: TransactionType,
storage_changes: (usize, usize),
state_changes: StateChangesCount,
l1_handler_payload_size: Option<usize>,
n_reverted_steps: usize,
) -> Result<HashMap<String, usize>, TransactionError> {
let (n_modified_contracts, n_storage_changes) = storage_changes;

let non_optional_calls: Vec<CallInfo> = call_info.iter().flatten().cloned().collect();
let n_deployments = non_optional_calls.iter().map(get_call_n_deployments).sum();

let mut l2_to_l1_messages = Vec::new();

for call_info in non_optional_calls {
l2_to_l1_messages.extend(call_info.get_sorted_l2_to_l1_messages()?)
}

let l1_gas_usage = calculate_tx_gas_usage(
l2_to_l1_messages,
n_modified_contracts,
n_storage_changes,
l1_handler_payload_size,
n_deployments,
);
let l1_gas_usage =
calculate_tx_gas_usage(l2_to_l1_messages, &state_changes, l1_handler_payload_size);

let cairo_usage = resources_manager.cairo_usage.clone();
let tx_syscall_counter = resources_manager.syscall_counter;
Expand Down
4 changes: 2 additions & 2 deletions tests/internals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,7 @@ fn expected_fib_transaction_execution_info(
}
let resources = HashMap::from([
("n_steps".to_string(), n_steps),
("l1_gas_usage".to_string(), 4896),
("l1_gas_usage".to_string(), 5508),
("pedersen_builtin".to_string(), 16),
("range_check_builtin".to_string(), 104),
]);
Expand Down Expand Up @@ -1477,7 +1477,7 @@ fn test_invoke_with_declarev2_tx() {
];
let invoke_tx = invoke_tx(calldata, u128::MAX);

let expected_gas_consumed = 4908;
let expected_gas_consumed = 5551;
let result = invoke_tx
.execute(state, block_context, expected_gas_consumed)
.unwrap();
Expand Down