Skip to content

Commit

Permalink
feat: IPython integration with ABI-related constructs (#2174)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jul 12, 2024
1 parent a5cea52 commit e1c9dcf
Show file tree
Hide file tree
Showing 19 changed files with 381 additions and 97 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
"web3[tester]>=6.17.2,<7",
# ** Dependencies maintained by ApeWorX **
"eip712>=0.2.7,<0.3",
"ethpm-types>=0.6.9,<0.7",
"ethpm-types>=0.6.14,<0.7",
"eth_pydantic_types>=0.1.0,<0.2",
"evmchains>=0.0.10,<0.1",
"evm-trace>=0.2.0,<0.3",
Expand Down
125 changes: 116 additions & 9 deletions src/ape/contracts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from eth_utils import to_hex
from ethpm_types.abi import ConstructorABI, ErrorABI, EventABI, MethodABI
from ethpm_types.contract_type import ABI_W_SELECTOR_T, ContractType
from IPython.lib.pretty import for_type

from ape.api import AccountAPI, Address, ReceiptAPI, TransactionAPI
from ape.api.address import BaseAddress
Expand All @@ -32,7 +33,7 @@
MethodNonPayableError,
MissingDeploymentBytecodeError,
)
from ape.logging import logger
from ape.logging import get_rich_console, logger
from ape.types import AddressType, ContractLog, LogFilter, MockContractLog
from ape.utils import (
BaseInterfaceModel,
Expand All @@ -41,7 +42,7 @@
log_instead_of_fail,
singledispatchmethod,
)
from ape.utils.abi import StructParser
from ape.utils.abi import StructParser, _enrich_natspec
from ape.utils.basemodel import (
ExtraAttributesMixin,
ExtraModelAttributes,
Expand Down Expand Up @@ -144,16 +145,65 @@ def __init__(self, contract: "ContractInstance", abis: list[MethodABI]) -> None:
self.contract = contract
self.abis = abis

# If there is a natspec, inject it as the "doc-str" for this method.
# This greatly helps integrate with IPython.
self.__doc__ = self.info

@log_instead_of_fail(default="<ContractMethodHandler>")
def __repr__(self) -> str:
# `<ContractName 0x1234...AbCd>.method_name`
return f"{self.contract.__repr__()}.{self.abis[-1].name}"

@log_instead_of_fail()
def _repr_pretty_(self, printer, cycle):
"""
Show the NatSpec of a Method in any IPython console (including ``ape console``).
"""
console = get_rich_console()
output = self._get_info(enrich=True) or "\n".join(abi.signature for abi in self.abis)
console.print(output)

def __str__(self) -> str:
# `method_name(type1 arg1, ...) -> return_type`
abis = sorted(self.abis, key=lambda abi: len(abi.inputs or []))
return abis[-1].signature

@property
def info(self) -> str:
"""
The NatSpec documentation of the method, if one exists.
Else, returns the empty string.
"""
return self._get_info()

def _get_info(self, enrich: bool = False) -> str:
infos: list[str] = []

for abi in self.abis:
if abi.selector not in self.contract.contract_type.natspecs:
continue

natspec = self.contract.contract_type.natspecs[abi.selector]
header = abi.signature
natspec_str = natspec.replace("\n", "\n ")
infos.append(f"{header}\n {natspec_str}")

if enrich:
infos = [_enrich_natspec(n) for n in infos]

# Same as number of ABIs, regardless of NatSpecs.
number_infos = len(infos)
if number_infos == 1:
return infos[0]

# Ensure some distinction of the infos using number-prefixes.
numeric_infos = []
for idx, info in enumerate(infos):
num_info = f"{idx + 1}: {info}"
numeric_infos.append(num_info)

return "\n\n".join(numeric_infos)

def encode_input(self, *args) -> HexBytes:
selected_abi = _select_method_abi(self.abis, args)
arguments = self.conversion_manager.convert_method_args(selected_abi, args)
Expand Down Expand Up @@ -408,23 +458,48 @@ class ContractEvent(BaseInterfaceModel):
abi: EventABI
_logs: Optional[list[ContractLog]] = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Inject the doc-str using the NatSpec to better integrate with IPython.
# NOTE: This must happen AFTER super().__init__().
self.__doc__ = self.info

@log_instead_of_fail(default="<ContractEvent>")
def __repr__(self) -> str:
return self.abi.signature

@log_instead_of_fail()
def _repr_pretty_(self, printer, cycle):
console = get_rich_console()
console.print(self._get_info(enrich=True))

@property
def name(self) -> str:
"""
The name of the contract event, as defined in the contract.
"""

return self.abi.name

@property
def info(self) -> str:
"""
NatSpec info derived from the contract-type developer-documentation.
"""
return self._get_info()

def _get_info(self, enrich: bool = False) -> str:
info_str = self.abi.signature
if spec := self.contract.contract_type.natspecs.get(self.abi.selector):
spec_indented = spec.replace("\n", "\n ")
info_str = f"{info_str}\n {spec_indented}"

return _enrich_natspec(info_str) if enrich else info_str

def __iter__(self) -> Iterator[ContractLog]: # type: ignore[override]
"""
Get all logs that have occurred for this event.
"""

yield from self.range(self.chain_manager.blocks.height + 1)

@property
Expand Down Expand Up @@ -800,12 +875,42 @@ def decode_input(self, calldata: bytes) -> tuple[str, dict[str, Any]]:
input_dict = ecosystem.decode_calldata(method, rest_calldata)
return method.selector, input_dict

def _create_custom_error_type(self, abi: ErrorABI) -> type[CustomError]:
def _create_custom_error_type(self, abi: ErrorABI, **kwargs) -> type[CustomError]:
def exec_body(namespace):
namespace["abi"] = abi
namespace["contract"] = self
for key, val in kwargs.items():
namespace[key] = val

error_type = types.new_class(abi.name, (CustomError,), {}, exec_body)
natspecs = self.contract_type.natspecs

def _get_info(enrich: bool = False) -> str:
if not (natspec := natspecs.get(abi.selector)):
return ""

elif enrich:
natspec = _enrich_natspec(natspec)

return f"{abi.signature}\n {natspec}"

def _repr_pretty_(cls, *args, **kwargs):
console = get_rich_console()
output = _get_info(enrich=True) or repr(cls)
console.print(output)

def repr_pretty_for_assignment(cls, *args, **kwargs):
return _repr_pretty_(error_type, *args, **kwargs)

info = _get_info()
error_type.info = error_type.__doc__ = info # type: ignore
if info:
error_type._repr_pretty_ = repr_pretty_for_assignment # type: ignore

# Register the dynamically-created type with IPython so it integrates.
for_type(type(error_type), _repr_pretty_)

return types.new_class(abi.name, (CustomError,), {}, exec_body)
return error_type


class ContractInstance(BaseAddress, ContractTypeWrapper):
Expand Down Expand Up @@ -1063,7 +1168,7 @@ def get_error_by_signature(self, signature: str) -> type[CustomError]:
raise err

for contract_err in options:
if contract_err.abi.signature == signature:
if contract_err.abi and contract_err.abi.signature == signature:
return contract_err

raise err
Expand Down Expand Up @@ -1107,14 +1212,16 @@ def _errors_(self) -> dict[str, list[type[CustomError]]]:
for abi in abi_list:
error_type = None
for existing_cls in prior_errors:
if existing_cls.abi.signature == abi.signature:
if existing_cls.abi and existing_cls.abi.signature == abi.signature:
# Error class was previously defined by contract at same address.
error_type = existing_cls
break

if error_type is None:
# Error class is being defined for the first time.
error_type = self._create_custom_error_type(abi)
error_type = self._create_custom_error_type(
abi, contract_address=self.address
)
self.chain_manager.contracts._cache_error(self.address, error_type)

errors_to_add.append(error_type)
Expand Down
32 changes: 31 additions & 1 deletion src/ape/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import IO, Any, Optional, Union

import click
from rich.console import Console as RichConsole
from yarl import URL


Expand Down Expand Up @@ -288,4 +289,33 @@ def sanitize_url(url: str) -> str:
logger = ApeLogger.create()


__all__ = ["DEFAULT_LOG_LEVEL", "logger", "LogLevel", "ApeLogger"]
class _RichConsoleFactory:
rich_console_map: dict[str, RichConsole] = {}

def get_console(self, file: Optional[IO[str]] = None, **kwargs) -> RichConsole:
# Configure custom file console
file_id = str(file)
if file_id not in self.rich_console_map:
self.rich_console_map[file_id] = RichConsole(file=file, width=100, **kwargs)

return self.rich_console_map[file_id]


_factory = _RichConsoleFactory()


def get_rich_console(file: Optional[IO[str]] = None, **kwargs) -> RichConsole:
"""
Get an Ape-configured rich console.
Args:
file (Optional[IO[str]]): The file to output to. Will default
to using stdout.
Returns:
``rich.Console``.
"""
return _factory.get_console(file)


__all__ = ["DEFAULT_LOG_LEVEL", "logger", "LogLevel", "ApeLogger", "get_rich_console"]
25 changes: 8 additions & 17 deletions src/ape/managers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pandas as pd
from eth_pydantic_types import HexBytes
from ethpm_types import ABI, ContractType
from rich import get_console
from rich.box import SIMPLE
from rich.console import Console as RichConsole
from rich.table import Table
Expand Down Expand Up @@ -38,7 +37,7 @@
TransactionNotFoundError,
UnknownSnapshotError,
)
from ape.logging import logger
from ape.logging import get_rich_console, logger
from ape.managers.base import BaseManager
from ape.types import AddressType, GasReport, SnapshotID, SourceTraceback
from ape.utils import (
Expand Down Expand Up @@ -1414,8 +1413,6 @@ class ReportManager(BaseManager):
**NOTE**: This class is not part of the public API.
"""

rich_console_map: dict[str, RichConsole] = {}

def show_gas(self, report: GasReport, file: Optional[IO[str]] = None):
tables: list[Table] = []

Expand Down Expand Up @@ -1459,7 +1456,7 @@ def show_gas(self, report: GasReport, file: Optional[IO[str]] = None):
def echo(
self, *rich_items, file: Optional[IO[str]] = None, console: Optional[RichConsole] = None
):
console = console or self._get_console(file)
console = console or get_rich_console(file)
console.print(*rich_items)

def show_source_traceback(
Expand All @@ -1469,28 +1466,22 @@ def show_source_traceback(
console: Optional[RichConsole] = None,
failing: bool = True,
):
console = console or self._get_console(file)
console = console or get_rich_console(file)
style = "red" if failing else None
console.print(str(traceback), style=style)

def show_events(
self, events: list, file: Optional[IO[str]] = None, console: Optional[RichConsole] = None
):
console = console or self._get_console(file)
console = console or get_rich_console(file)
console.print("Events emitted:")
for event in events:
console.print(event)

def _get_console(self, file: Optional[IO[str]] = None) -> RichConsole:
if not file:
return get_console()

# Configure custom file console
file_id = str(file)
if file_id not in self.rich_console_map:
self.rich_console_map[file_id] = RichConsole(file=file, width=100)

return self.rich_console_map[file_id]
def _get_console(self, *args, **kwargs):
# TODO: Delete this method in v0.9.
# It only exists for backwards compat.
return get_rich_console(*args, **kwargs)


class ChainManager(BaseManager):
Expand Down
7 changes: 7 additions & 0 deletions src/ape/utils/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ape.logging import logger

ARRAY_PATTERN = re.compile(r"[(*\w,? )]*\[\d*]")
NATSPEC_KEY_PATTERN = re.compile(r"(@\w+)")


def is_array(abi_type: Union[str, ABIType]) -> bool:
Expand Down Expand Up @@ -517,3 +518,9 @@ def decode_value(self, abi_type: str, value: Any) -> Any:
# ecosystem API through the calling function.

return value


def _enrich_natspec(natspec: str) -> str:
# Ensure the natspec @-words are highlighted.
replacement = r"[bright_red]\1[/]"
return re.sub(NATSPEC_KEY_PATTERN, replacement, natspec)
4 changes: 2 additions & 2 deletions src/ape/utils/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class ManagerAccessMixin:
_test_runner: ClassVar[Optional["PytestApeRunner"]] = None

@classproperty
def provider(self) -> "ProviderAPI":
def provider(cls) -> "ProviderAPI":
"""
The current active provider if connected to one.
Expand All @@ -145,7 +145,7 @@ def provider(self) -> "ProviderAPI":
Returns:
:class:`~ape.api.providers.ProviderAPI`
"""
if provider := self.network_manager.active_provider:
if provider := cls.network_manager.active_provider:
return provider

raise ProviderNotConnectedError()
Expand Down
Loading

0 comments on commit e1c9dcf

Please sign in to comment.