Skip to content
This repository was archived by the owner on Jan 9, 2025. It is now read-only.

staging: fixes all possible tests #1182

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions kakarot_scripts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,22 @@
WEB3 = Web3()

try:
response = requests.post(
RPC_CLIENT.url,
json={
"jsonrpc": "2.0",
"method": "starknet_chainId",
"params": [],
"id": 0,
},
)
payload = json.loads(response.text)
starknet_chain_id = int(payload["result"], 16)

if WEB3.is_connected():
chain_id = WEB3.eth.chain_id
else:
response = requests.post(
RPC_CLIENT.url,
json={
"jsonrpc": "2.0",
"method": "starknet_chainId",
"params": [],
"id": 0,
},
)
payload = json.loads(response.text)

chain_id = int(payload["result"], 16)
chain_id = starknet_chain_id
except (
requests.exceptions.ConnectionError,
requests.exceptions.MissingSchema,
Expand All @@ -155,10 +156,12 @@
f"⚠️ Could not get chain Id from {NETWORK['rpc_url']}: {e}, defaulting to KKRT"
)
chain_id = int.from_bytes(b"KKRT", "big")
starknet_chain_id = int.from_bytes(b"KKRT", "big")


class ChainId(IntEnum):
chain_id = chain_id
starknet_chain_id = starknet_chain_id


NETWORK["chain_id"] = ChainId.chain_id
Expand Down
3 changes: 2 additions & 1 deletion kakarot_scripts/utils/kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
NETWORK,
RPC_CLIENT,
WEB3,
ChainId,
)
from kakarot_scripts.utils.starknet import call as _call_starknet
from kakarot_scripts.utils.starknet import fund_address as _fund_starknet_address
Expand Down Expand Up @@ -334,7 +335,7 @@ async def get_eoa(private_key=None, amount=10) -> Account:
return Account(
address=starknet_address,
client=RPC_CLIENT,
chain=NETWORK["chain_id"],
chain=ChainId.starknet_chain_id,
# This is somehow a hack because we put EVM private key into a
# Stark signer KeyPair to have both a regular Starknet account
# and the access to the private key
Expand Down
5 changes: 3 additions & 2 deletions kakarot_scripts/utils/starknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
RPC_CLIENT,
SOURCE_DIR,
ArtifactType,
ChainId,
)

logging.basicConfig()
Expand Down Expand Up @@ -122,7 +123,7 @@ async def get_starknet_account(
return Account(
address=address,
client=RPC_CLIENT,
chain=NETWORK["chain_id"],
chain=ChainId.starknet_chain_id,
key_pair=key_pair,
)

Expand Down Expand Up @@ -347,7 +348,7 @@ async def deploy_starknet_account(class_hash=None, private_key=None, amount=1):
salt=salt,
key_pair=key_pair,
client=RPC_CLIENT,
chain=NETWORK["chain_id"],
chain=ChainId.starknet_chain_id,
constructor_calldata=constructor_calldata,
max_fee=_max_fee,
)
Expand Down
6 changes: 3 additions & 3 deletions tests/end_to_end/PlainOpcodes/test_plain_opcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ async def test_should_return_starknet_timestamp(
self, plain_opcodes, block_timestamp
):
assert pytest.approx(
await plain_opcodes.opcodeTimestamp(), abs=10
) == await block_timestamp("latest")
await plain_opcodes.opcodeTimestamp(), abs=20
) == await block_timestamp("pending")

class TestBlockhash:
@pytest.mark.xfail(reason="Need to fix blockhash on real Starknet network")
Expand All @@ -54,7 +54,7 @@ async def test_should_return_zero_with_invalid_block_number(
self, plain_opcodes, block_number
):
blockhash_invalid_number = await plain_opcodes.opcodeBlockHash(
await block_number("latest") + 1
await block_number("latest") + 10
)

assert int.from_bytes(blockhash_invalid_number, byteorder="big") == 0
Expand Down
17 changes: 2 additions & 15 deletions tests/end_to_end/test_kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def class_hashes():


@pytest_asyncio.fixture(scope="session")
async def origin(evm: Contract):
async def origin(evm: Contract, max_fee):
"""
Deploys the origin's Starknet contract to the correct address.
"""
Expand All @@ -70,7 +70,7 @@ async def origin(evm: Contract):
is_deployed = (await evm.functions["is_deployed"].call(evm_address)).deployed
if is_deployed:
return evm_address
tx = await evm.functions["deploy_account"].invoke_v1(evm_address, max_fee=100)
tx = await evm.functions["deploy_account"].invoke_v1(evm_address, max_fee=max_fee)
await wait_for_transaction(tx.hash)
return evm_address

Expand Down Expand Up @@ -101,7 +101,6 @@ async def test_execute(
params: dict,
request,
evm: Contract,
addresses,
max_fee,
origin,
):
Expand Down Expand Up @@ -168,8 +167,6 @@ async def test_should_return_same_as_deployed_address(
class TestDeployExternallyOwnedAccount:
async def test_should_deploy_starknet_contract_at_corresponding_address(
self,
starknet: FullNodeClient,
fund_starknet_address,
deploy_externally_owned_account,
compute_starknet_address,
get_contract,
Expand All @@ -187,11 +184,8 @@ class TestRegisterAccount:
async def test_should_fail_when_sender_is_not_account(
self,
starknet: FullNodeClient,
fund_starknet_address,
deploy_externally_owned_account,
register_account,
compute_starknet_address,
get_contract,
random_seed,
):
evm_address = generate_random_evm_address(random_seed)
Expand All @@ -205,11 +199,8 @@ async def test_should_fail_when_sender_is_not_account(
async def test_should_fail_when_account_is_already_registered(
self,
starknet: FullNodeClient,
fund_starknet_address,
deploy_externally_owned_account,
register_account,
compute_starknet_address,
get_contract,
):
evm_address = generate_random_evm_address(random_seed)
await deploy_externally_owned_account(evm_address)
Expand All @@ -222,7 +213,6 @@ class TestSetAccountStorage:
class TestWriteAccountBytecode:
async def test_should_set_account_bytecode(
self,
starknet: FullNodeClient,
deploy_externally_owned_account,
invoke,
compute_starknet_address,
Expand Down Expand Up @@ -253,8 +243,6 @@ async def test_should_fail_not_owner(
starknet: FullNodeClient,
deploy_externally_owned_account,
invoke,
compute_starknet_address,
get_contract,
random_seed,
other,
):
Expand All @@ -278,7 +266,6 @@ class TestWriteAccountNonce:

async def test_should_set_account_nonce(
self,
starknet: FullNodeClient,
deploy_externally_owned_account,
invoke,
compute_starknet_address,
Expand Down
Loading