Skip to content

Commit

Permalink
feat: Allow adding dependencies from Etherscan (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Feb 2, 2024
1 parent cf743ee commit c3fc86c
Show file tree
Hide file tree
Showing 14 changed files with 336 additions and 33 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,24 @@ etherscan:
uri: https://custom.scan
api_uri: https://api.custom.scan/api
```
## Dependencies
You can use dependencies from Etherscan in your projects.
Configure them like this:
```yaml
dependencies:
- name: Spork
etherscan: "0xb624FdE1a972B1C89eC1dAD691442d5E8E891469"
ecosystem: ethereum
network: mainnet
```
Then, access contract types from the dependency in your code:
```python
from ape import project

spork_contract_type = project.dependencies["Spork"]["etherscan"].Spork
```
6 changes: 6 additions & 0 deletions ape_etherscan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ape import plugins

from .config import EtherscanConfig
from .dependency import EtherscanDependency
from .explorer import Etherscan
from .query import EtherscanQueryEngine
from .utils import NETWORKS
Expand All @@ -22,3 +23,8 @@ def query_engines():
@plugins.register(plugins.Config)
def config_class():
return EtherscanConfig


@plugins.register(plugins.DependencyPlugin)
def dependencies():
yield "etherscan", EtherscanDependency
10 changes: 6 additions & 4 deletions ape_etherscan/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ape_etherscan.config import EtherscanConfig
from ape_etherscan.exceptions import (
ContractNotVerifiedError,
UnhandledResultError,
UnsupportedEcosystemError,
UnsupportedNetworkError,
Expand Down Expand Up @@ -302,7 +303,7 @@ def get_source_code(self) -> SourceCodeResponse:
}
result = self._get(params=params)

if not (result_list := result.value or []):
if not (result_list := result.value):
return SourceCodeResponse()

elif len(result_list) > 1:
Expand All @@ -312,9 +313,10 @@ def get_source_code(self) -> SourceCodeResponse:
if not isinstance(data, dict):
raise UnhandledResultError(result, data)

abi = data.get("ABI") or ""
name = data.get("ContractName") or "unknown"
return SourceCodeResponse(abi, name)
if data.get("ABI") == "Contract source code not verified":
raise ContractNotVerifiedError(result, self._address)

return SourceCodeResponse.model_validate(data)

def verify_source_code(
self,
Expand Down
64 changes: 64 additions & 0 deletions ape_etherscan/dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from ape.api.projects import DependencyAPI
from ape.exceptions import ProjectError
from ape.types import AddressType
from ethpm_types import PackageManifest
from hexbytes import HexBytes
from pydantic import AnyUrl, HttpUrl, field_validator

from .explorer import Etherscan


class EtherscanDependency(DependencyAPI):
etherscan: str
ecosystem: str = "ethereum"
network: str = "mainnet"

@field_validator("etherscan", mode="before")
@classmethod
def handle_int(cls, value):
return value if isinstance(value, str) else HexBytes(value).hex()

@property
def version_id(self) -> str:
return f"{self.ecosystem}_{self.network}"

@property
def address(self) -> AddressType:
return self.network_manager.ethereum.decode_address(self.etherscan)

@property
def uri(self) -> AnyUrl:
return HttpUrl(f"{self.explorer.get_address_url(self.address)}#code")

@property
def explorer(self) -> Etherscan:
if self.network_manager.active_provider:
explorer = self.provider.network.explorer
if isinstance(explorer, Etherscan):
# Could be using a different network.
return explorer
else:
return self.network_manager.ethereum.mainnet.explorer

# Assume Ethereum
return self.network_manager.ethereum.mainnet.explorer

def extract_manifest(self, use_cache: bool = True) -> PackageManifest:
ecosystem = self.network_manager.get_ecosystem(self.ecosystem)
network = ecosystem.get_network(self.network)

ctx = None
if self.network_manager.active_provider is None:
ctx = network.use_default_provider()
ctx.__enter__()

try:
manifest = self.explorer.get_manifest(self.address)
finally:
if ctx:
ctx.__exit__(None)

if not manifest:
raise ProjectError(f"Etherscan dependency '{self.name}' not verified.")

return manifest
9 changes: 9 additions & 0 deletions ape_etherscan/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ def __init__(self, response: Union[Response, "EtherscanResponse"], message: str)
super().__init__(f"Response indicated failure: {message}")


class ContractNotVerifiedError(EtherscanResponseError):
"""
Raised when a contract is not verified on Etherscan.
"""

def __init__(self, response: Union[Response, "EtherscanResponse"], address: str):
super().__init__(response, f"Contract '{address}' not verified.")


class UnhandledResultError(EtherscanResponseError):
"""
Raised in specific client module where the result from Etherscan
Expand Down
68 changes: 57 additions & 11 deletions ape_etherscan/explorer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import json
from json.decoder import JSONDecodeError
from typing import Optional

from ape.api import ExplorerAPI, PluginConfig
from ape.contracts import ContractInstance
from ape.exceptions import ProviderNotConnectedError
from ape.logging import logger
from ape.types import AddressType, ContractType
from ethpm_types import Compiler, PackageManifest
from ethpm_types.source import Source

from ape_etherscan.client import ClientFactory, get_etherscan_api_uri, get_etherscan_uri
from ape_etherscan.client import (
ClientFactory,
SourceCodeResponse,
get_etherscan_api_uri,
get_etherscan_uri,
)
from ape_etherscan.exceptions import ContractNotVerifiedError
from ape_etherscan.types import EtherscanInstance
from ape_etherscan.verify import SourceVerifier

Expand Down Expand Up @@ -47,23 +53,63 @@ def _client_factory(self) -> ClientFactory:
)
)

def get_contract_type(self, address: AddressType) -> Optional[ContractType]:
def get_manifest(self, address: AddressType) -> Optional[PackageManifest]:
try:
response = self._get_source_code(address)
except ContractNotVerifiedError:
return None

settings = {
"optimizer": {
"enabled": response.optimization_used,
"runs": response.optimization_runs,
},
}

code = response.source_code
if code.startswith("{"):
# JSON verified.
data = json.loads(code)
compiler = Compiler(
name=data.get("language", "Solidity"),
version=response.compiler_version,
settings=data.get("settings", settings),
contractTypes=[response.name],
)
source_data = data.get("sources", {})
sources = {
src_id: Source(content=cont.get("content", ""))
for src_id, cont in source_data.items()
}

else:
# A flattened source.
source_id = f"{response.name}.sol"
compiler = Compiler(
name="Solidity",
version=response.compiler_version,
settings=settings,
contractTypes=[response.name],
)
sources = {source_id: Source(content=response.source_code)}

return PackageManifest(compilers=[compiler], sources=sources)

def _get_source_code(self, address: AddressType) -> SourceCodeResponse:
if not self.conversion_manager.is_type(address, AddressType):
# Handle non-checksummed addresses
address = self.conversion_manager.convert(str(address), AddressType)

client = self._client_factory.get_contract_client(address)
source_code = client.get_source_code()
if not (abi_string := source_code.abi):
return None
return client.get_source_code()

def get_contract_type(self, address: AddressType) -> Optional[ContractType]:
try:
abi = json.loads(abi_string)
except JSONDecodeError as err:
logger.error(f"Error with contract ABI: {err}")
source_code = self._get_source_code(address)
except ContractNotVerifiedError:
return None

contract_type = ContractType(abi=abi, contractName=source_code.name)
contract_type = ContractType(abi=source_code.abi, contractName=source_code.name)
if source_code.name == "Vyper_contract" and "symbol" in contract_type.view_methods:
try:
contract = ContractInstance(address, contract_type)
Expand Down
47 changes: 42 additions & 5 deletions ape_etherscan/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import re
from dataclasses import dataclass
from typing import Dict, List, Union

from ape.utils import cached_property
from ethpm_types import BaseModel
from pydantic import Field, field_validator

from ape_etherscan.exceptions import EtherscanResponseError, get_request_error

Expand All @@ -17,10 +20,44 @@ class EtherscanInstance:
api_uri: str


@dataclass
class SourceCodeResponse:
abi: str = ""
name: str = "unknown"
class SourceCodeResponse(BaseModel):
abi: List = Field([], alias="ABI")
name: str = Field("unknown", alias="ContractName")
source_code: str = Field("", alias="SourceCode")
compiler_version: str = Field("", alias="CompilerVersion")
optimization_used: bool = Field(True, alias="OptimizationUsed")
optimization_runs: int = Field(200, alias="Runs")
evm_version: str = Field("Default", alias="EVMVersion")
library: str = Field("", alias="Library")
license_type: str = Field("", alias="LicenseType")
proxy: bool = Field(False, alias="Proxy")
implementation: str = Field("", alias="Implementation")
swarm_source: str = Field("", alias="SwarmSource")

@field_validator("optimization_used", "proxy", mode="before")
@classmethod
def validate_bools(cls, value):
return bool(int(value))

@field_validator("abi", mode="before")
@classmethod
def validate_abi(cls, value):
return json.loads(value)

@field_validator("source_code", mode="before")
@classmethod
def validate_source_code(cls, value):
if value.startswith("{"):
# NOTE: Have to deal with very poor JSON
# response from Etherscan.
fixed = re.sub(r"\r\n\s*", "", value)
fixed = re.sub(r"\r\n\s*", "", fixed)
if fixed.startswith("{{"):
fixed = fixed[1:-1]

return fixed

return value


@dataclass
Expand Down Expand Up @@ -49,7 +86,7 @@ def value(self) -> ResponseValue:

message = response_data.get("message", "")
is_error = response_data.get("isError", 0) or message == "NOTOK"
if is_error and self.raise_on_exceptions:
if is_error is True and self.raise_on_exceptions:
raise get_request_error(self.response, self.ecosystem)

result = response_data.get("result", message)
Expand Down
48 changes: 38 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def address(contract_to_verify):
@pytest.fixture(scope="session")
def contract_address_map(address):
return {
"get_contract_response": address,
"get_contract_response_flattened": address,
"get_contract_response_json": "0x000075Dc60EdE898f11b0d5C6cA31D7A6D050eeD",
"get_contract_response_not_verified": "0x5777d92f208679DB4b9778590Fa3CAB3aC9e2168",
"get_proxy_contract_response": "0x55A8a39bc9694714E2874c1ce77aa1E599461E18",
"get_vyper_contract_response": "0xdA816459F1AB5631232FE5e97a05BBBb94970c95",
}
Expand Down Expand Up @@ -419,8 +421,20 @@ def side_effect(self):

def _get_contract_type_response(self, file_name: str) -> Any:
test_data_path = MOCK_RESPONSES_PATH / f"{file_name}.json"
with open(test_data_path) as response_data_file:
return self.get_mock_response(response_data_file, file_name=file_name)
assert test_data_path.is_file(), f"Setup failed - missing test data {file_name}"
if "flattened" in file_name:
with open(test_data_path) as response_data_file:
return self.get_mock_response(response_data_file, file_name=file_name)

else:
# NOTE: Since the JSON is messed up for these, we can' load the mocks
# even without a weird hack.
content = (
MOCK_RESPONSES_PATH / "get_contract_response_json_source_code.json"
).read_text()
data = json.loads(test_data_path.read_text())
data["SourceCode"] = content
return self.get_mock_response(data, file_name=file_name)

def _expected_get_ct_params(self, address: str) -> Dict:
return {"module": "contract", "action": "getsourcecode", "address": address}
Expand Down Expand Up @@ -462,23 +476,37 @@ def get_mock_response(
self, response_data: Optional[Union[IO, Dict, str, MagicMock]] = None, **kwargs
):
if isinstance(response_data, str):
return self.get_mock_response({"result": response_data})
return self.get_mock_response({"result": response_data, **kwargs})

elif isinstance(response_data, _io.TextIOWrapper):
return self.get_mock_response(json.load(response_data), **kwargs)

elif isinstance(response_data, MagicMock):
# Mock wasn't set.
response_data = {}
response_data = {**kwargs}

assert isinstance(response_data, dict)
return self._get_mock_response(response_data=response_data, **kwargs)

def _get_mock_response(
self,
response_data: Optional[Dict] = None,
response_text: Optional[str] = None,
*args,
**kwargs,
):
response = self.mocker.MagicMock(spec=Response)
assert isinstance(response_data, dict) # For mypy
overrides: Dict = kwargs.get("response_overrides", {})
response.json.return_value = {**response_data, **overrides}
response.text = json.dumps(response_data or {})
if response_data:
assert isinstance(response_data, dict) # For mypy
overrides: Dict = kwargs.get("response_overrides", {})
response.json.return_value = {**response_data, **overrides}
if not response_text:
response_text = json.dumps(response_data or {})

response.status_code = 200
if response_text:
response.text = response_text

response.status_code = 200
for key, val in kwargs.items():
setattr(response, key, val)

Expand Down

Large diffs are not rendered by default.

Loading

0 comments on commit c3fc86c

Please sign in to comment.