Skip to content

Commit

Permalink
Mock call n times (#1839)
Browse files Browse the repository at this point in the history
<!-- Reference any GitHub issues resolved by this PR -->

Closes #

## Introduced changes

<!-- A brief description of the changes -->

- [Add
mock_call](53c9db1),
note different behaviour than other cheatcodes with CheatSpan
- [Add
tests](9158c0e)

## Checklist

<!-- Make sure all of these are complete -->

- [x] Linked relevant issue
- [x] Updated relevant documentation
- [x] Added relevant tests
- [x] Performed self-review of the code
- [x] Added changes to `CHANGELOG.md`
  • Loading branch information
drknzz authored Mar 11, 2024
1 parent 6003c5b commit 28dd681
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::cairo1_execution::execute_entry_point_call_cairo1;
use crate::runtime_extensions::call_to_blockifier_runtime_extension::execution::deprecated::cairo0_execution::execute_entry_point_call_cairo0;
use crate::runtime_extensions::call_to_blockifier_runtime_extension::RuntimeState;
use crate::state::CheatnetState;
use crate::state::{CheatStatus, CheatnetState};
use blockifier::execution::call_info::{CallExecution, Retdata};
use blockifier::{
execution::{
Expand Down Expand Up @@ -64,15 +64,18 @@ pub fn execute_call_entry_point(
CallType::Delegate => AddressOrClassHash::ClassHash(entry_point.class_hash.unwrap()),
};

if let Some(ret_data) =
get_ret_data_by_call_entry_point(entry_point, runtime_state.cheatnet_state)
if let Some(cheat_status) =
get_mocked_function_cheat_status(entry_point, runtime_state.cheatnet_state)
{
runtime_state.cheatnet_state.trace_data.exit_nested_call(
resources,
&Ok(CallInfo::default()),
&identifier,
);
return Ok(mocked_call_info(entry_point.clone(), ret_data));
if let CheatStatus::Cheated(ret_data, _) = (*cheat_status).clone() {
cheat_status.decrement_cheat_span();
runtime_state.cheatnet_state.trace_data.exit_nested_call(
resources,
&Ok(CallInfo::default()),
&identifier,
);
return Ok(mocked_call_info(entry_point.clone(), ret_data.clone()));
}
}
// endregion

Expand Down Expand Up @@ -207,21 +210,18 @@ pub fn execute_constructor_entry_point(
// endregion
}

fn get_ret_data_by_call_entry_point(
fn get_mocked_function_cheat_status<'a>(
call: &CallEntryPoint,
cheatnet_state: &CheatnetState,
) -> Option<Vec<StarkFelt>> {
if let Some(contract_address) = call.code_address {
if let Some(contract_functions) = cheatnet_state.mocked_functions.get(&contract_address) {
let entrypoint_selector = call.entry_point_selector;

let ret_data = contract_functions
.get(&entrypoint_selector)
.map(Clone::clone);
return ret_data;
}
cheatnet_state: &'a mut CheatnetState,
) -> Option<&'a mut CheatStatus<Vec<StarkFelt>>> {
if call.call_type == CallType::Delegate {
return None;
}
None

cheatnet_state
.mocked_functions
.get_mut(&call.storage_address)
.and_then(|contract_functions| contract_functions.get_mut(&call.entry_point_selector))
}

fn mocked_call_info(call: CallEntryPoint, ret_data: Vec<StarkFelt>) -> CallInfo {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,42 @@
use crate::state::{CheatSpan, CheatStatus};
use crate::CheatnetState;
use blockifier::execution::execution_utils::felt_to_stark_felt;
use cairo_felt::Felt252;
use conversions::IntoConv;
use starknet_api::core::ContractAddress;
use starknet_api::core::{ContractAddress, EntryPointSelector};
use starknet_api::hash::StarkFelt;
use std::collections::hash_map::Entry;

impl CheatnetState {
pub fn start_mock_call(
pub fn mock_call(
&mut self,
contract_address: ContractAddress,
function_selector: Felt252,
ret_data: &[Felt252],
span: CheatSpan,
) {
let ret_data: Vec<StarkFelt> = ret_data.iter().map(felt_to_stark_felt).collect();

let contract_mocked_functions = self.mocked_functions.entry(contract_address).or_default();

contract_mocked_functions.insert(function_selector.into_(), ret_data);
contract_mocked_functions.insert(
EntryPointSelector(function_selector.into_()),
CheatStatus::Cheated(ret_data, span),
);
}

pub fn start_mock_call(
&mut self,
contract_address: ContractAddress,
function_selector: Felt252,
ret_data: &[Felt252],
) {
self.mock_call(
contract_address,
function_selector,
ret_data,
CheatSpan::Indefinite,
);
}

pub fn stop_mock_call(
Expand Down
5 changes: 3 additions & 2 deletions crates/cheatnet/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl StateReader for ExtendedStateReader {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum CheatStatus<T> {
Cheated(T, CheatSpan),
Uncheated,
Expand Down Expand Up @@ -245,7 +245,8 @@ pub struct CheatnetState {
pub global_warp: Option<(Felt252, CheatSpan)>,
pub elected_contracts: HashMap<ContractAddress, CheatStatus<ContractAddress>>,
pub global_elect: Option<(ContractAddress, CheatSpan)>,
pub mocked_functions: HashMap<ContractAddress, HashMap<EntryPointSelector, Vec<StarkFelt>>>,
pub mocked_functions:
HashMap<ContractAddress, HashMap<EntryPointSelector, CheatStatus<Vec<StarkFelt>>>>,
pub spoofed_contracts: HashMap<ContractAddress, CheatStatus<TxInfoMock>>,
pub global_spoof: Option<(TxInfoMock, CheatSpan)>,
pub replaced_bytecode_contracts: HashMap<ContractAddress, ClassHash>,
Expand Down
210 changes: 203 additions & 7 deletions crates/cheatnet/tests/cheatcodes/mock_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,68 @@ use crate::{
};
use cairo_felt::Felt252;
use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::declare::declare;
use cheatnet::state::CheatnetState;
use cheatnet::state::{CheatSpan, CheatnetState};
use conversions::IntoConv;
use starknet::core::utils::get_selector_from_name;
use starknet_api::core::ContractAddress;

use super::test_environment::TestEnvironment;

trait MockCallTrait {
fn mock_call(
&mut self,
contract_address: &ContractAddress,
function_name: &str,
ret_data: &[u128],
span: CheatSpan,
);
fn start_mock_call(
&mut self,
contract_address: &ContractAddress,
function_selector: Felt252,
function_name: &str,
ret_data: &[u128],
);
fn stop_mock_call(&mut self, contract_address: &ContractAddress, function_selector: Felt252);
fn stop_mock_call(&mut self, contract_address: &ContractAddress, function_name: &str);
}

impl<'a> MockCallTrait for TestEnvironment<'a> {
fn mock_call(
&mut self,
contract_address: &ContractAddress,
function_name: &str,
ret_data: &[u128],
span: CheatSpan,
) {
let ret_data: Vec<Felt252> = ret_data.iter().map(|x| Felt252::from(*x)).collect();
let function_selector = get_selector_from_name(function_name).unwrap();
self.runtime_state.cheatnet_state.mock_call(
*contract_address,
function_selector.into_(),
&ret_data,
span,
);
}

fn start_mock_call(
&mut self,
contract_address: &ContractAddress,
function_selector: Felt252,
function_name: &str,
ret_data: &[u128],
) {
let ret_data: Vec<Felt252> = ret_data.iter().map(|x| Felt252::from(*x)).collect();
let function_selector = get_selector_from_name(function_name).unwrap();
self.runtime_state.cheatnet_state.start_mock_call(
*contract_address,
function_selector,
function_selector.into_(),
&ret_data,
);
}

fn stop_mock_call(&mut self, contract_address: &ContractAddress, function_selector: Felt252) {
fn stop_mock_call(&mut self, contract_address: &ContractAddress, function_name: &str) {
let function_selector = get_selector_from_name(function_name).unwrap();
self.runtime_state
.cheatnet_state
.stop_mock_call(*contract_address, function_selector);
.stop_mock_call(*contract_address, function_selector.into_());
}
}

Expand Down Expand Up @@ -627,3 +654,172 @@ fn mock_call_nonexisting_contract() {

assert_success(output, &ret_data);
}

#[test]
fn mock_call_simple_with_span() {
let mut cheatnet_state = CheatnetState::default();
let mut test_env = TestEnvironment::new(&mut cheatnet_state);

let contract_address = test_env.deploy("MockChecker", &[Felt252::from(420)]);

test_env.mock_call(&contract_address, "get_thing", &[123], CheatSpan::Number(2));

assert_success(
test_env.call_contract(&contract_address, "get_thing", &[]),
&[Felt252::from(123)],
);
assert_success(
test_env.call_contract(&contract_address, "get_thing", &[]),
&[Felt252::from(123)],
);
assert_success(
test_env.call_contract(&contract_address, "get_thing", &[]),
&[Felt252::from(420)],
);
}

#[test]
fn mock_call_proxy_with_span() {
let mut cheatnet_state = CheatnetState::default();
let mut test_env = TestEnvironment::new(&mut cheatnet_state);

let contract_address = test_env.deploy("MockChecker", &[Felt252::from(420)]);
let proxy_address = test_env.deploy("MockCheckerProxy", &[]);

test_env.mock_call(&contract_address, "get_thing", &[123], CheatSpan::Number(2));

assert_success(
test_env.call_contract(&contract_address, "get_thing", &[]),
&[Felt252::from(123)],
);
assert_success(
test_env.call_contract(
&proxy_address,
"get_thing_from_contract",
&[contract_address.into_()],
),
&[Felt252::from(123)],
);
assert_success(
test_env.call_contract(
&proxy_address,
"get_thing_from_contract",
&[contract_address.into_()],
),
&[Felt252::from(420)],
);
}

#[test]
fn mock_call_in_constructor_with_span() {
let mut cheatnet_state = CheatnetState::default();
let mut test_env = TestEnvironment::new(&mut cheatnet_state);

let contracts = get_contracts();

let balance_address = test_env.deploy("HelloStarknet", &[]);

let class_hash = test_env.declare("ConstructorMockChecker", &contracts);
let precalculated_address = test_env
.runtime_state
.cheatnet_state
.precalculate_address(&class_hash, &[balance_address.into_()]);

test_env.mock_call(
&balance_address,
"get_balance",
&[111],
CheatSpan::Number(2),
);

let contract_address = test_env.deploy_wrapper(&class_hash, &[balance_address.into_()]);
assert_eq!(precalculated_address, contract_address);

assert_success(
test_env.call_contract(&contract_address, "get_constructor_balance", &[]),
&[Felt252::from(111)],
);
assert_success(
test_env.call_contract(&balance_address, "get_balance", &[]),
&[Felt252::from(111)],
);
assert_success(
test_env.call_contract(&balance_address, "get_balance", &[]),
&[Felt252::from(0)],
);
}

#[test]
fn mock_call_twice_in_function() {
let mut cheatnet_state = CheatnetState::default();
let mut test_env = TestEnvironment::new(&mut cheatnet_state);

let contracts = get_contracts();

let class_hash = test_env.declare("MockChecker", &contracts);
let precalculated_address = test_env
.runtime_state
.cheatnet_state
.precalculate_address(&class_hash, &[111.into()]);

test_env.mock_call(
&precalculated_address,
"get_thing",
&[222],
CheatSpan::Number(2),
);

let contract_address = test_env.deploy_wrapper(&class_hash, &[111.into()]);
assert_eq!(precalculated_address, contract_address);

assert_success(
test_env.call_contract(&contract_address, "get_thing", &[]),
&[222.into()],
);
assert_success(
test_env.call_contract(&contract_address, "get_thing_twice", &[]),
&[222.into(), 111.into()],
);
assert_success(
test_env.call_contract(&contract_address, "get_thing", &[]),
&[111.into()],
);
}

#[test]
fn mock_call_override_span() {
let mut cheatnet_state = CheatnetState::default();
let mut test_env = TestEnvironment::new(&mut cheatnet_state);

let contract_address = test_env.deploy("MockChecker", &[111.into()]);

test_env.mock_call(&contract_address, "get_thing", &[222], CheatSpan::Number(2));

assert_success(
test_env.call_contract(&contract_address, "get_thing", &[]),
&[Felt252::from(222)],
);

test_env.mock_call(
&contract_address,
"get_thing",
&[333],
CheatSpan::Indefinite,
);

assert_success(
test_env.call_contract(&contract_address, "get_thing", &[]),
&[Felt252::from(333)],
);
assert_success(
test_env.call_contract(&contract_address, "get_thing", &[]),
&[Felt252::from(333)],
);

test_env.stop_mock_call(&contract_address, "get_thing");

assert_success(
test_env.call_contract(&contract_address, "get_thing", &[]),
&[111.into()],
);
}
Loading

0 comments on commit 28dd681

Please sign in to comment.