Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Replace class contract call native #1115

Merged
merged 8 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/cairo_1_syscalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,8 @@ fn replace_class_contract_call() {
block_context.invoke_tx_max_n_steps(),
)
.unwrap();
assert_eq!(result.call_info.unwrap().retdata, vec![17.into()]);
assert_eq!(result.call_info.clone().unwrap().retdata, vec![17.into()]);
assert_eq!(result.call_info.unwrap().failure_flag, false);
}

#[test]
Expand Down
223 changes: 222 additions & 1 deletion tests/cairo_native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ fn replace_class_test() {
assert_eq!(native_result.l2_to_l1_messages, vm_result.l2_to_l1_messages);
assert_eq!(native_result.gas_consumed, vm_result.gas_consumed);
assert_eq!(native_result.failure_flag, vm_result.failure_flag);
assert_eq!(native_result.internal_calls, vm_result.internal_calls);
assert_eq_sorted!(native_result.internal_calls, vm_result.internal_calls);
assert_eq!(native_result.class_hash.unwrap(), CLASS_HASH_A);
assert_eq!(vm_result.class_hash.unwrap(), CASM_CLASS_HASH_A);
assert_eq!(native_result.caller_address, caller_address);
Expand All @@ -874,6 +874,227 @@ fn replace_class_test() {
assert_eq!(native_result.entry_point_type, vm_result.entry_point_type);
}

#[test]
fn replace_class_contract_call() {
fn compare_results(native_result: CallInfo, vm_result: CallInfo) {
assert_eq!(vm_result.retdata, native_result.retdata);
assert_eq!(vm_result.events, native_result.events);
assert_eq!(
vm_result.accessed_storage_keys,
native_result.accessed_storage_keys
);
assert_eq!(vm_result.l2_to_l1_messages, native_result.l2_to_l1_messages);
assert_eq!(vm_result.gas_consumed, native_result.gas_consumed);
assert_eq!(vm_result.failure_flag, false);
assert_eq!(native_result.failure_flag, false);
assert_eq_sorted!(vm_result.internal_calls, native_result.internal_calls);
assert_eq!(
vm_result.accessed_storage_keys,
native_result.accessed_storage_keys
);
assert_eq!(
vm_result.storage_read_values,
native_result.storage_read_values
);
assert_eq!(vm_result.class_hash, native_result.class_hash);
}
// Same execution than cairo_1_syscalls.rs test but comparing results to native execution.

// SET GET_NUMBER_A
// Add get_number_a.cairo to storage
let program_data = include_bytes!("../starknet_programs/cairo2/get_number_a.casm");
let casm_contract_class_a: CasmContractClass = serde_json::from_slice(program_data).unwrap();

let sierra_class_a: cairo_lang_starknet::contract_class::ContractClass = serde_json::from_str(
std::fs::read_to_string("starknet_programs/cairo2/get_number_a.sierra")
.unwrap()
.as_str(),
)
.unwrap();

// Create state reader with class hash data
let mut contract_class_cache = HashMap::new();
let mut native_contract_class_cache = HashMap::new();

let address = Address(Felt252::one());
let class_hash_a: ClassHash = [1; 32];
let nonce = Felt252::zero();

contract_class_cache.insert(
class_hash_a,
CompiledClass::Casm(Arc::new(casm_contract_class_a)),
);
insert_sierra_class_into_cache(
&mut native_contract_class_cache,
class_hash_a,
sierra_class_a,
);

let mut state_reader = InMemoryStateReader::default();
state_reader
.address_to_class_hash_mut()
.insert(address.clone(), class_hash_a);
state_reader
.address_to_nonce_mut()
.insert(address.clone(), nonce.clone());

let mut native_state_reader = InMemoryStateReader::default();
native_state_reader
.address_to_class_hash_mut()
.insert(address.clone(), class_hash_a);

// SET GET_NUMBER_B

// Add get_number_b contract to the state (only its contract_class)

let program_data = include_bytes!("../starknet_programs/cairo2/get_number_b.casm");
let contract_class_b: CasmContractClass = serde_json::from_slice(program_data).unwrap();

let sierra_class_b: cairo_lang_starknet::contract_class::ContractClass = serde_json::from_str(
std::fs::read_to_string("starknet_programs/cairo2/get_number_b.sierra")
.unwrap()
.as_str(),
)
.unwrap();
let class_hash_b: ClassHash = [2; 32];

contract_class_cache.insert(
class_hash_b,
CompiledClass::Casm(Arc::new(contract_class_b)),
);
insert_sierra_class_into_cache(
&mut native_contract_class_cache,
class_hash_b,
sierra_class_b,
);

// SET GET_NUMBER_WRAPPER

// Create program and entry point types for contract class
let program_data = include_bytes!("../starknet_programs/cairo2/get_number_wrapper.casm");
let wrapper_contract_class: CasmContractClass = serde_json::from_slice(program_data).unwrap();
let entrypoints = wrapper_contract_class.clone().entry_points_by_type;
let get_number_entrypoint_selector = &entrypoints.external.get(1).unwrap().selector;
let upgrade_entrypoint_selector: &BigUint = &entrypoints.external.get(0).unwrap().selector;

let wrapper_sierra_class: cairo_lang_starknet::contract_class::ContractClass =
serde_json::from_str(
std::fs::read_to_string("starknet_programs/cairo2/get_number_wrapper.sierra")
.unwrap()
.as_str(),
)
.unwrap();
let native_entrypoints = wrapper_sierra_class.clone().entry_points_by_type;

let native_get_number_entrypoint_selector =
&native_entrypoints.external.get(1).unwrap().selector;
let native_upgrade_entrypoint_selector: &BigUint =
&native_entrypoints.external.get(0).unwrap().selector;

let wrapper_address = Address(Felt252::from(2));
let wrapper_class_hash: ClassHash = [3; 32];

contract_class_cache.insert(
wrapper_class_hash,
CompiledClass::Casm(Arc::new(wrapper_contract_class)),
);
insert_sierra_class_into_cache(
&mut native_contract_class_cache,
wrapper_class_hash,
wrapper_sierra_class,
);

state_reader
.address_to_class_hash_mut()
.insert(wrapper_address.clone(), wrapper_class_hash);
state_reader
.address_to_nonce_mut()
.insert(wrapper_address.clone(), nonce);

native_state_reader
.address_to_class_hash_mut()
.insert(wrapper_address, wrapper_class_hash);

// Create state from the state_reader and contract cache.
let mut state = CachedState::new(Arc::new(state_reader.clone()), contract_class_cache.clone());
let mut native_state = CachedState::new(Arc::new(state_reader), contract_class_cache);
// CALL GET_NUMBER BEFORE REPLACE_CLASS

let calldata = [].to_vec();
let caller_address = Address(0000.into());
let entry_point_type = EntryPointType::External;

let vm_result = execute(
&mut state,
&caller_address,
&address,
get_number_entrypoint_selector,
&calldata,
entry_point_type,
&wrapper_class_hash,
);

let native_result = execute(
&mut native_state,
&caller_address,
&address,
native_get_number_entrypoint_selector,
&calldata,
entry_point_type,
&wrapper_class_hash,
);
compare_results(native_result, vm_result);

// REPLACE_CLASS

let calldata = [Felt252::from_bytes_be(&class_hash_b)].to_vec();

let vm_result = execute(
&mut state,
&caller_address,
&address,
upgrade_entrypoint_selector,
&calldata,
entry_point_type,
&wrapper_class_hash,
);

let native_result = execute(
&mut native_state,
&caller_address,
&address,
native_upgrade_entrypoint_selector,
&calldata,
entry_point_type,
&wrapper_class_hash,
);
compare_results(native_result, vm_result);
// CALL GET_NUMBER AFTER REPLACE_CLASS

let calldata = [].to_vec();

let vm_result = execute(
&mut state,
&caller_address,
&address,
get_number_entrypoint_selector,
&calldata,
entry_point_type,
&wrapper_class_hash,
);

let native_result = execute(
&mut native_state,
&caller_address,
&address,
native_get_number_entrypoint_selector,
&calldata,
entry_point_type,
&wrapper_class_hash,
);
compare_results(native_result, vm_result);
}

fn execute(
state: &mut CachedState<InMemoryStateReader>,
caller_address: &Address,
Expand Down