Skip to content

Commit

Permalink
feat: support configuring test account initial balance (#2173)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jul 9, 2024
1 parent 1e726ab commit 5d2db65
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 53 deletions.
3 changes: 2 additions & 1 deletion docs/userguides/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ def test_my_method(owner, receiver):
...
```

You can configure your accounts by changing the `mnemonic` or `number_of_accounts` settings in the `test` section of your `ape-config.yaml` file:
You can configure your accounts by changing the `mnemonic`, `number_of_accounts`, and `balance` in the `test` section of your `ape-config.yaml` file:

```yaml
test:
mnemonic: test test test test test test test test test test test junk
number_of_accounts: 5
balance: 100_000 ETH
```
If you are running tests against `anvil`, your generated test accounts may not correspond to the `anvil`'s default generated accounts despite using the same mnemonic. In such a case, you are able to specify a custom derivation path in `ape-config.yaml`:
Expand Down
2 changes: 2 additions & 0 deletions src/ape/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from ape.utils.process import JoinableQueue, spawn
from ape.utils.testing import (
DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
DEFAULT_TEST_ACCOUNT_BALANCE,
DEFAULT_TEST_CHAIN_ID,
DEFAULT_TEST_HD_PATH,
DEFAULT_TEST_MNEMONIC,
Expand All @@ -82,6 +83,7 @@
"DEFAULT_LIVE_NETWORK_BASE_FEE_MULTIPLIER",
"DEFAULT_LOCAL_TRANSACTION_ACCEPTANCE_TIMEOUT",
"DEFAULT_NUMBER_OF_TEST_ACCOUNTS",
"DEFAULT_TEST_ACCOUNT_BALANCE",
"DEFAULT_TEST_CHAIN_ID",
"DEFAULT_TEST_MNEMONIC",
"DEFAULT_TEST_HD_PATH",
Expand Down
1 change: 1 addition & 0 deletions src/ape/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DEFAULT_TEST_MNEMONIC = "test test test test test test test test test test test junk"
DEFAULT_TEST_HD_PATH = "m/44'/60'/0'/0"
DEFAULT_TEST_CHAIN_ID = 1337
DEFAULT_TEST_ACCOUNT_BALANCE = int(10e21) # 10,000 Ether (in Wei)
GeneratedDevAccount = namedtuple("GeneratedDevAccount", ("address", "private_key"))
"""
An account key-pair generated from the test mnemonic. Set the test mnemonic
Expand Down
77 changes: 41 additions & 36 deletions src/ape_node/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from eth_pydantic_types import HexBytes
from eth_typing import HexStr
from eth_utils import add_0x_prefix, to_hex, to_wei
from eth_utils import add_0x_prefix, to_hex
from evmchains import get_random_rpc
from geth.chain import initialize_chain
from geth.process import BaseGethProcess
Expand All @@ -21,17 +21,15 @@
from ape.api import PluginConfig, SubprocessProvider, TestProviderAPI
from ape.logging import LogLevel, logger
from ape.types import SnapshotID
from ape.utils import (
from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail, raises_not_implemented
from ape.utils.process import JoinableQueue, spawn
from ape.utils.testing import (
DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
DEFAULT_TEST_ACCOUNT_BALANCE,
DEFAULT_TEST_CHAIN_ID,
DEFAULT_TEST_HD_PATH,
DEFAULT_TEST_MNEMONIC,
ZERO_ADDRESS,
JoinableQueue,
generate_dev_accounts,
log_instead_of_fail,
raises_not_implemented,
spawn,
)
from ape_ethereum.provider import (
DEFAULT_HOSTNAME,
Expand Down Expand Up @@ -98,7 +96,7 @@ def __init__(
mnemonic: str = DEFAULT_TEST_MNEMONIC,
number_of_accounts: int = DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
chain_id: int = DEFAULT_TEST_CHAIN_ID,
initial_balance: Union[str, int] = to_wei(10000, "ether"),
initial_balance: Union[str, int] = DEFAULT_TEST_ACCOUNT_BALANCE,
executable: Optional[str] = None,
auto_disconnect: bool = True,
extra_funded_accounts: Optional[list[str]] = None,
Expand Down Expand Up @@ -131,8 +129,9 @@ def __init__(
self._clean()

geth_kwargs["dev_mode"] = True
hd_path = hd_path or DEFAULT_TEST_HD_PATH
accounts = generate_dev_accounts(
mnemonic, number_of_accounts=number_of_accounts, hd_path=hd_path or DEFAULT_TEST_HD_PATH
mnemonic, number_of_accounts=number_of_accounts, hd_path=hd_path
)
addresses = [a.address for a in accounts]
addresses.extend(extra_funded_accounts or [])
Expand All @@ -152,20 +151,22 @@ def from_uri(cls, uri: str, data_folder: Path, **kwargs):
port = parsed_uri.port if parsed_uri.port is not None else DEFAULT_PORT
mnemonic = kwargs.get("mnemonic", DEFAULT_TEST_MNEMONIC)
number_of_accounts = kwargs.get("number_of_accounts", DEFAULT_NUMBER_OF_TEST_ACCOUNTS)
balance = kwargs.get("initial_balance", DEFAULT_TEST_ACCOUNT_BALANCE)
extra_accounts = [
HexBytes(a).hex().lower() for a in kwargs.get("extra_funded_accounts", [])
]

return cls(
data_folder,
hostname=parsed_uri.host,
port=port,
mnemonic=mnemonic,
number_of_accounts=number_of_accounts,
executable=kwargs.get("executable"),
auto_disconnect=kwargs.get("auto_disconnect", True),
executable=kwargs.get("executable"),
extra_funded_accounts=extra_accounts,
hd_path=kwargs.get("hd_path", DEFAULT_TEST_HD_PATH),
hostname=parsed_uri.host,
initial_balance=balance,
mnemonic=mnemonic,
number_of_accounts=number_of_accounts,
port=port,
)

@property
Expand Down Expand Up @@ -325,33 +326,15 @@ def connect(self):
self.start()

def start(self, timeout: int = 20):
# NOTE: Using JSON mode to ensure types can be passed as CLI args.
test_config = self.config_manager.get_config("test").model_dump(mode="json")

# Allow configuring a custom executable besides your $PATH geth.
if self.settings.executable is not None:
test_config["executable"] = self.settings.executable

test_config["ipc_path"] = self.ipc_path
test_config["auto_disconnect"] = self._test_runner is None or test_config.get(
"disconnect_providers_after", True
)

# Include extra accounts to allocated funds to at genesis.
extra_accounts = self.settings.ethereum.local.get("extra_funded_accounts", [])
extra_accounts.extend(self.provider_settings.get("extra_funded_accounts", []))
extra_accounts = list({HexBytes(a).hex().lower() for a in extra_accounts})
test_config["extra_funded_accounts"] = extra_accounts

process = GethDevProcess.from_uri(self.uri, self.data_dir, **test_config)
process.connect(timeout=timeout)
geth_dev = self._create_process()
geth_dev.connect(timeout=timeout)
if not self.web3.is_connected():
process.disconnect()
geth_dev.disconnect()
raise ConnectionError("Unable to connect to locally running geth.")
else:
self.web3.middleware_onion.inject(geth_poa_middleware, layer=0)

self._process = process
self._process = geth_dev

# For subprocess-provider
if self._process is not None and (process := self._process.proc):
Expand All @@ -366,6 +349,28 @@ def start(self, timeout: int = 20):
spawn(self.consume_stdout_queue)
spawn(self.consume_stderr_queue)

def _create_process(self) -> GethDevProcess:
# NOTE: Using JSON mode to ensure types can be passed as CLI args.
test_config = self.config_manager.get_config("test").model_dump(mode="json")

# Allow configuring a custom executable besides your $PATH geth.
if self.settings.executable is not None:
test_config["executable"] = self.settings.executable

test_config["ipc_path"] = self.ipc_path
test_config["auto_disconnect"] = self._test_runner is None or test_config.get(
"disconnect_providers_after", True
)

# Include extra accounts to allocated funds to at genesis.
extra_accounts = self.settings.ethereum.local.get("extra_funded_accounts", [])
extra_accounts.extend(self.provider_settings.get("extra_funded_accounts", []))
extra_accounts = list({HexBytes(a).hex().lower() for a in extra_accounts})
test_config["extra_funded_accounts"] = extra_accounts
test_config["initial_balance"] = self.test_config.balance

return GethDevProcess.from_uri(self.uri, self.data_dir, **test_config)

def disconnect(self):
# Must disconnect process first.
if self._process is not None:
Expand Down
42 changes: 31 additions & 11 deletions src/ape_test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from ape import plugins
from ape.api import PluginConfig
from ape.api.networks import LOCAL_NETWORK_NAME
from ape.utils import DEFAULT_NUMBER_OF_TEST_ACCOUNTS, DEFAULT_TEST_HD_PATH, DEFAULT_TEST_MNEMONIC
from ape.utils.basemodel import ManagerAccessMixin
from ape.utils.testing import (
DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
DEFAULT_TEST_ACCOUNT_BALANCE,
DEFAULT_TEST_HD_PATH,
DEFAULT_TEST_MNEMONIC,
)
from ape_test.accounts import TestAccount, TestAccountContainer
from ape_test.provider import EthTesterProviderConfig, LocalProvider

Expand Down Expand Up @@ -110,41 +116,55 @@ class CoverageConfig(PluginConfig):


class ApeTestConfig(PluginConfig):
mnemonic: str = DEFAULT_TEST_MNEMONIC
balance: int = DEFAULT_TEST_ACCOUNT_BALANCE
"""
The mnemonic to use when generating the test accounts.
The starting-balance of every test account in Wei (NOT Ether).
"""

number_of_accounts: NonNegativeInt = DEFAULT_NUMBER_OF_TEST_ACCOUNTS
coverage: CoverageConfig = CoverageConfig()
"""
The number of test accounts to generate in the provider.
Configuration related to coverage reporting.
"""

disconnect_providers_after: bool = True
"""
Set to ``False`` to keep providers connected at the end of the test run.
"""

gas: GasConfig = GasConfig()
"""
Configuration related to gas reporting.
"""

coverage: CoverageConfig = CoverageConfig()
hd_path: str = DEFAULT_TEST_HD_PATH
"""
Configuration related to coverage reporting.
The hd_path to use when generating the test accounts.
"""

disconnect_providers_after: bool = True
mnemonic: str = DEFAULT_TEST_MNEMONIC
"""
Set to ``False`` to keep providers connected at the end of the test run.
The mnemonic to use when generating the test accounts.
"""

hd_path: str = DEFAULT_TEST_HD_PATH
number_of_accounts: NonNegativeInt = DEFAULT_NUMBER_OF_TEST_ACCOUNTS
"""
The hd_path to use when generating the test accounts.
The number of test accounts to generate in the provider.
"""

provider: EthTesterProviderConfig = EthTesterProviderConfig()
"""
Settings for the provider.
"""

@field_validator("balance", mode="before")
@classmethod
def validate_balance(cls, value):
return (
value
if isinstance(value, int)
else ManagerAccessMixin.conversion_manager.convert(value, int)
)


@plugins.register(plugins.Config)
def config_class():
Expand Down
4 changes: 3 additions & 1 deletion src/ape_test/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ def tester(self):
return

hd_path = (self.config.hd_path or DEFAULT_TEST_HD_PATH).rstrip("/")
state_overrides = {"balance": self.test_config.balance}
self._evm_backend = PyEVMBackend.from_mnemonic(
genesis_state_overrides=state_overrides,
hd_path=hd_path,
mnemonic=self.config.mnemonic,
num_accounts=self.config.number_of_accounts,
hd_path=hd_path,
)
endpoints = {**API_ENDPOINTS}
endpoints["eth"] = merge(endpoints["eth"], {"chainId": static_return(chain_id)})
Expand Down
14 changes: 14 additions & 0 deletions tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,17 @@ def test_trace_approach_config(project):
with project.temp_config(node=node_cfg):
provider = project.network_manager.ethereum.local.get_provider("node")
assert provider.call_trace_approach is TraceApproach.GETH_STRUCT_LOG_PARSE


def test_start(mocker, convert, project, geth_provider):
amount = convert("100_000 ETH", int)
spy = mocker.spy(GethDevProcess, "from_uri")

with project.temp_config(test={"balance": amount}):
try:
geth_provider.start()
except Exception:
pass # Exceptions are fine here.

actual = spy.call_args[1]["balance"]
assert actual == amount
22 changes: 20 additions & 2 deletions tests/functional/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from requests import HTTPError
from web3.exceptions import ContractPanicError

from ape import convert
from ape.exceptions import (
APINotImplementedError,
BlockNotFoundError,
Expand All @@ -19,9 +20,10 @@
TransactionNotFoundError,
)
from ape.types import LogFilter
from ape.utils import DEFAULT_TEST_CHAIN_ID
from ape.utils.testing import DEFAULT_TEST_ACCOUNT_BALANCE, DEFAULT_TEST_CHAIN_ID
from ape_ethereum.provider import WEB3_PROVIDER_URI_ENV_VAR_NAME, Web3Provider, _sanitize_web3_url
from ape_ethereum.transactions import TransactionStatusEnum, TransactionType
from ape_test import LocalProvider


def test_uri(eth_tester_provider):
Expand Down Expand Up @@ -200,7 +202,7 @@ def test_provider_get_balance(project, networks, accounts):
balance = networks.provider.get_balance(accounts.test_accounts[0].address)

assert type(balance) is int
assert balance == 1000000000000000000000000
assert balance == DEFAULT_TEST_ACCOUNT_BALANCE


def test_set_timestamp(ethereum):
Expand Down Expand Up @@ -475,3 +477,19 @@ def disconnect(self):
finally:
if WEB3_PROVIDER_URI_ENV_VAR_NAME in os.environ:
del os.environ[WEB3_PROVIDER_URI_ENV_VAR_NAME]


def test_account_balance_state(project, eth_tester_provider, owner):
amount = convert("100_000 ETH", int)

with project.temp_config(test={"balance": amount}):
# NOTE: Purposely using a different instance of the provider
# for better testing isolation.
provider = LocalProvider(
name="test",
network=eth_tester_provider.network,
request_header=eth_tester_provider.request_header,
)
provider.connect()
bal = provider.get_balance(owner.address)
assert bal == amount
11 changes: 11 additions & 0 deletions tests/functional/test_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ape_test import ApeTestConfig


class TestApeTestConfig:
def test_balance_set_from_currency_str(self):
curr_val = "10 Eth"
data = {"balance": curr_val}
cfg = ApeTestConfig.model_validate(data)
actual = cfg.balance
expected = 10_000_000_000_000_000_000 # 10 ETH in WEI
assert actual == expected
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

INITIAL_BALANCE = 1_000_1 * 10**18


@pytest.fixture(scope="session")
def alice(accounts):
Expand Down Expand Up @@ -27,10 +29,10 @@ def start_block_number(chain):

def test_isolation_first(alice, bob, chain, start_block_number):
assert chain.provider.get_block("latest").number == start_block_number
assert bob.balance == 1_000_001 * 10**18
assert bob.balance == INITIAL_BALANCE
alice.transfer(bob, "1 ether")


def test_isolation_second(bob, chain, start_block_number):
assert chain.provider.get_block("latest").number == start_block_number
assert bob.balance == 1_000_001 * 10**18
assert bob.balance == INITIAL_BALANCE

0 comments on commit 5d2db65

Please sign in to comment.