Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Added fee transfer storage update into count_actual_storage_changes() #960

Merged
merged 10 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ mod test {
let transaction = Transaction::InvokeFunction(invoke_function);

let estimated_fee = estimate_fee(&[transaction], state, &block_context).unwrap();
assert_eq!(estimated_fee[0], (2483, 2448));
assert_eq!(estimated_fee[0], (3707, 3672));
}

#[test]
Expand Down Expand Up @@ -1035,7 +1035,7 @@ mod test {

assert_eq!(
estimate_fee(&[deploy, invoke_tx], state, block_context,).unwrap(),
[(0, 3672), (0, 2448)]
[(0, 3672), (0, 3672)]
);
}

Expand Down
30 changes: 25 additions & 5 deletions src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use crate::{
core::errors::state_errors::StateError,
services::api::contract_classes::compiled_class::CompiledClass,
state::StateDiff,
utils::{subtract_mappings, to_cache_state_storage_mapping, Address, ClassHash},
utils::{
get_erc20_balance_var_addresses, subtract_mappings, to_cache_state_storage_mapping,
Address, ClassHash,
},
};
use cairo_vm::felt::Felt252;
use getset::{Getters, MutGetters};
Expand Down Expand Up @@ -268,8 +271,11 @@ impl<T: StateReader> State for CachedState<T> {
Ok(())
}

