Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

add test to check cairo 2 account contract deploy panic failing properly #1045

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions starknet_programs/cairo2/account_panic.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
use starknet::account::Call;

mod SUPPORTED_TX_VERSION {
const DEPLOY_ACCOUNT: felt252 = 1;
const DECLARE: felt252 = 2;
const INVOKE: felt252 = 1;
}

#[starknet::interface]
trait IAccount<T> {
fn is_valid_signature(self: @T, hash: felt252, signature: Array<felt252>) -> felt252;
fn supports_interface(self: @T, interface_id: felt252) -> bool;
fn public_key(self: @T) -> felt252;
}

#[starknet::contract]
mod Account {
use super::{Call, IAccount, SUPPORTED_TX_VERSION};
use starknet::{get_caller_address, call_contract_syscall, get_tx_info, VALIDATED};
use zeroable::Zeroable;
use array::{ArrayTrait, SpanTrait};
use ecdsa::check_ecdsa_signature;
use box::BoxTrait;
use result::ResultTrait;

const SIMULATE_TX_VERSION_OFFSET: felt252 = 340282366920938463463374607431768211456; // 2**128
const SRC6_TRAIT_ID: felt252 = 1270010605630597976495846281167968799381097569185364931397797212080166453709; // hash of SNIP-6 trait

#[storage]
struct Storage {
public_key: felt252
}

#[constructor]
fn constructor(ref self: ContractState, public_key: felt252) {
self.public_key.write(public_key);
}

#[external(v0)]
impl AccountImpl of IAccount<ContractState> {
fn is_valid_signature(self: @ContractState, hash: felt252, signature: Array<felt252>) -> felt252 {
let is_valid = self.is_valid_signature_bool(hash, signature.span());
if is_valid { VALIDATED } else { 0 }
}

fn supports_interface(self: @ContractState, interface_id: felt252) -> bool {
interface_id == SRC6_TRAIT_ID
}

fn public_key(self: @ContractState) -> felt252 {
self.public_key.read()
}
}

#[external(v0)]
#[generate_trait]
impl ProtocolImpl of ProtocolTrait {
fn __execute__(ref self: ContractState, calls: Array<Call>) -> Array<Span<felt252>> {
let arr = ArrayTrait::new();
panic_with_felt252('panic');
arr
//self.only_protocol();
// self.only_supported_tx_version(SUPPORTED_TX_VERSION::INVOKE);
// self.execute_multiple_calls(calls)
}

fn __validate__(self: @ContractState, calls: Array<Call>) -> felt252 {
panic_with_felt252('panic');
0
// self.only_protocol();
// self.only_supported_tx_version(SUPPORTED_TX_VERSION::INVOKE);
// self.validate_transaction()
}

fn __validate_declare__(self: @ContractState, class_hash: felt252) -> felt252 {
self.only_protocol();
self.only_supported_tx_version(SUPPORTED_TX_VERSION::DECLARE);
self.validate_transaction()
}

fn __validate_deploy__(self: @ContractState, class_hash: felt252, salt: felt252, public_key: felt252) -> felt252 {
self.only_protocol();
self.only_supported_tx_version(SUPPORTED_TX_VERSION::DEPLOY_ACCOUNT);
self.validate_transaction()
}
}

#[generate_trait]
impl PrivateImpl of PrivateTrait {
fn only_protocol(self: @ContractState) {
let sender = get_caller_address();
assert(sender.is_zero(), 'Account: invalid caller');
}

fn is_valid_signature_bool(self: @ContractState, hash: felt252, signature: Span<felt252>) -> bool {
let is_valid_length = signature.len() == 2_u32;

if !is_valid_length {
return false;
}

check_ecdsa_signature(
hash, self.public_key.read(), *signature.at(0_u32), *signature.at(1_u32)
)
}

fn validate_transaction(self: @ContractState) -> felt252 {
let tx_info = get_tx_info().unbox();
let tx_hash = tx_info.transaction_hash;
let signature = tx_info.signature;

let is_valid = self.is_valid_signature_bool(tx_hash, signature);
assert(is_valid, 'Account: Incorrect tx signature');
VALIDATED
}

fn execute_single_call(self: @ContractState, call: Call) -> Span<felt252> {
let Call{to, selector, calldata} = call;
call_contract_syscall(to, selector, calldata.span()).unwrap()
}

fn execute_multiple_calls(self: @ContractState, mut calls: Array<Call>) -> Array<Span<felt252>> {
let mut res = ArrayTrait::new();
loop {
match calls.pop_front() {
Option::Some(call) => {
let _res = self.execute_single_call(call);
res.append(_res);
},
Option::None(_) => {
break ();
},
};
};
res
}

fn only_supported_tx_version(self: @ContractState, supported_tx_version: felt252) {
let tx_info = get_tx_info().unbox();
let version = tx_info.version;
assert(
version == supported_tx_version ||
version == SIMULATE_TX_VERSION_OFFSET + supported_tx_version,
'Account: Unsupported tx version'
);
}
}
}
113 changes: 113 additions & 0 deletions tests/account_panic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use std::{collections::HashMap, sync::Arc};

