Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: upgrade to 0.8 #144

Merged
merged 20 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: 08
  • Loading branch information
antazoey committed May 20, 2024
commit 240dcb06a4d522e728bb05d49896781d41387b9b
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
# TODO: Replace with macos-latest when works again.
# https://github.com/actions/setup-python/issues/808
os: [ubuntu-latest, macos-12] # eventually add `windows-latest`
python-version: [3.8, 3.9, '3.10', '3.11', '3.12']
python-version: [3.9, '3.10', '3.11', '3.12']

env:
GITHUB_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Compile Solidity contracts.

## Dependencies

- [python3](https://www.python.org/downloads) version 3.8 up to 3.12.
- [python3](https://www.python.org/downloads) version 3.9 up to 3.12.

## Installation

Expand Down
19 changes: 10 additions & 9 deletions ape_solidity/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import os
import re
from collections.abc import Iterable
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Set, Union
from typing import Optional, Union

from ape.exceptions import CompilerError
from ape.utils import pragma_str_to_specifier_set
Expand Down Expand Up @@ -45,7 +46,7 @@ def validate_entry(cls, value):
return value

@property
def _parts(self) -> List[str]:
def _parts(self) -> list[str]:
return self.entry.split("=")

# path normalization needed in case delimiter in remapping key/value
Expand Down Expand Up @@ -96,8 +97,8 @@ class ImportRemappingBuilder:
def __init__(self, contracts_cache: Path):
# import_map maps import keys like `@openzeppelin/contracts`
# to str paths in the compiler cache folder.
self.import_map: Dict[str, str] = {}
self.dependencies_added: Set[Path] = set()
self.import_map: dict[str, str] = {}
self.dependencies_added: set[Path] = set()
self.contracts_cache = contracts_cache

def add_entry(self, remapping: ImportRemapping):
Expand All @@ -108,8 +109,8 @@ def add_entry(self, remapping: ImportRemapping):
self.import_map[remapping.key] = str(path)


def get_import_lines(source_paths: Set[Path]) -> Dict[Path, List[str]]:
imports_dict: Dict[Path, List[str]] = {}
def get_import_lines(source_paths: set[Path]) -> dict[Path, list[str]]:
imports_dict: dict[Path, list[str]] = {}
for filepath in source_paths:
import_set = set()
if not filepath.is_file():
Expand Down Expand Up @@ -168,7 +169,7 @@ def get_pragma_spec_from_str(source_str: str) -> Optional[SpecifierSet]:
return pragma_str_to_specifier_set(pragma_match.groups()[0])


def load_dict(data: Union[str, dict]) -> Dict:
def load_dict(data: Union[str, dict]) -> dict:
return data if isinstance(data, dict) else json.loads(data)


Expand All @@ -183,7 +184,7 @@ def add_commit_hash(version: Union[str, Version]) -> Version:
return get_solc_version_from_binary(solc, with_commit_hash=True)


def verify_contract_filepaths(contract_filepaths: Sequence[Path]) -> Set[Path]:
def verify_contract_filepaths(contract_filepaths: Iterable[Path]) -> set[Path]:
invalid_files = [p.name for p in contract_filepaths if p.suffix != Extension.SOL.value]
if not invalid_files:
return set(contract_filepaths)
Expand All @@ -192,7 +193,7 @@ def verify_contract_filepaths(contract_filepaths: Sequence[Path]) -> Set[Path]:
raise CompilerError(f"Unable to compile '{sources_str}' using Solidity compiler.")


def select_version(pragma_spec: SpecifierSet, options: Sequence[Version]) -> Optional[Version]:
def select_version(pragma_spec: SpecifierSet, options: Iterable[Version]) -> Optional[Version]:
choices = sorted(list(pragma_spec.filter(options)), reverse=True)
return choices[0] if choices else None

Expand Down
101 changes: 52 additions & 49 deletions ape_solidity/compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import re
from collections.abc import Iterable, Iterator
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
from typing import Any, Optional, Union, cast

from ape.api import CompilerAPI, PluginConfig
from ape.contracts import ContractInstance
Expand Down Expand Up @@ -65,7 +66,7 @@ class SolidityConfig(PluginConfig):
Configure the Solidity plugin.
"""

import_remapping: List[str] = []
import_remapping: list[str] = []
"""
Configure re-mappings using a ``=`` separated-str,
e.g. ``"@import_name=path/to/dependency"``.
Expand Down Expand Up @@ -99,9 +100,9 @@ class SolidityConfig(PluginConfig):
class SolidityCompiler(CompilerAPI):
_import_remapping_hash: Optional[int] = None
_cached_project_path: Optional[Path] = None
_cached_import_map: Dict[str, str] = {}
_libraries: Dict[str, Dict[str, AddressType]] = {}
_contracts_needing_libraries: Set[Path] = set()
_cached_import_map: dict[str, str] = {}
_libraries: dict[str, dict[str, AddressType]] = {}
_contracts_needing_libraries: set[Path] = set()

@property
def name(self) -> str:
Expand All @@ -112,11 +113,11 @@ def config(self) -> SolidityConfig:
return cast(SolidityConfig, self.config_manager.get_config(self.name))

@property
def libraries(self) -> Dict[str, Dict[str, AddressType]]:
def libraries(self) -> dict[str, dict[str, AddressType]]:
return self._libraries

@cached_property
def available_versions(self) -> List[Version]:
def available_versions(self) -> list[Version]:
# NOTE: Package version should already be included in available versions
try:
return get_installable_solc_versions()
Expand All @@ -126,7 +127,7 @@ def available_versions(self) -> List[Version]:
return []

@property
def installed_versions(self) -> List[Version]:
def installed_versions(self) -> list[Version]:
"""
Returns a lis of installed version WITHOUT their
commit hashes.
Expand Down Expand Up @@ -219,7 +220,7 @@ def add_library(self, *contracts: ContractInstance):

self._contracts_needing_libraries = set()

def get_versions(self, all_paths: Sequence[Path]) -> Set[str]:
def get_versions(self, all_paths: Iterable[Path]) -> set[str]:
versions = set()
for path in all_paths:
# Make sure we have the compiler available to compile this
Expand All @@ -229,7 +230,7 @@ def get_versions(self, all_paths: Sequence[Path]) -> Set[str]:

return versions

def get_import_remapping(self, base_path: Optional[Path] = None) -> Dict[str, str]:
def get_import_remapping(self, base_path: Optional[Path] = None) -> dict[str, str]:
"""
Config remappings like ``'@import_name=path/to/dependency'`` parsed here
as ``{'@import_name': 'path/to/dependency'}``.
Expand Down Expand Up @@ -407,17 +408,17 @@ def _add_dependencies(
)

def get_compiler_settings(
self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None
) -> Dict[Version, Dict]:
self, contract_filepaths: Iterable[Path], base_path: Optional[Path] = None
) -> dict[Version, dict]:
base_path = base_path or self.config_manager.contracts_folder
files_by_solc_version = self.get_version_map(contract_filepaths, base_path=base_path)
if not files_by_solc_version:
return {}

import_remappings = self.get_import_remapping(base_path=base_path)
settings: Dict = {}
settings: dict = {}
for solc_version, sources in files_by_solc_version.items():
version_settings: Dict[str, Union[Any, List[Any]]] = {
version_settings: dict[str, Union[Any, list[Any]]] = {
"optimizer": {"enabled": self.settings.optimize, "runs": DEFAULT_OPTIMIZATION_RUNS},
"outputSelection": {
str(get_relative_path(p, base_path)): {"*": OUTPUT_SELECTION, "": ["ast"]}
Expand Down Expand Up @@ -447,8 +448,8 @@ def get_compiler_settings(
return settings

def _get_used_remappings(
self, sources, remappings: Dict[str, str], base_path: Optional[Path] = None
) -> Dict[str, str]:
self, sources, remappings: dict[str, str], base_path: Optional[Path] = None
) -> dict[str, str]:
base_path = base_path or self.project_manager.contracts_folder
remappings = remappings or self.get_import_remapping(base_path=base_path)
if not remappings:
Expand All @@ -473,8 +474,8 @@ def _get_used_remappings(
}

def get_standard_input_json(
self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None
) -> Dict[Version, Dict]:
self, contract_filepaths: Iterable[Path], base_path: Optional[Path] = None
) -> dict[Version, dict]:
base_path = base_path or self.config_manager.contracts_folder
files_by_solc_version = self.get_version_map(contract_filepaths, base_path=base_path)
settings = self.get_compiler_settings(contract_filepaths, base_path)
Expand Down Expand Up @@ -510,18 +511,18 @@ def get_standard_input_json(
return input_jsons

def compile(
self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None
) -> List[ContractType]:
self, contract_filepaths: Iterable[Path], base_path: Optional[Path] = None
) -> Iterator[ContractType]:
base_path = base_path or self.config_manager.contracts_folder
contract_versions: Dict[str, Version] = {}
contract_types: List[ContractType] = []
contract_versions: dict[str, Version] = {}
contract_types: list[ContractType] = []
input_jsons = self.get_standard_input_json(contract_filepaths, base_path=base_path)

for solc_version, input_json in input_jsons.items():
logger.info(f"Compiling using Solidity compiler '{solc_version}'.")
cleaned_version = Version(solc_version.base_version)
solc_binary = get_executable(version=cleaned_version)
arguments: Dict = {"solc_binary": solc_binary, "solc_version": cleaned_version}
arguments: dict = {"solc_binary": solc_binary, "solc_version": cleaned_version}

if solc_version >= Version("0.6.9"):
arguments["base_path"] = base_path
Expand All @@ -543,7 +544,7 @@ def compile(
if not contracts:
continue

input_contract_names: List[str] = []
input_contract_names: list[str] = []
for source_id, contracts_out in contracts.items():
for name, _ in contracts_out.items():
# Filter source files that the user did not ask for, such as
Expand Down Expand Up @@ -606,7 +607,7 @@ def compile(
contract_versions[contract_name] = solc_version

# Output compiler data used.
compilers_used: Dict[Version, Compiler] = {}
compilers_used: dict[Version, Compiler] = {}
for ct in contract_types:
if not ct.name:
# Won't happen, but just for mypy.
Expand All @@ -632,7 +633,9 @@ def compile(
compilers_ls = list(compilers_used.values())
self.project_manager.local_project.add_compiler_data(compilers_ls)

return contract_types
# Yield contract-types afterward to ensure we yield only the latest types.
# This avoids collision errors for shared imported contracts across versions.
yield from contract_types

def compile_code(
self,
Expand Down Expand Up @@ -692,14 +695,14 @@ def compile_code(

def _get_unmapped_imports(
self,
contract_filepaths: Sequence[Path],
contract_filepaths: Iterable[Path],
base_path: Optional[Path] = None,
) -> Dict[str, List[Tuple[str, str]]]:
) -> dict[str, list[tuple[str, str]]]:
contracts_path = base_path or self.config_manager.contracts_folder
import_remapping = self.get_import_remapping(base_path=contracts_path)
contract_filepaths_set = verify_contract_filepaths(contract_filepaths)

imports_dict: Dict[str, List[Tuple[str, str]]] = {}
imports_dict: dict[str, list[tuple[str, str]]] = {}
for src_path, import_strs in get_import_lines(contract_filepaths_set).items():
import_list = []
for import_str in import_strs:
Expand All @@ -722,13 +725,13 @@ def _get_unmapped_imports(

def get_imports(
self,
contract_filepaths: Sequence[Path],
contract_filepaths: Iterable[Path],
base_path: Optional[Path] = None,
) -> Dict[str, List[str]]:
) -> dict[str, list[str]]:
contracts_path = base_path or self.config_manager.contracts_folder

def build_map(paths: Set[Path], prev: Optional[Dict] = None) -> Dict[str, List[str]]:
result: Dict[str, List[str]] = prev or {}
def build_map(paths: set[Path], prev: Optional[dict] = None) -> dict[str, list[str]]:
result: dict[str, list[str]] = prev or {}

for src_path, import_strs in get_import_lines(paths).items():
source_id = str(get_relative_path(src_path, contracts_path))
Expand All @@ -754,13 +757,13 @@ def build_map(paths: Set[Path], prev: Optional[Dict] = None) -> Dict[str, List[s

def get_version_map(
self,
contract_filepaths: Union[Path, Sequence[Path]],
contract_filepaths: Union[Path, Iterable[Path]],
base_path: Optional[Path] = None,
) -> Dict[Version, Set[Path]]:
) -> dict[Version, set[Path]]:
# Ensure `.cache` folder is built before getting version map.
self.get_import_remapping(base_path=base_path)

if not isinstance(contract_filepaths, Sequence):
if not isinstance(contract_filepaths, Iterable):
contract_filepaths = [contract_filepaths]

base_path = base_path or self.project_manager.contracts_folder
Expand Down Expand Up @@ -798,7 +801,7 @@ def get_version_map(
install_solc(latest, show_progress=True)

# Adjust best-versions based on imports.
files_by_solc_version: Dict[Version, Set[Path]] = {}
files_by_solc_version: dict[Version, set[Path]] = {}
for source_file_path in source_paths_to_get:
solc_version = self._get_best_version(source_file_path, source_by_pragma_spec)
imported_source_paths = self._get_imported_source_paths(
Expand Down Expand Up @@ -859,9 +862,9 @@ def _get_imported_source_paths(
self,
path: Path,
base_path: Path,
imports: Dict,
source_ids_checked: Optional[List[str]] = None,
) -> Set[Path]:
imports: dict,
source_ids_checked: Optional[list[str]] = None,
) -> set[Path]:
source_ids_checked = source_ids_checked or []
source_identifier = str(get_relative_path(path, base_path))
if source_identifier in source_ids_checked:
Expand Down Expand Up @@ -910,7 +913,7 @@ def _get_pramga_spec_from_str(self, source_str: str) -> Optional[SpecifierSet]:

return pragma_spec

def _get_best_version(self, path: Path, source_by_pragma_spec: Dict) -> Version:
def _get_best_version(self, path: Path, source_by_pragma_spec: dict) -> Version:
compiler_version: Optional[Version] = None
if pragma_spec := source_by_pragma_spec.get(path):
if selected := select_version(pragma_spec, self.installed_versions):
Expand Down Expand Up @@ -1037,7 +1040,7 @@ def get_first_version_pragma(source: str) -> str:
return ""


def get_licenses(src: str) -> List[Tuple[str, str]]:
def get_licenses(src: str) -> list[tuple[str, str]]:
return LICENSES_PATTERN.findall(src)


Expand Down Expand Up @@ -1084,7 +1087,7 @@ def process_licenses(contract: str) -> str:
return contract_with_single_license


def _get_sol_panic(revert_message: str) -> Optional[Type[RuntimeErrorUnion]]:
def _get_sol_panic(revert_message: str) -> Optional[type[RuntimeErrorUnion]]:
if revert_message.startswith(RUNTIME_ERROR_CODE_PREFIX):
# ape-geth (style) plugins show the hex with the Panic ABI prefix.
error_type_val = int(
Expand All @@ -1101,7 +1104,7 @@ def _get_sol_panic(revert_message: str) -> Optional[Type[RuntimeErrorUnion]]:


def _import_str_to_source_id(
_import_str: str, source_path: Path, base_path: Path, import_remapping: Dict[str, str]
_import_str: str, source_path: Path, base_path: Path, import_remapping: dict[str, str]
) -> str:
quote = '"' if '"' in _import_str else "'"

Expand All @@ -1118,18 +1121,18 @@ def _import_str_to_source_id(
source_id_value = str(get_relative_path(path, base_path))

# Get all matches.
matches: List[Tuple[str, str]] = []
import_matches: list[tuple[str, str]] = []
for key, value in import_remapping.items():
if key not in source_id_value:
continue

matches.append((key, value))
import_matches.append((key, value))

if not matches:
if not import_matches:
return source_id_value

# Convert remapping list back to source using longest match (most exact).
key, value = max(matches, key=lambda x: len(x[0]))
key, value = max(import_matches, key=lambda x: len(x[0]))
sections = [s for s in source_id_value.split(key) if s]
depth = len(sections) - 1
source_id_value = ""
Expand All @@ -1147,5 +1150,5 @@ def _import_str_to_source_id(
return source_id_value


def _try_max(ls: List[Any]):
def _try_max(ls: list[Any]):
return max(ls) if ls else None
Loading
Loading