fn count_actual_storage_changes(&mut self) -> (usize, usize) {
let storage_updates = subtract_mappings(
fn count_actual_storage_changes(
&mut self,
fee_token_and_sender_address: Option<(&Address, &Address)>,
) -> (usize, usize) {
let mut storage_updates = subtract_mappings(
self.cache.storage_writes.clone(),
self.cache.storage_initial_values.clone(),
);
Expand Down Expand Up @@ -301,6 +307,16 @@ impl<T: StateReader> State for CachedState<T> {
modified_contracts.len()
};

// Add fee transfer storage update before actually charging it, as it needs to be included in the
// calculation of the final fee.
if let Some((fee_token_address, sender_address)) = fee_token_and_sender_address {
let (sender_low_key, _) = get_erc20_balance_var_addresses(sender_address).unwrap();
storage_updates.insert(
(fee_token_address.clone(), sender_low_key),
Felt252::default(),
);
}

(n_modified_contracts, storage_updates.len())
}

Expand Down Expand Up @@ -705,13 +721,17 @@ mod tests {
((address_two, storage_key_two), Felt252::from(1)),
]);

let fee_token_address = Address(123.into());
let sender_address = Address(321.into());

let expected_changes = {
let n_storage_updates = 3;
let n_storage_updates = 3 + 1; // + 1 fee transfer balance update
let n_modified_contracts = 2;

(n_modified_contracts, n_storage_updates)
};
let changes = cached_state.count_actual_storage_changes();
let changes =
cached_state.count_actual_storage_changes(Some((&fee_token_address, &sender_address)));

assert_eq!(changes, expected_changes);
}
Expand Down
5 changes: 4 additions & 1 deletion src/state/state_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ pub trait State {
fn apply_state_update(&mut self, sate_updates: &StateDiff) -> Result<(), StateError>;

/// Counts the amount of modified contracts and the updates to the storage
fn count_actual_storage_changes(&mut self) -> (usize, usize);
fn count_actual_storage_changes(
&mut self,
fee_token_and_sender_address: Option<(&Address, &Address)>,
) -> (usize, usize);

fn get_class_hash_at(&mut self, contract_address: &Address) -> Result<ClassHash, StateError>;

Expand Down
2 changes: 1 addition & 1 deletion src/testing/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ mod tests {
.unwrap();
let actual_resources = HashMap::from([
("n_steps".to_string(), 3457),
("l1_gas_usage".to_string(), 2448),
("l1_gas_usage".to_string(), 3672),
("range_check_builtin".to_string(), 80),
("pedersen_builtin".to_string(), 16),
]);
Expand Down
7 changes: 5 additions & 2 deletions src/transaction/declare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ impl Declare {
} else {
self.run_validate_entrypoint(state, &mut resources_manager, block_context)?
};
let changes = state.count_actual_storage_changes();
let changes = state.count_actual_storage_changes(Some((
&block_context.starknet_os_config.fee_token_address,
&self.sender_address,
)));
let actual_resources = calculate_tx_resources(
resources_manager,
&vec![validate_info.clone()],
Expand Down Expand Up @@ -435,7 +438,7 @@ mod tests {

let actual_resources = HashMap::from([
("n_steps".to_string(), 2715),
("l1_gas_usage".to_string(), 1224),
("l1_gas_usage".to_string(), 2448),
("range_check_builtin".to_string(), 63),
("pedersen_builtin".to_string(), 15),
]);
Expand Down
5 changes: 4 additions & 1 deletion src/transaction/declare_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,10 @@ impl DeclareV2 {
(info, gas)
};

let storage_changes = state.count_actual_storage_changes();
let storage_changes = state.count_actual_storage_changes(Some((
&block_context.starknet_os_config.fee_token_address,
&self.sender_address,
)));
let actual_resources = calculate_tx_resources(
resources_manager,
&[execution_result.call_info.clone()],
Expand Down
4 changes: 2 additions & 2 deletions src/transaction/deploy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl Deploy {

let resources_manager = ExecutionResourcesManager::default();

let changes = state.count_actual_storage_changes();
let changes = state.count_actual_storage_changes(None);
let actual_resources = calculate_tx_resources(
resources_manager,
&[Some(call_info.clone())],
Expand Down Expand Up @@ -245,7 +245,7 @@ impl Deploy {
block_context.validate_max_n_steps,
)?;

let changes = state.count_actual_storage_changes();
let changes = state.count_actual_storage_changes(None);
let actual_resources = calculate_tx_resources(
resources_manager,
&[call_info.clone()],
Expand Down
5 changes: 4 additions & 1 deletion src/transaction/deploy_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ impl DeployAccount {
resources_manager,
&[Some(constructor_call_info.clone()), validate_info.clone()],
TransactionType::DeployAccount,
state.count_actual_storage_changes(),
state.count_actual_storage_changes(Some((
&block_context.starknet_os_config.fee_token_address,
&self.contract_address,
))),
None,
0,
)
Expand Down
5 changes: 4 additions & 1 deletion src/transaction/invoke_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,10 @@ impl InvokeFunction {
remaining_gas,
)?
};
let changes = state.count_actual_storage_changes();
let changes = state.count_actual_storage_changes(Some((
&block_context.starknet_os_config.fee_token_address,
&self.contract_address,
)));
let actual_resources = calculate_tx_resources(
resources_manager,
&vec![call_info.clone(), validate_info.clone()],
Expand Down
2 changes: 1 addition & 1 deletion src/transaction/l1_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl L1Handler {
)?
};

let changes = state.count_actual_storage_changes();
let changes = state.count_actual_storage_changes(None);
let actual_resources = calculate_tx_resources(
resources_manager,
&[call_info.clone()],
Expand Down
50 changes: 49 additions & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ use cairo_vm::{
felt::Felt252, serde::deserialize_program::BuiltinName, vm::runners::builtin_runner,
};
use cairo_vm::{types::relocatable::Relocatable, vm::vm_core::VirtualMachine};
use num_integer::Integer;
use num_traits::{Num, ToPrimitive};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sha3::{Digest, Keccak256};
use starknet_crypto::FieldElement;
use starknet::core::types::FromByteArrayError;
use starknet_api::core::L2_ADDRESS_UPPER_BOUND;
use starknet_crypto::{pedersen_hash, FieldElement};
use std::{
collections::{HashMap, HashSet},
hash::Hash,
Expand Down Expand Up @@ -270,6 +273,51 @@ where
keys1.into_iter().collect()
}

/// Returns the storage address of a StarkNet storage variable given its name and arguments.
pub fn get_storage_var_address(
storage_var_name: &str,
args: &[Felt252],
) -> Result<Felt252, FromByteArrayError> {
let felt_to_field_element = |felt: &Felt252| -> Result<FieldElement, FromByteArrayError> {
FieldElement::from_bytes_be(&felt.to_be_bytes())
};

let args = args
.iter()
.map(|felt| felt_to_field_element(felt))
.collect::<Result<Vec<_>, _>>()?;

let storage_var_name_hash =
FieldElement::from_bytes_be(&calculate_sn_keccak(storage_var_name.as_bytes()))?;
let storage_key_hash = args
.iter()
.fold(storage_var_name_hash, |res, arg| pedersen_hash(&res, arg));

let storage_key = field_element_to_felt(&storage_key_hash).mod_floor(&Felt252::from_bytes_be(
&L2_ADDRESS_UPPER_BOUND.to_bytes_be(),
));

Ok(storage_key)
}

/// Gets storage keys for a Uint256 storage variable.
pub fn get_uint256_storage_var_addresses(
storage_var_name: &str,
args: &[Felt252],
) -> Result<(Felt252, Felt252), FromByteArrayError> {
let low_key = get_storage_var_address(storage_var_name, args)?;
let high_key = &low_key + &Felt252::from(1);
Ok((low_key, high_key))
}

pub fn get_erc20_balance_var_addresses(
contract_address: &Address,
) -> Result<([u8; 32], [u8; 32]), FromByteArrayError> {
let (felt_low, felt_high) =
get_uint256_storage_var_addresses("ERC20_balances", &[contract_address.clone().0])?;
Ok((felt_low.to_be_bytes(), felt_high.to_be_bytes()))
}

//* ----------------------------
//* Execution entry point utils
//* ----------------------------
Expand Down
4 changes: 2 additions & 2 deletions tests/deploy_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ fn internal_deploy_account() {
("n_steps", 3612),
("pedersen_builtin", 23),
("range_check_builtin", 83),
("l1_gas_usage", 3672)
("l1_gas_usage", 4896)
]
.into_iter()
.map(|(k, v)| (k.to_string(), v))
Expand Down Expand Up @@ -264,7 +264,7 @@ fn internal_deploy_account_cairo1() {
("n_steps", n_steps),
("pedersen_builtin", 23),
("range_check_builtin", 87),
("l1_gas_usage", 4896)
("l1_gas_usage", 6120)
]
.into_iter()
.map(|(k, v)| (k.to_string(), v))
Expand Down
12 changes: 6 additions & 6 deletions tests/internals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,9 @@ fn expected_fib_fee_transfer_info(fee: u128) -> CallInfo {
],
}],
storage_read_values: vec![
INITIAL_BALANCE.clone() - Felt252::from(1252),
INITIAL_BALANCE.clone() - Felt252::from(2476),
Felt252::zero(),
Felt252::from(1252),
Felt252::from(2476),
Felt252::zero(),
],
accessed_storage_keys: HashSet::from([
Expand Down Expand Up @@ -920,7 +920,7 @@ fn test_declare_tx() {
("n_steps".to_string(), 2715),
("range_check_builtin".to_string(), 63),
("pedersen_builtin".to_string(), 15),
("l1_gas_usage".to_string(), 2448),
("l1_gas_usage".to_string(), 3672),
]);
let fee = calculate_tx_fee(&resources, *GAS_PRICE, &block_context).unwrap();

Expand Down Expand Up @@ -1008,7 +1008,7 @@ fn test_declarev2_tx() {
("n_steps".to_string(), 2715),
("range_check_builtin".to_string(), 63),
("pedersen_builtin".to_string(), 15),
("l1_gas_usage".to_string(), 1224),
("l1_gas_usage".to_string(), 2448),
]);
let fee = calculate_tx_fee(&resources, *GAS_PRICE, &block_context).unwrap();

Expand Down Expand Up @@ -1226,7 +1226,7 @@ fn expected_transaction_execution_info(block_context: &BlockContext) -> Transact
let resources = HashMap::from([
("n_steps".to_string(), 4135),
("pedersen_builtin".to_string(), 16),
("l1_gas_usage".to_string(), 2448),
("l1_gas_usage".to_string(), 3672),
("range_check_builtin".to_string(), 101),
]);
let fee = calculate_tx_fee(&resources, *GAS_PRICE, block_context).unwrap();
Expand Down Expand Up @@ -1782,7 +1782,7 @@ fn test_state_for_declare_tx() {
// ])
// );

let fee = Felt252::from(2476);
let fee = Felt252::from(3700);

// Check state.cache
assert_eq!(
Expand Down