Skip to content
This repository was archived by the owner on Jan 9, 2025. It is now read-only.

feat: account versioning #1058

Merged
merged 8 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/kakarot/accounts/account_contract.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ func upgrade{
return AccountContract.upgrade(new_class);
}

// @notice Returns the version of the account class.
// @dev The version is a packed integer with the following format: XXX.YYY.ZZZ where XXX is the
// major version, YYY is the minor version and ZZZ is the patch version.
@view
func version{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}() -> (version: felt) {
let version = AccountContract.VERSION;
return (version=version);
}

// @notice Gets the evm address associated with the account.
// @return address The EVM address of the account.
@view
Expand Down
17 changes: 9 additions & 8 deletions src/kakarot/accounts/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ const BYTES_PER_FELT = 31;
// @notice This file contains the EVM account representation logic.
// @dev: Both EOAs and Contract Accounts are represented by this contract.
namespace AccountContract {
// 000.001.000
const VERSION = 000001000;

// @notice This function is used to initialize the smart contract account.
// @dev The `evm_address` and `kakarot_address` were set during the uninitialized_account creation.
// Reading them from state ensures that they always match the ones the account was created for.
Expand All @@ -99,25 +102,23 @@ namespace AccountContract {
return ();
}

// @notice This function is used to upgrade the smart contract account.
// @notice Upgrade the implementation of the account.
// @param new_class The new class of the account.
func upgrade{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
bitwise_ptr: BitwiseBuiltin*,
}(new_class: felt) {
alloc_locals;
// Access control check. Only the EOA owner should be able to upgrade its contract.
// with `supports_interface`
internal.assert_only_self();
// TODO: only valid classes should be allowed to be upgraded. Add a validation on the new class interface.
Internals.assert_only_self();
assert_not_zero(new_class);
replace_class(new_class);
Account_implementation.write(new_class);
return ();
}

@view
func get_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
implementation: felt
) {
Expand Down Expand Up @@ -359,7 +360,7 @@ namespace AccountContract {
Ownable.assert_only_owner();
// Recursively store the bytecode.
Account_bytecode_len.write(bytecode_len);
internal.write_bytecode(bytecode_len=bytecode_len, bytecode=bytecode);
Internals.write_bytecode(bytecode_len=bytecode_len, bytecode=bytecode);
return ();
}

Expand All @@ -382,7 +383,7 @@ namespace AccountContract {
}() -> (bytecode_len: felt, bytecode: felt*) {
alloc_locals;
let (bytecode_len) = Account_bytecode_len.read();
let (bytecode_) = internal.load_bytecode(bytecode_len);
let (bytecode_) = Internals.load_bytecode(bytecode_len);
return (bytecode_len, bytecode_);
}

Expand Down Expand Up @@ -437,7 +438,7 @@ namespace AccountContract {
}
}

namespace internal {
namespace Internals {
// @notice asserts that the caller is the account itself
func assert_only_self{syscall_ptr: felt*}() {
let (this) = get_contract_address();
Expand Down
3 changes: 3 additions & 0 deletions src/kakarot/interfaces/interfaces.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ namespace IAccount {
func get_evm_address() -> (evm_address: felt) {
}

func version() -> (version: felt) {
}

func bytecode_len() -> (len: felt) {
}

Expand Down
15 changes: 15 additions & 0 deletions tests/src/kakarot/accounts/test_contract_account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@ func test__initialize__should_store_given_evm_address{
return ();
}

func test__upgrade{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}() {
alloc_locals;

// Given
local new_class: felt;
%{ ids.new_class = program_input["new_class"] %}

// When
AccountContract.upgrade(new_class);

return ();
}

func test__get_evm_address__should_return_stored_address{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}() -> felt {
Expand Down
37 changes: 36 additions & 1 deletion tests/src/kakarot/accounts/test_contract_account.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from textwrap import wrap
from unittest.mock import call, patch
from unittest.mock import PropertyMock, call, patch

import pytest
from starkware.starknet.public.abi import (
Expand Down Expand Up @@ -123,3 +123,38 @@ def test_should_read_bytecode(self, cairo_run, bytecode, storage):
calls = [call(address=address) for address in addresses]
mock_storage.assert_has_calls(calls)
assert output[:output_len] == list(bytecode)

class TestUpgrade:
@patch.object(SyscallHandler, "mock_replace_class")
@patch.object(SyscallHandler, "mock_storage")
def test_should_upgrade_account(
self, mock_storage, mock_replace_class, cairo_run
):
cairo_run("test__upgrade", new_class=0xBEEF)

mock_replace_class.assert_called_once_with(class_hash=0xBEEF)
mock_storage.assert_called_once_with(
address=get_storage_var_address("Account_implementation"), value=0xBEEF
)

@patch(
"tests.utils.syscall_handler.SyscallHandler.contract_address",
new_callable=PropertyMock,
return_value=123,
)
@patch(
"tests.utils.syscall_handler.SyscallHandler.caller_address",
new_callable=PropertyMock,
return_value=456,
)
@patch.object(SyscallHandler, "mock_replace_class")
def test_should_fail_caller_not_self(
self,
mock_replace_class,
mock_contract_address,
mock_caller_address,
cairo_run,
):
with cairo_error():
cairo_run("test__upgrade", new_class=0xBEEF)
mock_replace_class.assert_not_called()
18 changes: 18 additions & 0 deletions tests/utils/syscall_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ class SyscallHandler:
)
contract_address: int = 0xABDE1
caller_address: int = 0xABDE1
class_hash: int = 0xC1A55
patches = {}
mock_call = mock.MagicMock()
mock_library_call = mock.MagicMock()
mock_storage = mock.MagicMock()
mock_event = mock.MagicMock()
mock_replace_class = mock.MagicMock()

def get_contract_address(self, segments, syscall_ptr):
"""
Expand Down Expand Up @@ -262,6 +264,22 @@ def storage_write(self, segments, syscall_ptr):
value=segments.memory[syscall_ptr + 2],
)

def replace_class(self, segments, syscall_ptr):
"""
Record the replaced class hash in the internal mock object and update the class_hash attribute.

Syscall structure is:

struct ReplaceClass {
selector: felt,
class_hash: felt,
}
"""
class_hash = segments.memory[syscall_ptr + 1]
self.mock_replace_class(
class_hash=class_hash,
)

def emit_event(self, segments, syscall_ptr):
"""
Record the call in the internal mock object.
Expand Down
Loading