Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Save SierraProgram + ContractEntryPoints instead of SierraContractClass in CompiledProgram #1112

Merged
merged 2 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/execution/execution_entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ use crate::{
validate_contract_deployed, Address,
},
};
use cairo_lang_sierra::program::Program as SierraProgram;
use cairo_lang_starknet::casm_contract_class::{CasmContractClass, CasmContractEntryPoint};
use cairo_lang_starknet::contract_class::ContractEntryPoints;
use cairo_vm::{
felt::Felt252,
types::{
Expand Down Expand Up @@ -164,12 +166,12 @@ impl ExecutionEntryPoint {
}
}
}
CompiledClass::Sierra(sierra_contract_class) => {
CompiledClass::Sierra(sierra_program_and_entrypoints) => {
let mut transactional_state = state.create_transactional();

match self.native_execute(
&mut transactional_state,
sierra_contract_class,
sierra_program_and_entrypoints,
tx_execution_context,
block_context,
) {
Expand Down Expand Up @@ -622,7 +624,7 @@ impl ExecutionEntryPoint {
fn native_execute<S: StateReader>(
&self,
_state: &mut CachedState<S>,
_contract_class: Arc<cairo_lang_starknet::contract_class::ContractClass>,
_sierra_program_and_entrypoints: Arc<(SierraProgram, ContractEntryPoints)>,
_tx_execution_context: &mut TransactionExecutionContext,
_block_context: &BlockContext,
) -> Result<CallInfo, TransactionError> {
Expand All @@ -636,7 +638,7 @@ impl ExecutionEntryPoint {
fn native_execute<S: StateReader>(
&self,
state: &mut CachedState<S>,
contract_class: Arc<cairo_lang_starknet::contract_class::ContractClass>,
sierra_program_and_entrypoints: Arc<(SierraProgram, ContractEntryPoints)>,
tx_execution_context: &TransactionExecutionContext,
block_context: &BlockContext,
) -> Result<CallInfo, TransactionError> {
Expand All @@ -647,34 +649,32 @@ impl ExecutionEntryPoint {
use serde_json::json;

use crate::syscalls::business_logic_syscall_handler::SYSCALL_BASE;
let sierra_program = &sierra_program_and_entrypoints.0;
let contract_entrypoints = &sierra_program_and_entrypoints.1;

let entry_point = match self.entry_point_type {
EntryPointType::External => contract_class
.entry_points_by_type
EntryPointType::External => contract_entrypoints
.external
.iter()
.find(|entry_point| entry_point.selector == self.entry_point_selector.to_biguint())
.unwrap(),
EntryPointType::Constructor => contract_class
.entry_points_by_type
EntryPointType::Constructor => contract_entrypoints
.constructor
.iter()
.find(|entry_point| entry_point.selector == self.entry_point_selector.to_biguint())
.unwrap(),
EntryPointType::L1Handler => contract_class
.entry_points_by_type
EntryPointType::L1Handler => contract_entrypoints
.l1_handler
.iter()
.find(|entry_point| entry_point.selector == self.entry_point_selector.to_biguint())
.unwrap(),
};

let sierra_program = contract_class.extract_sierra_program().unwrap();
let program_registry: ProgramRegistry<CoreType, CoreLibfunc> =
ProgramRegistry::new(&sierra_program).unwrap();
ProgramRegistry::new(sierra_program).unwrap();

let native_context = NativeContext::new();
let mut native_program = native_context.compile(&sierra_program).unwrap();
let mut native_program = native_context.compile(sierra_program).unwrap();
let contract_storage_state =
ContractStorageState::new(state, self.contract_address.clone());

Expand Down
3 changes: 2 additions & 1 deletion src/services/api/contract_classes/compiled_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::services::api::contract_classes::deprecated_contract_class::AbiType;
use crate::{ContractEntryPoint, EntryPointType};

use super::deprecated_contract_class::ContractClass;
use cairo_lang_sierra::program::Program as SierraProgram;
use cairo_lang_starknet::abi::Contract;
use cairo_lang_starknet::casm_contract_class::CasmContractClass;
use cairo_lang_starknet::contract_class::{
Expand All @@ -24,7 +25,7 @@ use starknet::core::types::ContractClass::{Legacy, Sierra};
pub enum CompiledClass {
Deprecated(Arc<ContractClass>),
Casm(Arc<CasmContractClass>),
Sierra(Arc<SierraContractClass>),
Sierra(Arc<(SierraProgram, ContractEntryPoints)>),
}

impl TryInto<CasmContractClass> for CompiledClass {
Expand Down
2 changes: 0 additions & 2 deletions src/state/cached_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ use std::{
sync::Arc,
};

pub type SierraProgramCache =
HashMap<ClassHash, cairo_lang_starknet::contract_class::ContractClass>;
pub type ContractClassCache = HashMap<ClassHash, CompiledClass>;

pub const UNINITIALIZED_CLASS_HASH: &ClassHash = &[0u8; 32];
Expand Down
4 changes: 3 additions & 1 deletion src/syscalls/native_syscall_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,9 @@ where
.ok_or(ContractClassError::NoneEntryPointType)?
.is_empty()),
CompiledClass::Casm(class) => Ok(class.entry_points_by_type.constructor.is_empty()),
CompiledClass::Sierra(class) => Ok(class.entry_points_by_type.constructor.is_empty()),
CompiledClass::Sierra(sierra_program_and_entrypoints) => {
Ok(sierra_program_and_entrypoints.1.constructor.is_empty())
}
}
}
}
70 changes: 48 additions & 22 deletions tests/cairo_native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use crate::CallType::Call;
use cairo_lang_starknet::casm_contract_class::CasmContractEntryPoints;
use cairo_lang_starknet::contract_class::ContractClass;
use cairo_lang_starknet::contract_class::ContractEntryPoints;
use cairo_vm::felt::Felt252;
use num_bigint::BigUint;
Expand All @@ -27,6 +28,19 @@ use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;

fn insert_sierra_class_into_cache(
contract_class_cache: &mut HashMap<ClassHash, CompiledClass>,
class_hash: ClassHash,
sierra_class: ContractClass,
) {
let sierra_program = sierra_class.extract_sierra_program().unwrap();
let entry_points = sierra_class.entry_points_by_type;
contract_class_cache.insert(
class_hash,
CompiledClass::Sierra(Arc::new((sierra_program, entry_points))),
);
}

#[test]
fn integration_test_erc20() {
let sierra_contract_class: cairo_lang_starknet::contract_class::ContractClass =
Expand Down Expand Up @@ -54,9 +68,10 @@ fn integration_test_erc20() {

let caller_address = Address(123456789.into());

contract_class_cache.insert(
insert_sierra_class_into_cache(
&mut contract_class_cache,
NATIVE_CLASS_HASH,
CompiledClass::Sierra(Arc::new(sierra_contract_class)),
sierra_contract_class,
);
contract_class_cache.insert(
CASM_CLASS_HASH,
Expand Down Expand Up @@ -449,13 +464,16 @@ fn call_contract_test() {
let callee_class_hash: ClassHash = [2; 32];
let callee_nonce = Felt252::zero();

contract_class_cache.insert(
insert_sierra_class_into_cache(
&mut contract_class_cache,
caller_class_hash,
CompiledClass::Sierra(Arc::new(caller_contract_class)),
caller_contract_class,
);
contract_class_cache.insert(

insert_sierra_class_into_cache(
&mut contract_class_cache,
callee_class_hash,
CompiledClass::Sierra(Arc::new(callee_contract_class)),
callee_contract_class,
);

let mut state_reader = InMemoryStateReader::default();
Expand Down Expand Up @@ -534,14 +552,16 @@ fn call_echo_contract_test() {
let callee_class_hash: ClassHash = [2; 32];
let callee_nonce = Felt252::zero();

contract_class_cache.insert(
insert_sierra_class_into_cache(
&mut contract_class_cache,
caller_class_hash,
CompiledClass::Sierra(Arc::new(caller_contract_class)),
caller_contract_class,
);

contract_class_cache.insert(
insert_sierra_class_into_cache(
&mut contract_class_cache,
callee_class_hash,
CompiledClass::Sierra(Arc::new(callee_contract_class)),
callee_contract_class,
);

let mut state_reader = InMemoryStateReader::default();
Expand Down Expand Up @@ -622,14 +642,16 @@ fn call_events_contract_test() {
let callee_class_hash: ClassHash = [2; 32];
let callee_nonce = Felt252::zero();

contract_class_cache.insert(
insert_sierra_class_into_cache(
&mut contract_class_cache,
caller_class_hash,
CompiledClass::Sierra(Arc::new(caller_contract_class)),
caller_contract_class,
);

contract_class_cache.insert(
insert_sierra_class_into_cache(
&mut contract_class_cache,
callee_class_hash,
CompiledClass::Sierra(Arc::new(callee_contract_class)),
callee_contract_class,
);

let mut state_reader = InMemoryStateReader::default();
Expand Down Expand Up @@ -840,14 +862,16 @@ fn deploy_syscall_test() {
let deployee_class_hash: ClassHash = Felt252::one().to_be_bytes();
let _deployee_nonce = Felt252::zero();

contract_class_cache.insert(
insert_sierra_class_into_cache(
&mut contract_class_cache,
deployer_class_hash,
CompiledClass::Sierra(Arc::new(deployer_contract_class)),
deployer_contract_class,
);

contract_class_cache.insert(
insert_sierra_class_into_cache(
&mut contract_class_cache,
deployee_class_hash,
CompiledClass::Sierra(Arc::new(deployee_contract_class)),
deployee_contract_class,
);

let mut state_reader = InMemoryStateReader::default();
Expand Down Expand Up @@ -945,14 +969,16 @@ fn deploy_syscall_address_unavailable_test() {
// Insert contract to be deployed so that its address is taken
let deployee_address = expected_deployed_contract_address;

contract_class_cache.insert(
insert_sierra_class_into_cache(
&mut contract_class_cache,
deployer_class_hash,
CompiledClass::Sierra(Arc::new(deployer_contract_class)),
deployer_contract_class,
);

contract_class_cache.insert(
insert_sierra_class_into_cache(
&mut contract_class_cache,
deployee_class_hash,
CompiledClass::Sierra(Arc::new(deployee_contract_class)),
deployee_contract_class,
);

let mut state_reader = InMemoryStateReader::default();
Expand Down