From f854ed591f607e12cabcf111647d86ee5d7e754a Mon Sep 17 00:00:00 2001 From: antazoey Date: Wed, 7 Aug 2024 09:30:18 -0500 Subject: [PATCH] perf: avoid enriching entire trace when only requesting `return_value` (#2208) Co-authored-by: El De-dog-lo <3859395+fubuloubu@users.noreply.github.com> --- src/ape/types/__init__.py | 3 + src/ape_ethereum/ecosystem.py | 122 +++++++++++++++------------- src/ape_ethereum/provider.py | 9 +- src/ape_ethereum/trace.py | 74 +++++++++++++++-- src/ape_test/provider.py | 17 +++- tests/conftest.py | 9 +- tests/functional/geth/conftest.py | 5 -- tests/functional/geth/test_trace.py | 29 +++++-- tests/functional/test_coverage.py | 4 +- tests/integration/cli/test_test.py | 9 +- 10 files changed, 197 insertions(+), 84 deletions(-) diff --git a/src/ape/types/__init__.py b/src/ape/types/__init__.py index e4656c22e0..489f5d2729 100644 --- a/src/ape/types/__init__.py +++ b/src/ape/types/__init__.py @@ -485,6 +485,9 @@ def __eq__(self, other: Any) -> bool: return NotImplemented +CurrencyValueComparable.__name__ = int.__name__ + + CurrencyValue: TypeAlias = CurrencyValueComparable """ An alias to :class:`~ape.types.CurrencyValueComparable` for diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index 54c391b681..21cc8c0a45 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -15,7 +15,9 @@ is_hex, is_hex_address, keccak, + to_bytes, to_checksum_address, + to_hex, ) from ethpm_types import ContractType from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI @@ -57,7 +59,6 @@ StructParser, is_array, returns_array, - to_int, ) from ape.utils.basemodel import _assert_not_ipython_check, only_raise_attribute_error from ape.utils.misc import DEFAULT_MAX_RETRIES_TX, DEFAULT_TRANSACTION_TYPE @@ -161,7 +162,7 @@ def validate_gas_limit(cls, value): return int(value) elif isinstance(value, str) and is_hex(value) and is_0x_prefixed(value): - return to_int(HexBytes(value)) + return int(value, 16) elif is_hex(value): raise ValueError("Gas limit hex str must include '0x' prefix.") @@ -400,7 +401,7 @@ def decode_address(cls, raw_address: RawAddress) -> AddressType: @classmethod def encode_address(cls, address: AddressType) -> RawAddress: - return str(address) + return f"{address}" def decode_transaction_type(self, transaction_type_id: Any) -> type[TransactionAPI]: if isinstance(transaction_type_id, TransactionType): @@ -721,15 +722,16 @@ def decode_returndata(self, abi: MethodABI, raw_data: bytes) -> tuple[Any, ...]: ): # Array of structs or tuples: don't convert to list # Array of anything else: convert to single list - return ( - ( - [ - output_values[0], - ], - ) - if issubclass(type(output_values[0]), Struct) - else ([o for o in output_values[0]],) # type: ignore[union-attr] - ) + + if issubclass(type(output_values[0]), Struct): + return ([output_values[0]],) + + else: + try: + return ([o for o in output_values[0]],) # type: ignore[union-attr] + except Exception: + # On-chains transaction data errors. + return (output_values,) elif returns_array(abi): # Tuple with single item as the array. @@ -747,7 +749,7 @@ def _enrich_value(self, value: Any, **kwargs) -> Any: if len(value) > 24: return humanize_hash(cast(Hash32, value)) - hex_str = HexBytes(value).hex() + hex_str = to_hex(value) if is_hex_address(hex_str): return self._enrich_value(hex_str, **kwargs) @@ -1028,6 +1030,7 @@ def get_abi(_topic: HexStr) -> Optional[LogInputABICollection]: def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI: kwargs["trace"] = trace if not isinstance(trace, Trace): + # Can only enrich `ape_ethereum.trace.Trace` (or subclass) implementations. return trace elif trace._enriched_calltree is not None: @@ -1047,11 +1050,8 @@ def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI: # Return value was discovered already. kwargs["return_value"] = return_value - enriched_calltree = self._enrich_calltree(data, **kwargs) - # Cache the result back on the trace. - trace._enriched_calltree = enriched_calltree - + trace._enriched_calltree = self._enrich_calltree(data, **kwargs) return trace def _enrich_calltree(self, call: dict, **kwargs) -> dict: @@ -1080,10 +1080,10 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict: call["calls"] = [self._enrich_calltree(c, **kwargs) for c in subcalls] # Figure out the contract. - address = call.pop("address", "") + address: AddressType = call.pop("address", "") try: - call["contract_id"] = address = kwargs["contract_address"] = str( - self.decode_address(address) + call["contract_id"] = address = kwargs["contract_address"] = self.decode_address( + address ) except Exception: # Tx was made with a weird address. @@ -1104,10 +1104,11 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict: else: # Collapse pre-compile address calls if 1 <= address_int <= 9: - if len(call.get("calls", [])) == 1: - return call["calls"][0] - - return {"contract_id": f"{address_int}", "calls": call["calls"]} + return ( + call["calls"][0] + if len(call.get("calls", [])) == 1 + else {"contract_id": f"{address_int}", "calls": call["calls"]} + ) depth = call.get("depth", 0) if depth == 0 and address in self.account_manager: @@ -1115,14 +1116,13 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict: else: call["contract_id"] = self._enrich_contract_id(call["contract_id"], **kwargs) - if not (contract_type := self.chain_manager.contracts.get(address)): - # Without a contract, we can enrich no further. + if not (contract_type := self._get_contract_type_for_enrichment(address, **kwargs)): + # Without a contract type, we can enrich no further. return call + kwargs["contract_type"] = contract_type if events := call.get("events"): - call["events"] = self._enrich_trace_events( - events, address=address, contract_type=contract_type - ) + call["events"] = self._enrich_trace_events(events, address=address, **kwargs) method_abi: Optional[Union[MethodABI, ConstructorABI]] = None if is_create: @@ -1131,24 +1131,26 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict: elif call["method_id"] != "0x": method_id_bytes = HexBytes(call["method_id"]) - if method_id_bytes in contract_type.methods: + + # perf: use try/except instead of __contains__ check. + try: method_abi = contract_type.methods[method_id_bytes] + except KeyError: + name = call["method_id"] + else: assert isinstance(method_abi, MethodABI) # For mypy # Check if method name duplicated. If that is the case, use selector. times = len([x for x in contract_type.methods if x.name == method_abi.name]) name = (method_abi.name if times == 1 else method_abi.selector) or call["method_id"] - call = self._enrich_calldata(call, method_abi, contract_type, **kwargs) - - else: - name = call["method_id"] + call = self._enrich_calldata(call, method_abi, **kwargs) else: name = call.get("method_id") or "0x" call["method_id"] = name if method_abi: - call = self._enrich_calldata(call, method_abi, contract_type, **kwargs) + call = self._enrich_calldata(call, method_abi, **kwargs) if kwargs.get("return_value"): # Return value was separately enriched. @@ -1172,10 +1174,12 @@ def _enrich_contract_id(self, address: AddressType, **kwargs) -> str: elif address == ZERO_ADDRESS: return "ZERO_ADDRESS" - if not (contract_type := self.chain_manager.contracts.get(address)): + elif not (contract_type := self._get_contract_type_for_enrichment(address, **kwargs)): + # Without a contract type, we can enrich no further. return address - elif kwargs.get("use_symbol_for_tokens") and "symbol" in contract_type.view_methods: + kwargs["contract_type"] = contract_type + if kwargs.get("use_symbol_for_tokens") and "symbol" in contract_type.view_methods: # Use token symbol as name contract = self.chain_manager.contracts.instance_at( address, contract_type=contract_type @@ -1203,17 +1207,18 @@ def _enrich_calldata( self, call: dict, method_abi: Union[MethodABI, ConstructorABI], - contract_type: ContractType, **kwargs, ) -> dict: calldata = call["calldata"] - if isinstance(calldata, (str, bytes, int)): - calldata_arg = HexBytes(calldata) + if isinstance(calldata, str): + calldata_arg = to_bytes(hexstr=calldata) + elif isinstance(calldata, bytes): + calldata_arg = calldata else: - # Not sure if we can get here. - # Mostly for mypy's sake. + # Already enriched. return call + contract_type = kwargs["contract_type"] if call.get("call_type") and "CREATE" in call.get("call_type", ""): # Strip off bytecode bytecode = ( @@ -1316,18 +1321,15 @@ def _enrich_trace_events( self, events: list[dict], address: Optional[AddressType] = None, - contract_type: Optional[ContractType] = None, + **kwargs, ) -> list[dict]: - return [ - self._enrich_trace_event(e, address=address, contract_type=contract_type) - for e in events - ] + return [self._enrich_trace_event(e, address=address, **kwargs) for e in events] def _enrich_trace_event( self, event: dict, address: Optional[AddressType] = None, - contract_type: Optional[ContractType] = None, + **kwargs, ) -> dict: if "topics" not in event or len(event["topics"]) < 1: # Already enriched or wrong. @@ -1339,16 +1341,11 @@ def _enrich_trace_event( # Cannot enrich further w/o an address. return event - if not contract_type: - try: - contract_type = self.chain_manager.contracts.get(address) - except Exception as err: - logger.debug(f"Error getting contract type during event enrichment: {err}") - return event + if not (contract_type := self._get_contract_type_for_enrichment(address, **kwargs)): + # Without a contract type, we can enrich no further. + return event - if not contract_type: - # Cannot enrich further w/o an contract type. - return event + kwargs["contract_type"] = contract_type # The selector is always the first topic. selector = event["topics"][0] @@ -1393,6 +1390,17 @@ def _enrich_revert_message(self, call: dict) -> dict: return call + def _get_contract_type_for_enrichment( + self, address: AddressType, **kwargs + ) -> Optional[ContractType]: + if not (contract_type := kwargs.get("contract_type")): + try: + contract_type = self.chain_manager.contracts.get(address) + except Exception as err: + logger.debug(f"Error getting contract type during event enrichment: {err}") + + return contract_type + def get_python_types(self, abi_type: ABIType) -> Union[type, Sequence]: return self._python_type_for_abi_type(abi_type) diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index 31c9643cd7..90ab4fe48e 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -133,6 +133,8 @@ class Web3Provider(ProviderAPI, ABC): _supports_debug_trace_call: Optional[bool] = None + _transaction_trace_cache: dict[str, TransactionTrace] = {} + def __new__(cls, *args, **kwargs): assert_web3_provider_uri_env_var_not_set() @@ -440,10 +442,15 @@ def get_storage( raise # Raise original error def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI: + if transaction_hash in self._transaction_trace_cache: + return self._transaction_trace_cache[transaction_hash] + if "call_trace_approach" not in kwargs: kwargs["call_trace_approach"] = self.call_trace_approach - return TransactionTrace(transaction_hash=transaction_hash, **kwargs) + trace = TransactionTrace(transaction_hash=transaction_hash, **kwargs) + self._transaction_trace_cache[transaction_hash] = trace + return trace def send_call( self, diff --git a/src/ape_ethereum/trace.py b/src/ape_ethereum/trace.py index 784933b1f0..9472dea448 100644 --- a/src/ape_ethereum/trace.py +++ b/src/ape_ethereum/trace.py @@ -1,13 +1,14 @@ import json import sys from abc import abstractmethod -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import Iterable, Iterator, Sequence from enum import Enum from functools import cached_property from typing import IO, Any, Optional, Union -from eth_utils import is_0x_prefixed +from eth_utils import is_0x_prefixed, to_hex +from ethpm_types import ContractType, MethodABI from evm_trace import ( CallTreeNode, CallType, @@ -172,6 +173,27 @@ def frames(self) -> Iterator[TraceFrame]: def addresses(self) -> Iterator[AddressType]: yield from self.get_addresses_used() + @cached_property + def root_contract_type(self) -> Optional[ContractType]: + if address := self.transaction.get("to"): + try: + return self.chain_manager.contracts.get(address) + except Exception: + return None + + return None + + @cached_property + def root_method_abi(self) -> Optional[MethodABI]: + method_id = self.transaction.get("data", b"")[:10] + if ct := self.root_contract_type: + try: + return ct.methods[method_id] + except Exception: + return None + + return None + @property def _ecosystem(self) -> EcosystemAPI: if provider := self.network_manager.active_provider: @@ -197,6 +219,25 @@ def get_addresses_used(self, reverse: bool = False): @cached_property def return_value(self) -> Any: + if self._enriched_calltree: + # Only check enrichment output if was already enriched! + # Don't enrich ONLY for return value, as that is very bad performance + # for realistic contract interactions. + return self._return_value_from_enriched_calltree + + elif abi := self.root_method_abi: + return_data = self._return_data_from_trace_frames + if return_data is not None: + try: + return self._ecosystem.decode_returndata(abi, return_data) + except Exception as err: + logger.debug(f"Failed decoding return data from trace frames. Error: {err}") + # Use enrichment method. It is slow but it'll at least work. + + return self._return_value_from_enriched_calltree + + @cached_property + def _return_value_from_enriched_calltree(self) -> Any: calltree = self.enriched_calltree # Check if was cached from enrichment. @@ -227,16 +268,33 @@ def try_get_revert_msg(c) -> Optional[str]: return message # Enrichment call-tree not available. Attempt looking in trace-frames. + return to_hex(self._revert_str_from_trace_frames) + + @cached_property + def _last_frame(self) -> Optional[dict]: try: - frames = list(self.raw_trace_frames) + frame = deque(self.raw_trace_frames, maxlen=1) except Exception as err: logger.error(f"Failed getting traceback: {err}") - frames = [] + return None + + return frame[0] if frame else None - data = frames[-1] if len(frames) > 0 else {} - memory = data.get("memory", []) - if ret := "".join([x[2:] for x in memory[4:]]): - return HexBytes(ret).hex() + @cached_property + def _revert_str_from_trace_frames(self) -> Optional[HexBytes]: + if frame := self._last_frame: + memory = frame.get("memory", []) + if ret := "".join([x[2:] for x in memory[4:]]): + return HexBytes(ret) + + return None + + @cached_property + def _return_data_from_trace_frames(self) -> Optional[HexBytes]: + if frame := self._last_frame: + memory = frame["memory"] + start_pos = int(frame["stack"][2], 16) // 32 + return HexBytes("".join(memory[start_pos:])) return None diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index 5bb9b3e6ad..05bea7044c 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -17,7 +17,7 @@ from web3.providers.eth_tester.defaults import API_ENDPOINTS, static_return from web3.types import TxParams -from ape.api import BlockAPI, PluginConfig, ReceiptAPI, TestProviderAPI, TransactionAPI +from ape.api import BlockAPI, PluginConfig, ReceiptAPI, TestProviderAPI, TraceAPI, TransactionAPI from ape.exceptions import ( APINotImplementedError, ContractLogicError, @@ -31,6 +31,7 @@ from ape.types import AddressType, BlockID, ContractLog, LogFilter, SnapshotID from ape.utils import DEFAULT_TEST_CHAIN_ID, DEFAULT_TEST_HD_PATH, gas_estimation_error_message from ape_ethereum.provider import Web3Provider +from ape_ethereum.trace import TraceApproach, TransactionTrace if TYPE_CHECKING: from ape.api.accounts import TestAccountAPI @@ -388,6 +389,12 @@ def _get_last_base_fee(self) -> int: raise APINotImplementedError("No base fee found in block.") + def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI: + if "call_trace_approach" not in kwargs: + kwargs["call_trace_approach"] = TraceApproach.BASIC + + return EthTesterTransactionTrace(transaction_hash=transaction_hash, **kwargs) + def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError: if isinstance(exception, ValidationError): match = self._CANNOT_AFFORD_GAS_PATTERN.match(str(exception)) @@ -438,3 +445,11 @@ def _get_latest_block(self) -> BlockAPI: def _get_latest_block_rpc(self) -> dict: return self.evm_backend.get_block_by_number("latest") + + +class EthTesterTransactionTrace(TransactionTrace): + @cached_property + def return_value(self) -> Any: + # perf: skip trying anything else, because eth-tester doesn't + # yet implement any tracing RPCs. + return self._return_value_from_enriched_calltree diff --git a/tests/conftest.py b/tests/conftest.py index df3bd5007a..adb999e2ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -92,9 +92,11 @@ def validate_cwd(start_dir): @pytest.fixture def project(): - path = "functional/data/contracts/local" + path = "tests/functional/data/contracts/ethereum/local" with ape.project.temp_config(contracts_folder=path): + ape.project.manifest_path.unlink(missing_ok=True) yield ape.project + ape.project.manifest_path.unlink(missing_ok=True) @pytest.fixture(scope="session") @@ -691,3 +693,8 @@ def shared_contracts_folder(): @pytest.fixture def project_with_contracts(with_dependencies_project_path): return Project(with_dependencies_project_path) + + +@pytest.fixture +def geth_contract(geth_account, vyper_contract_container, geth_provider): + return geth_account.deploy(vyper_contract_container, 0) diff --git a/tests/functional/geth/conftest.py b/tests/functional/geth/conftest.py index 6a402bd16e..e903284534 100644 --- a/tests/functional/geth/conftest.py +++ b/tests/functional/geth/conftest.py @@ -13,11 +13,6 @@ def parity_trace_response(): return TRACE_RESPONSE -@pytest.fixture -def geth_contract(geth_account, vyper_contract_container, geth_provider): - return geth_account.deploy(vyper_contract_container, 0) - - @pytest.fixture def contract_with_call_depth_geth( owner, geth_provider, get_contract_type, leaf_contract_geth, middle_contract_geth diff --git a/tests/functional/geth/test_trace.py b/tests/functional/geth/test_trace.py index 14d50f31bc..d6089ffc62 100644 --- a/tests/functional/geth/test_trace.py +++ b/tests/functional/geth/test_trace.py @@ -33,16 +33,16 @@ │ └── ContractC\.methodC1\( │ windows95="simpler", │ jamaica=345457847457457458457457457, -│ cardinal=ContractA +│ cardinal=Contract[A|C] │ \) \[\d+ gas\] ├── SYMBOL\.callMe\(blue=tx\.origin\) -> tx\.origin \[\d+ gas\] ├── SYMBOL\.methodB2\(trombone=tx\.origin\) \[\d+ gas\] -│ ├── ContractC\.paperwork\(ContractA\) -> \( +│ ├── ContractC\.paperwork\(Contract[A|C]\) -> \( │ │ os="simpler", │ │ country=345457847457457458457457457, -│ │ wings=ContractA +│ │ wings=Contract[A|C] │ │ \) \[\d+ gas\] -│ ├── ContractC\.methodC1\(windows95="simpler", jamaica=0, cardinal=ContractC\) \[\d+ gas\] +│ ├── ContractC\.methodC1\(windows95="simpler", jamaica=0, cardinal=Contract[A|C]\) \[\d+ gas\] │ ├── ContractC\.methodC2\(\) \[\d+ gas\] │ └── ContractC\.methodC2\(\) \[\d+ gas\] ├── ContractC\.addressToValue\(tx.origin\) -> 0 \[\d+ gas\] @@ -53,14 +53,14 @@ │ │ 111344445534535353, │ │ 993453434534534534534977788884443333 │ │ \] \[\d+ gas\] -│ └── ContractC\.methodC1\(windows95="simpler", jamaica=0, cardinal=ContractA\) \[\d+ gas\] +│ └── ContractC\.methodC1\(windows95="simpler", jamaica=0, cardinal=Contract[A|C]\) \[\d+ gas\] └── SYMBOL\.methodB1\(lolol="snitches_get_stiches", dynamo=111\) \[\d+ gas\] ├── ContractC\.getSomeList\(\) -> \[ │ 3425311345134513461345134534531452345, │ 111344445534535353, │ 993453434534534534534977788884443333 │ \] \[\d+ gas\] - └── ContractC\.methodC1\(windows95="simpler", jamaica=111, cardinal=ContractA\) \[\d+ gas\] + └── ContractC\.methodC1\(windows95="simpler", jamaica=111, cardinal=Contract[A|C]\) \[\d+ gas\] """ @@ -307,3 +307,20 @@ def test_call_trace_supports_debug_trace_call(geth_contract, geth_account): trace = CallTrace(tx=tx) _ = trace._traced_call assert trace.supports_debug_trace_call + + +@geth_process_test +def test_return_value(geth_contract, geth_account): + receipt = geth_contract.getFilledArray.transact(sender=geth_account) + trace = receipt.trace + expected = [1, 2, 3] # Hardcoded in contract + assert receipt.return_value == expected + + # In `trace.return_value`, it is still a tuple. + # (unlike receipt.return_value) + actual = trace.return_value[0] + assert actual == expected + + # NOTE: This is very important from a performance perspective! + # (VERY IMPORTANT). We shouldn't need to enrich anything. + assert trace._enriched_calltree is None diff --git a/tests/functional/test_coverage.py b/tests/functional/test_coverage.py index bba0684227..8ffc4c2100 100644 --- a/tests/functional/test_coverage.py +++ b/tests/functional/test_coverage.py @@ -192,8 +192,8 @@ def config_wrapper(self, pytest_config): return ConfigWrapper(pytest_config) @pytest.fixture - def tracker(self, pytest_config): - return CoverageTracker(pytest_config) + def tracker(self, pytest_config, project): + return CoverageTracker(pytest_config, project=project) def test_data(self, tracker): assert tracker.data is not None diff --git a/tests/integration/cli/test_test.py b/tests/integration/cli/test_test.py index 3be6580698..7c40f99707 100644 --- a/tests/integration/cli/test_test.py +++ b/tests/integration/cli/test_test.py @@ -233,11 +233,14 @@ def test_fixture_docs(setup_pytester, integ_project, pytester, eth_tester_provid @skip_projects_except("with-contracts") -def test_gas_flag_when_not_supported(setup_pytester, integ_project, pytester, eth_tester_provider): +def test_gas_flag_when_not_supported( + setup_pytester, project, integ_project, pytester, eth_tester_provider +): _ = eth_tester_provider # Ensure using EthTester for this test. setup_pytester(integ_project) - path = f"{integ_project.path}/tests/test_contract.py::test_contract_interaction_in_tests" - result = pytester.runpytest(path, "--gas") + path = f"{integ_project.path}/tests/test_contract.py" + path_w_test = f"{path}::test_contract_interaction_in_tests" + result = pytester.runpytest(path_w_test, "--gas") actual = "\n".join(result.outlines) expected = ( "Provider 'test' does not support transaction tracing. "