use cairo_vm::felt::Felt252;
use starknet_in_rust::{
core::contract_address::compute_casm_class_hash,
definitions::{block_context::BlockContext, constants::TRANSACTION_VERSION},
services::api::contract_classes::compiled_class::CompiledClass,
state::{cached_state::CachedState, in_memory_state_reader::InMemoryStateReader},
transaction::{InvokeFunction, Transaction},
utils::{calculate_sn_keccak, Address},
CasmContractClass,
};

#[test]
fn account_panic() {
let account_data = include_bytes!("../starknet_programs/cairo2/account_panic.casm");
let contract_data = include_bytes!("../starknet_programs/cairo2/contract_a.casm");

let account_contract_class: CasmContractClass = serde_json::from_slice(account_data).unwrap();
let account_class_hash = compute_casm_class_hash(&account_contract_class)
.unwrap()
.to_be_bytes();

let contract_class: CasmContractClass = serde_json::from_slice(contract_data).unwrap();
let contract_class_hash_felt = compute_casm_class_hash(&contract_class).unwrap();
let contract_class_hash = contract_class_hash_felt.to_be_bytes();

let account_address = Address(1111.into());
let contract_address = Address(0000.into());
let nonce = 0.into();

let block_context = BlockContext::default();

let mut contract_class_cache = HashMap::new();

contract_class_cache.insert(
account_class_hash,
CompiledClass::Casm(Arc::new(account_contract_class)),
);
contract_class_cache.insert(
contract_class_hash,
CompiledClass::Casm(Arc::new(contract_class.clone())),
);

let mut state_reader = InMemoryStateReader::default();
state_reader
.address_to_class_hash_mut()
.insert(account_address.clone(), account_class_hash);
state_reader
.address_to_nonce_mut()
.insert(account_address.clone(), nonce);
state_reader
.address_to_class_hash_mut()
.insert(contract_address.clone(), contract_class_hash);
state_reader
.address_to_nonce_mut()
.insert(contract_address, 1.into());
let mut state = CachedState::new(Arc::new(state_reader), contract_class_cache);

let selector = Felt252::from_bytes_be(&calculate_sn_keccak(b"__execute__"));

// arguments of contract_a contract
// calldata is a Vec of Call, which is
/*
#[derive(Drop, Serde)]
struct Call {
to: ContractAddress,
selector: felt252,
calldata: Array<felt252>
}
*/
let selector_contract = &contract_class
.entry_points_by_type
.external
.get(0)
.unwrap()
.selector;
// calldata of contract_a is 1 value.
let calldata: Vec<_> = [
1.into(),
contract_class_hash_felt,
selector_contract.into(),
1.into(),
2.into(),
]
.to_vec();

// set up remaining structures

let invoke = InvokeFunction::new(
account_address,
Felt252::new(selector),
0,
TRANSACTION_VERSION.clone(),
calldata,
vec![],
block_context.starknet_os_config().chain_id().clone(),
Some(0.into()),
)
.unwrap();

let tx = Transaction::InvokeFunction(invoke);
let exec_info = tx
.execute(&mut state, &block_context, u128::MAX)
.expect("failed to invoke");
let call_info = exec_info.call_info.as_ref().unwrap();

assert_eq!(exec_info.revert_error, None);

// 482670963043u128 == 'panic'
assert_eq!(call_info.retdata[0], 482670963043u128.into());
assert!(call_info.failure_flag);
}