Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Add class hash check for declarev2 #819

Merged
merged 9 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ mod test {
use std::path::PathBuf;

use crate::core::contract_address::{compute_deprecated_class_hash, compute_sierra_class_hash};
use crate::definitions::constants::INITIAL_GAS_COST;
use crate::definitions::{
block_context::StarknetChainId,
constants::{
Expand Down Expand Up @@ -836,7 +837,7 @@ mod test {
tx_type: TransactionType::Declare,
validate_entry_point_selector: VALIDATE_DECLARE_ENTRY_POINT_SELECTOR.clone(),
version: 1.into(),
max_fee: 2,
max_fee: INITIAL_GAS_COST,
signature: vec![],
nonce: 0.into(),
hash_value: 0.into(),
Expand Down Expand Up @@ -998,4 +999,18 @@ mod test {
[(0, 1224), (0, 0)]
);
}

#[test]
fn test_declare_v2_with_invalid_compiled_class_hash() {
let (block_context, mut state) = create_account_tx_test_state().unwrap();
let mut declare_v2 = declarev2_tx();
declare_v2.compiled_class_hash = Felt252::from(1);
let declare_tx = Transaction::DeclareV2(Box::new(declare_v2));

let err = declare_tx
.execute(&mut state, &block_context, INITIAL_GAS_COST)
.unwrap_err();

assert_eq!(err.to_string(), "Invalid compiled class, expected class hash: \"1948962768849191111780391610229754715773924969841143100991524171924131413970\", but received: \"1\"".to_string());
}
}
2 changes: 1 addition & 1 deletion src/testing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ lazy_static! {
pub static ref TEST_CLASS_HASH: Felt252 = felt_str!("272");
pub static ref TEST_EMPTY_CONTRACT_CLASS_HASH: Felt252 = felt_str!("274");
pub static ref TEST_ERC20_CONTRACT_CLASS_HASH: Felt252 = felt_str!("4112");
pub static ref TEST_FIB_COMPILED_CONTRACT_CLASS_HASH: Felt252 = felt_str!("27727");
pub static ref TEST_FIB_COMPILED_CONTRACT_CLASS_HASH: Felt252 = felt_str!("1948962768849191111780391610229754715773924969841143100991524171924131413970");

// Storage keys.
pub static ref TEST_ERC20_ACCOUNT_BALANCE_KEY: Felt252 =
Expand Down
22 changes: 18 additions & 4 deletions src/transaction/declare_v2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{verify_version, Transaction};
use crate::core::contract_address::compute_sierra_class_hash;
use crate::core::contract_address::{compute_casm_class_hash, compute_sierra_class_hash};
use crate::definitions::constants::QUERY_VERSION_BASE;
use crate::services::api::contract_classes::deprecated_contract_class::EntryPointType;

Expand Down Expand Up @@ -381,6 +381,13 @@ impl DeclareV2 {
})
.map_err(|e| TransactionError::SierraCompileError(e.to_string()))?;

let casm_class_hash = compute_casm_class_hash(casm_class)?;
if casm_class_hash != self.compiled_class_hash {
Comment on lines +384 to +385
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to always have this check? Maybe we could have it as an optional feature, or add it only for debugging.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not necessary to always check this. We want to check this whenever a casm contract class and casm class hash is sent from the outside, because otherwise we couldn't guarantee that the compiled contract is really the corresponded to that sierra contract class. It wouldn't be a problem if the contract was already compiled on our side. I'll work on that next, but preferred to have this implemented first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in favor of making it a debug assertion. Either we computed it and they are the same by construction, or it was passed via one of the unchecked functions and the responsibility lies in the caller. The point of that was not having to compute the hash in the first place.
Unless I'm misunderstanding it, of course.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option may be to don't compile and check at all, and take the casm contract class and class hash as valid. But sound pretty insecure. We should think a better way of guarantee that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is insecure, but that's the compromise made for performance and that's why we have to be extra explicit and opt-in. If you hold it wrong, it breaks.
The only way to guarantee correctness is to compute stuff ourselves, but then we only moved the costs and increased complexity for no gain for the user. If we're gonna check, just compute it on construction and never again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think of it like using unsafe in Rust. The keyword doesn't mean you can do whatever you like, but that you go to battle without the shield, it's your job to dodge the blows now. When the programmer opts for passing the hash themselves, it's their job to make sure it's the correct one, and if they get hurt they'll deal with it.

return Err(TransactionError::InvalidCompiledClassHash(
casm_class_hash.to_string(),
self.compiled_class_hash.to_string(),
));
}
state.set_compiled_class_hash(&self.sierra_class_hash, &self.compiled_class_hash)?;
state.set_compiled_class(&self.compiled_class_hash, casm_class.clone())?;

Expand Down Expand Up @@ -453,7 +460,7 @@ mod tests {
use std::{collections::HashMap, fs::File, io::BufReader, path::PathBuf};

use super::DeclareV2;
use crate::core::contract_address::compute_sierra_class_hash;
use crate::core::contract_address::{compute_casm_class_hash, compute_sierra_class_hash};
use crate::definitions::constants::QUERY_VERSION_BASE;
use crate::services::api::contract_classes::compiled_class::CompiledClass;
use crate::state::state_api::StateReader;
Expand Down Expand Up @@ -487,12 +494,15 @@ mod tests {
let sierra_contract_class: cairo_lang_starknet::contract_class::ContractClass =
serde_json::from_reader(reader).unwrap();
let sender_address = Address(1.into());
let casm_class =
CasmContractClass::from_contract_class(sierra_contract_class.clone(), true).unwrap();
let casm_class_hash = compute_casm_class_hash(&casm_class).unwrap();

// create internal declare v2

let internal_declare = DeclareV2::new_with_tx_hash(
&sierra_contract_class,
Felt252::one(),
casm_class_hash,
sender_address,
0,
version,
Expand Down Expand Up @@ -553,12 +563,16 @@ mod tests {
serde_json::from_reader(reader).unwrap();
let sierra_class_hash = compute_sierra_class_hash(&sierra_contract_class).unwrap();
let sender_address = Address(1.into());
let casm_class =
CasmContractClass::from_contract_class(sierra_contract_class.clone(), true).unwrap();
let casm_class_hash = compute_casm_class_hash(&casm_class).unwrap();

// create internal declare v2

let internal_declare = DeclareV2::new_with_tx_hash(
let internal_declare = DeclareV2::new_with_sierra_class_hash_and_tx_hash(
&sierra_contract_class,
sierra_class_hash,
casm_class_hash,
sender_address,
0,
version,
Expand Down
2 changes: 2 additions & 0 deletions src/transaction/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,6 @@ pub enum TransactionError {
CallInfoIsNone,
#[error("Unsupported version {0:?}")]
UnsupportedVersion(String),
#[error("Invalid compiled class, expected class hash: {0:?}, but received: {1:?}")]
InvalidCompiledClassHash(String, String),
}
3 changes: 2 additions & 1 deletion starknet_programs/cairo1/fibonacci_dispatcher.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ mod Dispatcher {
class_hash: felt252, selector: felt252, a: felt252, b: felt252, n: felt252
) -> felt252 {
FibonacciLibraryDispatcher {
class_hash: starknet::class_hash_const::<27727>(),
// THIS VALUE IS THE HASH OF THE FIBONACCI CASM CLASS HASH. THE SAME AS THE CONSTANT: TEST_FIB_COMPILED_CONTRACT_CLASS_HASH
class_hash: starknet::class_hash_const::<1948962768849191111780391610229754715773924969841143100991524171924131413970>(),
selector
}.fib(a, b, n)
}
Expand Down
3 changes: 2 additions & 1 deletion starknet_programs/cairo2/fibonacci_dispatcher.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ mod Dispatcher {
self: @ContractState, class_hash: felt252, selector: felt252, a: felt252, b: felt252, n: felt252
) -> felt252 {
FibonacciLibraryDispatcher {
class_hash: starknet::class_hash_const::<27727>(),
// THIS VALUE IS THE HASH OF THE FIBONACCI CASM CLASS HASH.
class_hash: starknet::class_hash_const::<2889767417435368609058888822622483550637539736178264636938129582300971548553>(),
selector
}.fib(a, b, n)
}
Expand Down
70 changes: 61 additions & 9 deletions tests/internals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use cairo_vm::vm::{
use lazy_static::lazy_static;
use num_bigint::BigUint;
use num_traits::{FromPrimitive, Num, One, Zero};
use starknet_in_rust::core::contract_address::compute_sierra_class_hash;
use starknet_in_rust::core::contract_address::{
compute_casm_class_hash, compute_sierra_class_hash,
};
use starknet_in_rust::core::errors::state_errors::StateError;
use starknet_in_rust::definitions::constants::{
DEFAULT_CAIRO_RESOURCE_FEE_WEIGHTS, VALIDATE_ENTRY_POINT_SELECTOR,
Expand Down Expand Up @@ -74,7 +76,8 @@ lazy_static! {
static ref TEST_CLASS_HASH: Felt252 = felt_str!("272");
static ref TEST_EMPTY_CONTRACT_CLASS_HASH: Felt252 = felt_str!("274");
static ref TEST_ERC20_CONTRACT_CLASS_HASH: Felt252 = felt_str!("4112");
static ref TEST_FIB_COMPILED_CONTRACT_CLASS_HASH: Felt252 = felt_str!("27727");
static ref TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO1: Felt252 = felt_str!("1948962768849191111780391610229754715773924969841143100991524171924131413970");
static ref TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO2: Felt252 = felt_str!("2889767417435368609058888822622483550637539736178264636938129582300971548553");

// Storage keys.
// NOTE: this key corresponds to the lower 128 bits of an U256
Expand Down Expand Up @@ -699,6 +702,9 @@ fn declarev2_tx() -> DeclareV2 {
let program_data = include_bytes!("../starknet_programs/cairo1/fibonacci.sierra");
let sierra_contract_class: SierraContractClass = serde_json::from_slice(program_data).unwrap();
let sierra_class_hash = compute_sierra_class_hash(&sierra_contract_class).unwrap();
let casm_class =
CasmContractClass::from_contract_class(sierra_contract_class.clone(), true).unwrap();
let casm_class_hash = compute_casm_class_hash(&casm_class).unwrap();

DeclareV2 {
sender_address: TEST_ACCOUNT_CONTRACT_ADDRESS.clone(),
Expand All @@ -709,23 +715,32 @@ fn declarev2_tx() -> DeclareV2 {
signature: vec![],
nonce: 0.into(),
hash_value: 0.into(),
compiled_class_hash: TEST_FIB_COMPILED_CONTRACT_CLASS_HASH.clone(),
compiled_class_hash: casm_class_hash,
sierra_contract_class,
sierra_class_hash,
casm_class: Default::default(),
casm_class: casm_class.into(),
skip_execute: false,
skip_fee_transfer: false,
skip_validate: false,
}
}

fn deploy_fib_syscall() -> Deploy {
let contract_hash;
#[cfg(not(feature = "cairo_1_tests"))]
{
contract_hash = felt_to_hash(&TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO2.clone())
}
#[cfg(feature = "cairo_1_tests")]
{
contract_hash = felt_to_hash(&TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO1.clone())
}
Deploy {
hash_value: 0.into(),
version: 1.into(),
contract_address: TEST_FIB_CONTRACT_ADDRESS.clone(),
contract_address_salt: 0.into(),
contract_hash: felt_to_hash(&TEST_FIB_COMPILED_CONTRACT_CLASS_HASH.clone()),
contract_hash,
constructor_calldata: Vec::new(),
tx_type: TransactionType::Deploy,
skip_execute: false,
Expand Down Expand Up @@ -862,14 +877,23 @@ fn test_declarev2_tx() {
]);
let fee = calculate_tx_fee(&resources, *GAS_PRICE, &block_context).unwrap();

let contract_hash;
#[cfg(not(feature = "cairo_1_tests"))]
{
contract_hash = TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO2.clone();
}
#[cfg(feature = "cairo_1_tests")]
{
contract_hash = TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO1.clone();
}
let expected_execution_info = TransactionExecutionInfo::new(
Some(CallInfo {
call_type: Some(CallType::Call),
contract_address: TEST_ACCOUNT_CONTRACT_ADDRESS.clone(),
class_hash: Some(felt_to_hash(&TEST_ACCOUNT_CONTRACT_CLASS_HASH)),
entry_point_selector: Some(VALIDATE_DECLARE_ENTRY_POINT_SELECTOR.clone()),
entry_point_type: Some(EntryPointType::External),
calldata: vec![TEST_FIB_COMPILED_CONTRACT_CLASS_HASH.clone()],
calldata: vec![contract_hash],
execution_resources: ExecutionResources {
n_steps: 12,
..Default::default()
Expand Down Expand Up @@ -943,6 +967,15 @@ fn expected_execute_call_info() -> CallInfo {
}

fn expected_fib_execute_call_info() -> CallInfo {
let contract_hash;
#[cfg(not(feature = "cairo_1_tests"))]
{
contract_hash = felt_to_hash(&TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO2.clone());
}
#[cfg(feature = "cairo_1_tests")]
{
contract_hash = felt_to_hash(&TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO1.clone());
}
CallInfo {
caller_address: Address(Felt252::zero()),
call_type: Some(CallType::Call),
Expand Down Expand Up @@ -972,7 +1005,7 @@ fn expected_fib_execute_call_info() -> CallInfo {
internal_calls: vec![CallInfo {
caller_address: TEST_ACCOUNT_CONTRACT_ADDRESS.clone(),
call_type: Some(CallType::Call),
class_hash: Some(felt_to_hash(&TEST_FIB_COMPILED_CONTRACT_CLASS_HASH.clone())),
class_hash: Some(contract_hash),
entry_point_selector: Some(Felt252::from_bytes_be(&calculate_sn_keccak(b"fib"))),
entry_point_type: Some(EntryPointType::External),
calldata: vec![Felt252::from(42), Felt252::from(0), Felt252::from(0)],
Expand Down Expand Up @@ -1760,9 +1793,18 @@ fn test_library_call_with_declare_v2() {
)
};

let casm_contract_hash;
#[cfg(not(feature = "cairo_1_tests"))]
{
casm_contract_hash = TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO2.clone()
}
#[cfg(feature = "cairo_1_tests")]
{
casm_contract_hash = TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO1.clone()
}
// Create an execution entry point
let calldata = vec![
TEST_FIB_COMPILED_CONTRACT_CLASS_HASH.clone(),
casm_contract_hash,
Felt252::from_bytes_be(&calculate_sn_keccak(b"fib")),
1.into(),
1.into(),
Expand Down Expand Up @@ -1798,11 +1840,21 @@ fn test_library_call_with_declare_v2() {
)
.unwrap();

let casm_contract_hash;
#[cfg(not(feature = "cairo_1_tests"))]
{
casm_contract_hash = TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO2.clone()
}
#[cfg(feature = "cairo_1_tests")]
{
casm_contract_hash = TEST_FIB_COMPILED_CONTRACT_CLASS_HASH_CAIRO1.clone()
}

let expected_internal_call_info = CallInfo {
caller_address: Address(0.into()),
call_type: Some(CallType::Delegate),
contract_address: address.clone(),
class_hash: Some(TEST_FIB_COMPILED_CONTRACT_CLASS_HASH.clone().to_be_bytes()),
class_hash: Some(casm_contract_hash.to_be_bytes()),
entry_point_selector: Some(external_entrypoint_selector.into()),
entry_point_type: Some(EntryPointType::External),
#[cfg(not(feature = "cairo_1_tests"))]
Expand Down