Skip to content
This repository was archived by the owner on Jan 9, 2025. It is now read-only.

Call restriction checks in contracts #1179

Merged
merged 6 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/kakarot/accounts/account_contract.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ func get_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_ch
func set_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
implementation_class: felt
) {
// Access control check.
Ownable.assert_only_owner();
return AccountContract.set_implementation(implementation_class);
}

Expand Down Expand Up @@ -203,6 +205,8 @@ func __execute__{
func write_bytecode{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}(bytecode_len: felt, bytecode: felt*) {
// Access control check.
Ownable.assert_only_owner();
return AccountContract.write_bytecode(bytecode_len, bytecode);
}

Expand All @@ -220,9 +224,9 @@ func bytecode{
// @dev Compared to bytecode, it does not read the code so it's much cheaper if only len is required.
// @return len The bytecode_len of the smart contract.
@view
func bytecode_len{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}() -> (len: felt) {
func bytecode_len{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
len: felt
) {
let (len) = AccountContract.bytecode_len();
return (len=len);
}
Expand All @@ -234,6 +238,8 @@ func bytecode_len{
func write_storage{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}(storage_addr: felt, value: Uint256) {
// Access control check.
Ownable.assert_only_owner();
return AccountContract.write_storage(storage_addr, value);
}

Expand All @@ -257,16 +263,19 @@ func get_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
// @notice This function set the contract account nonce
@external
func set_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(nonce: felt) {
// Access control check.
Ownable.assert_only_owner();
return AccountContract.set_nonce(nonce);
}

// @notice Write valid jumpdests in the account's storage.
// @param jumpdests_len The length of the jumpdests array.
// @param jumpdests The jumpdests array, containing indexes of valid jumpdests.
@external
func write_jumpdests{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}(jumpdests_len: felt, jumpdests: felt*) {
func write_jumpdests{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
jumpdests_len: felt, jumpdests: felt*
) {
// Access control check.
Ownable.assert_only_owner();
AccountContract.write_jumpdests(jumpdests_len, jumpdests);
return ();
Expand All @@ -276,9 +285,9 @@ func write_jumpdests{
// @param index The index of the jumpdest.
// @return is_valid 1 if the jumpdest is valid, 0 otherwise.
@view
func is_valid_jumpdest{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}(index: felt) -> (is_valid: felt) {
func is_valid_jumpdest{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
index: felt
) -> (is_valid: felt) {
let is_valid = AccountContract.is_valid_jumpdest(index);
return (is_valid=is_valid);
}
8 changes: 0 additions & 8 deletions src/kakarot/accounts/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ namespace AccountContract {
func set_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
new_implementation: felt
) {
// Access control check.
Ownable.assert_only_owner();
replace_class(new_implementation);
Account_implementation.write(new_implementation);
return ();
Expand Down Expand Up @@ -362,8 +360,6 @@ namespace AccountContract {
bitwise_ptr: BitwiseBuiltin*,
}(bytecode_len: felt, bytecode: felt*) {
alloc_locals;
// Access control check.
Ownable.assert_only_owner();
// Recursively store the bytecode.
Account_bytecode_len.write(bytecode_len);
Internals.write_bytecode(bytecode_len=bytecode_len, bytecode=bytecode);
Expand Down Expand Up @@ -417,8 +413,6 @@ namespace AccountContract {
range_check_ptr,
bitwise_ptr: BitwiseBuiltin*,
}(storage_addr: felt, value: Uint256) {
// Access control check.
Ownable.assert_only_owner();
// Write State
storage_write(address=storage_addr + 0, value=value.low);
storage_write(address=storage_addr + 1, value=value.high);
Expand All @@ -439,8 +433,6 @@ namespace AccountContract {
func set_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
new_nonce: felt
) {
// Access control check.
Ownable.assert_only_owner();
Account_nonce.write(new_nonce);
return ();
}
Expand Down
5 changes: 5 additions & 0 deletions src/kakarot/kakarot.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func upgrade{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
func set_native_token{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
native_token_address: felt
) {
Ownable.assert_only_owner();
return Kakarot.set_native_token(native_token_address);
}

Expand All @@ -79,6 +80,7 @@ func get_native_token{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_chec
// @param base_fee The new base fee.
@external
func set_base_fee{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(base_fee: felt) {
Ownable.assert_only_owner();
return Kakarot.set_base_fee(base_fee);
}

Expand All @@ -95,6 +97,7 @@ func get_base_fee{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_pt
// @param coinbase The new coinbase address.
@external
func set_coinbase{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(coinbase: felt) {
Ownable.assert_only_owner();
return Kakarot.set_coinbase(coinbase);
}

Expand All @@ -113,6 +116,7 @@ func get_coinbase{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_pt
func set_prev_randao{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
prev_randao: Uint256
) {
Ownable.assert_only_owner();
return Kakarot.set_prev_randao(prev_randao);
}

Expand All @@ -131,6 +135,7 @@ func get_prev_randao{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check
func set_block_gas_limit{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
gas_limit_: felt
) {
Ownable.assert_only_owner();
return Kakarot.set_block_gas_limit(gas_limit_);
}

Expand Down
5 changes: 0 additions & 5 deletions src/kakarot/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ namespace Kakarot {
func set_native_token{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
native_token_address: felt
) {
Ownable.assert_only_owner();
Kakarot_native_token_address.write(native_token_address);
return ();
}
Expand All @@ -144,7 +143,6 @@ namespace Kakarot {
func set_base_fee{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
base_fee: felt
) {
Ownable.assert_only_owner();
Kakarot_base_fee.write(base_fee);
return ();
}
Expand All @@ -163,7 +161,6 @@ namespace Kakarot {
func set_coinbase{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
coinbase: felt
) {
Ownable.assert_only_owner();
Kakarot_coinbase.write(coinbase);
return ();
}
Expand All @@ -182,7 +179,6 @@ namespace Kakarot {
func set_prev_randao{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
prev_randao: Uint256
) {
Ownable.assert_only_owner();
Kakarot_prev_randao.write(prev_randao);
return ();
}
Expand All @@ -201,7 +197,6 @@ namespace Kakarot {
func set_block_gas_limit{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
block_gas_limit: felt
) {
Ownable.assert_only_owner();
Kakarot_block_gas_limit.write(block_gas_limit);
return ();
}
Expand Down
24 changes: 16 additions & 8 deletions tests/src/kakarot/accounts/test_contract_account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@ from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.uint256 import Uint256

from kakarot.accounts.library import AccountContract, Internals as AccountInternals
from kakarot.accounts.library import Internals as AccountInternals
from kakarot.accounts.account_contract import (
initialize,
get_evm_address,
write_bytecode,
bytecode as read_bytecode,
write_jumpdests,
is_valid_jumpdest,
)

func test__initialize{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
Expand All @@ -23,15 +31,15 @@ func test__initialize{
%}

// When
AccountContract.initialize(kakarot_address, evm_address, implementation_class);
initialize(kakarot_address, evm_address, implementation_class);

return ();
}

func test__get_evm_address__should_return_stored_address{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}() -> felt {
let (evm_address) = AccountContract.get_evm_address();
let (evm_address) = get_evm_address();

return evm_address;
}
Expand All @@ -49,16 +57,16 @@ func test__write_bytecode{
segments.write_arg(ids.bytecode, program_input["bytecode"])
%}

AccountContract.write_bytecode(bytecode_len, bytecode);
write_bytecode(bytecode_len, bytecode);

return ();
}

func test__read_bytecode{
func test__bytecode{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}() -> (bytecode_len: felt, bytecode: felt*) {
alloc_locals;
let (bytecode_len, bytecode) = AccountContract.bytecode();
let (bytecode_len, bytecode) = read_bytecode();
return (bytecode_len, bytecode);
}

Expand Down Expand Up @@ -103,7 +111,7 @@ func test__write_jumpdests{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range
%}

// When
AccountContract.write_jumpdests(jumpdests_len, jumpdests);
write_jumpdests(jumpdests_len, jumpdests);

return ();
}
Expand All @@ -113,7 +121,7 @@ func test__is_valid_jumpdest{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ran
tempvar index: felt;
%{ ids.index = program_input["index"] %}

let is_valid = AccountContract.is_valid_jumpdest(index);
let (is_valid) = is_valid_jumpdest(index);

return is_valid;
}
8 changes: 7 additions & 1 deletion tests/src/kakarot/accounts/test_contract_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_should_read_bytecode(self, cairo_run, bytecode, storage):
with patch.object(
SyscallHandler, "mock_storage", side_effect=storage
) as mock_storage:
output_len, output = cairo_run("test__read_bytecode")
output_len, output = cairo_run("test__bytecode")
chunk_counts, remainder = divmod(len(bytecode), 31)
addresses = list(range(chunk_counts + (remainder > 0)))
calls = [call(address=address) for address in addresses]
Expand All @@ -143,6 +143,12 @@ def test_should_read_bytecode(self, cairo_run, bytecode, storage):

class TestJumpdests:
class TestWriteJumpdests:
@SyscallHandler.patch("Ownable_owner", 0xDEAD)
def test_should_assert_only_owner(self, cairo_run):
with cairo_error():
cairo_run("test__write_jumpdests", bytecode=[])

@SyscallHandler.patch("Ownable_owner", SyscallHandler.caller_address)
def test__should_store_valid_jumpdests(self, cairo_run):
jumpdests = [0x02, 0x10, 0xFF]
cairo_run("test__write_jumpdests", jumpdests=jumpdests)
Expand Down
Loading