Skip to content
This repository was archived by the owner on Jan 9, 2025. It is now read-only.

feat: upgrade contracts invoked if outdated #1153

Merged
merged 6 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/kakarot/account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,16 @@ namespace Account {
address=address, code_len=0, code=bytecode, nonce=0, balance=balance_ptr
);
return account;
} else {
tempvar address = new model.Address(starknet=starknet_address, evm=evm_address);
let balance = fetch_balance(address);
assert balance_ptr = new Uint256(balance.low, balance.high);
}

tempvar address = new model.Address(starknet=starknet_address, evm=evm_address);
let balance = fetch_balance(address);
assert balance_ptr = new Uint256(balance.low, balance.high);

// Upgrade the target starknet contract's class if it's not the latest one.
// The contract must be deployed on starknet already.
Internals.check_and_upgrade_account_class(address);

let (bytecode_len, bytecode) = IAccount.bytecode(contract_address=starknet_address);
let (nonce) = IAccount.get_nonce(contract_address=starknet_address);

Expand Down Expand Up @@ -671,4 +675,16 @@ namespace Internals {

return _cache_storage_keys(evm_address, storage_keys_len - 1, storage_keys + Uint256.SIZE);
}

func check_and_upgrade_account_class{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr
}(address: model.Address*) {
let (account_impl) = IAccount.get_implementation(address.starknet);
let (latest_impl) = Kakarot_account_contract_class_hash.read();
if (account_impl == latest_impl) {
return ();
}
IAccount.set_implementation(address.starknet, latest_impl);
return ();
}
}
14 changes: 14 additions & 0 deletions src/kakarot/accounts/account_contract.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ func get_evm_address{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check
return AccountContract.get_evm_address();
}

@view
func get_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
implementation: felt
) {
return AccountContract.get_implementation();
}

@external
func set_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
implementation_class: felt
) {
return AccountContract.set_implementation(implementation_class);
}

// @notice Checks if the account was initialized.
// @return is_initialized: 1 if the account has been initialized 0 otherwise.
@view
Expand Down
10 changes: 10 additions & 0 deletions src/kakarot/accounts/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ namespace AccountContract {
return (implementation=implementation);
}

func set_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
new_implementation: felt
) {
// Access control check.
Ownable.assert_only_owner();
replace_class(new_implementation);
Account_implementation.write(new_implementation);
return ();
}

// @return address The EVM address of the account
func get_evm_address{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
address: felt
Expand Down
1 change: 1 addition & 0 deletions src/kakarot/instructions/system_operations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,7 @@ namespace CallHelper {
Memory.load_n(args_size.low, calldata, args_offset.low);

// 2. Build child_evm

let code_account = State.get_account(code_address);
local code_len: felt = code_account.code_len;
local code: felt* = code_account.code;
Expand Down
6 changes: 6 additions & 0 deletions src/kakarot/interfaces/interfaces.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ namespace IAccount {
func get_evm_address() -> (evm_address: felt) {
}

func get_implementation() -> (implementation: felt) {
}

func set_implementation(implementation: felt) {
}

func version() -> (version: felt) {
}

Expand Down
93 changes: 92 additions & 1 deletion tests/end_to_end/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import pytest_asyncio
from eth_utils import keccak
from starknet_py.net.full_node_client import FullNodeClient
from starkware.starknet.public.abi import get_storage_var_address

Expand Down Expand Up @@ -41,7 +42,7 @@ async def new_account(max_fee):
return account


@pytest_asyncio.fixture(scope="module")
@pytest_asyncio.fixture(scope="function")
async def counter(deploy_contract, new_account):
return await deploy_contract(
"PlainOpcodes",
Expand All @@ -50,9 +51,19 @@ async def counter(deploy_contract, new_account):
)


@pytest_asyncio.fixture(scope="module")
async def caller(deploy_contract, owner):
return await deploy_contract(
"PlainOpcodes",
"Caller",
caller_eoa=owner.starknet_contract,
)


@pytest.fixture(autouse=True)
async def cleanup(invoke, class_hashes):
yield

await invoke(
"kakarot",
"set_account_contract_class_hash",
Expand All @@ -72,6 +83,18 @@ async def assert_counter_transaction_success(counter, new_account):
assert await counter.count() == prev_count + 1


async def assert_caller_contract_increases_counter(caller, counter, new_account):
"""
Assert that the transaction sent, other than upgrading the account contract, is successful.
"""
prev_count = await counter.count()
inc_selector = keccak(b"inc()")[0:4]
await caller.call(
counter.address, inc_selector, caller_eoa=new_account.starknet_contract
)
assert await counter.count() == prev_count + 1


@pytest.mark.asyncio(scope="session")
@pytest.mark.AccountContract
class TestAccount:
Expand All @@ -83,6 +106,7 @@ async def test_should_upgrade_outdated_account_on_transfer(
counter,
new_account,
class_hashes,
cleanup,
):
prev_class = await starknet.get_class_hash_at(
new_account.starknet_contract.address
Expand Down Expand Up @@ -135,3 +159,70 @@ async def test_should_update_cairo1_helpers_class(
)
== target_class
)

class TestAutoUpgradeContracts:
async def test_should_upgrade_outdated_contract_transaction_target(
self,
starknet: FullNodeClient,
invoke,
call,
counter,
new_account,
class_hashes,
):
counter_starknet_address = (
await call(
"kakarot",
"get_starknet_address",
int(counter.address, 16),
)
).starknet_address
prev_class = await starknet.get_class_hash_at(counter_starknet_address)
target_class = class_hashes["account_contract_fixture"]
assert prev_class != target_class
assert prev_class == class_hashes["account_contract"]

await invoke(
"kakarot",
"set_account_contract_class_hash",
target_class,
)

await assert_counter_transaction_success(counter, new_account)

new_class = await starknet.get_class_hash_at(counter_starknet_address)
assert new_class == target_class

async def test_should_upgrade_outdated_contract_called_contract(
self,
starknet: FullNodeClient,
invoke,
counter,
call,
caller,
new_account,
class_hashes,
cleanup,
):
counter_starknet_address = (
await call(
"kakarot",
"get_starknet_address",
int(counter.address, 16),
)
).starknet_address
prev_class = await starknet.get_class_hash_at(counter_starknet_address)
target_class = class_hashes["account_contract_fixture"]
assert prev_class != target_class
assert prev_class == class_hashes["account_contract"]

await invoke(
"kakarot",
"set_account_contract_class_hash",
target_class,
)

await assert_caller_contract_increases_counter(caller, counter, new_account)

new_class = await starknet.get_class_hash_at(counter_starknet_address)
assert new_class == target_class
4 changes: 4 additions & 0 deletions tests/fixtures/account_contract_fixture.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ from kakarot.accounts.account_contract import (
storage,
get_nonce,
set_nonce,
get_implementation,
set_implementation,
is_valid_jumpdest,
write_jumpdests,
)

// make sure the class hash is different
Expand Down
2 changes: 2 additions & 0 deletions tests/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

ZERO_ADDRESS = "0x" + 40 * "0"

ACCOUNT_CLASS_IMPLEMENTATION = 0xC0DEC1A55

BLOCK_NUMBER = 0x42
BLOCK_TIMESTAMP = int(time())

Expand Down
6 changes: 5 additions & 1 deletion tests/utils/syscall_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_storage_var_address,
)

from tests.utils.constants import CHAIN_ID
from tests.utils.constants import ACCOUNT_CLASS_IMPLEMENTATION, CHAIN_ID
from tests.utils.uint256 import int_to_uint256, uint256_to_int


Expand Down Expand Up @@ -173,6 +173,10 @@ class SyscallHandler:
get_selector_from_name(
"verify_signature_secp256r1"
): cairo_verify_signature_secp256r1,
get_selector_from_name("get_implementation"): lambda addr, data: [
ACCOUNT_CLASS_IMPLEMENTATION
],
get_selector_from_name("set_implementation"): lambda addr, data: [],
}

def get_contract_address(self, segments, syscall_ptr):
Expand Down
Loading