diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 73acc0c..3fa88b7 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -65,7 +65,7 @@ jobs: # TODO: Replace with macos-latest when works again. # https://github.com/actions/setup-python/issues/808 os: [ubuntu-latest, macos-12] # eventually add `windows-latest` - python-version: [3.8, 3.9, '3.10', '3.11', '3.12'] + python-version: [3.9, '3.10', '3.11', '3.12'] env: GITHUB_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6d3caa5..45b3831 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: rev: 0.7.17 hooks: - id: mdformat - additional_dependencies: [mdformat-gfm, mdformat-frontmatter] + additional_dependencies: [mdformat-gfm, mdformat-frontmatter, mdformat-pyproject] default_language_version: diff --git a/README.md b/README.md index 0f86cdf..7433c5d 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Compile Solidity contracts. ## Dependencies -- [python3](https://www.python.org/downloads) version 3.8 up to 3.12. +- [python3](https://www.python.org/downloads) version 3.9 up to 3.12. ## Installation @@ -62,32 +62,32 @@ solidity: ### Dependency Mapping -To configure import remapping, use your project's `ape-config.yaml` file: - -```yaml -solidity: - import_remapping: - - "@openzeppelin=path/to/open_zeppelin/contracts" -``` - -If you are using the `dependencies:` key in your `ape-config.yaml`, `ape` can automatically -search those dependencies for the path. +By default, `ape-solidity` knows to look at installed dependencies for potential remapping-values and will use those when it notices you are importing them. +For example, if you are using dependencies like: ```yaml dependencies: - - name: OpenZeppelin + - name: openzeppelin github: OpenZeppelin/openzeppelin-contracts version: 4.4.2 - -solidity: - import_remapping: - - "@openzeppelin=OpenZeppelin/4.4.2" ``` -Once you have your dependencies configured, you can import packages using your import keys: +And your source files import from `openzeppelin` this way: ```solidity -import "@openzeppelin/token/ERC721/ERC721.sol"; +import "@openzeppelin/contracts/token/ERC721/ERC721.sol"; +``` + +Ape knows how to resolve the `@openzeppelin` value and find the correct source. + +If you want to override this behavior or add new remappings that are not dependencies, you can add them to your `ape-config.yaml` under the `solidity:` key. +For example, let's say you have downloaded `openzeppelin` somewhere and do not have it installed in Ape. +You can map to your local install of `openzeppelin` this way: + +```yaml +solidity: + import_remapping: + - "@openzeppelin=path/to/openzeppelin" ``` ### Library Linking diff --git a/ape_solidity/__init__.py b/ape_solidity/__init__.py index 4655900..78ecdcc 100644 --- a/ape_solidity/__init__.py +++ b/ape_solidity/__init__.py @@ -1,6 +1,7 @@ from ape import plugins -from .compiler import Extension, SolidityCompiler, SolidityConfig +from ._utils import Extension +from .compiler import SolidityCompiler, SolidityConfig @plugins.register(plugins.Config) diff --git a/ape_solidity/_utils.py b/ape_solidity/_utils.py index e34739a..fe3fdc3 100644 --- a/ape_solidity/_utils.py +++ b/ape_solidity/_utils.py @@ -1,22 +1,17 @@ import json -import os import re +from collections.abc import Iterable from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Sequence, Set, Union +from typing import Optional, Union from ape.exceptions import CompilerError from ape.utils import pragma_str_to_specifier_set from packaging.specifiers import SpecifierSet -from packaging.version import InvalidVersion from packaging.version import Version -from packaging.version import Version as _Version -from pydantic import BaseModel, field_validator from solcx.install import get_executable from solcx.wrapper import get_solc_version as get_solc_version_from_binary -from ape_solidity.exceptions import IncorrectMappingFormatError - OUTPUT_SELECTION = [ "abi", "bin-runtime", @@ -32,87 +27,11 @@ class Extension(Enum): SOL = ".sol" -class ImportRemapping(BaseModel): - entry: str - packages_cache: Path - - @field_validator("entry", mode="before") - @classmethod - def validate_entry(cls, value): - if len((value or "").split("=")) != 2: - raise IncorrectMappingFormatError() - - return value - - @property - def _parts(self) -> List[str]: - return self.entry.split("=") - - # path normalization needed in case delimiter in remapping key/value - # and system path delimiter are different (Windows as an example) - @property - def key(self) -> str: - return os.path.normpath(self._parts[0]) - - @property - def name(self) -> str: - suffix_str = os.path.normpath(self._parts[1]) - return suffix_str.split(os.path.sep)[0] - - @property - def package_id(self) -> Path: - suffix = Path(self._parts[1]) - data_folder_cache = self.packages_cache / suffix - - try: - _Version(suffix.name) - if not suffix.name.startswith("v"): - suffix = suffix.parent / f"v{suffix.name}" - - except InvalidVersion: - # The user did not specify a version_id suffix in their mapping. - # We try to smartly figure one out, else error. - if len(Path(suffix).parents) == 1 and data_folder_cache.is_dir(): - version_ids = [d.name for d in data_folder_cache.iterdir()] - if len(version_ids) == 1: - # Use only version ID available. - suffix = suffix / version_ids[0] - - elif not version_ids: - raise CompilerError(f"Missing dependency '{suffix}'.") - - else: - options_str = ", ".join(version_ids) - raise CompilerError( - "Ambiguous version reference. " - f"Please set import remapping value to {suffix}/{{version_id}} " - f"where 'version_id' is one of '{options_str}'." - ) - - return suffix - - -class ImportRemappingBuilder: - def __init__(self, contracts_cache: Path): - # import_map maps import keys like `@openzeppelin/contracts` - # to str paths in the compiler cache folder. - self.import_map: Dict[str, str] = {} - self.dependencies_added: Set[Path] = set() - self.contracts_cache = contracts_cache - - def add_entry(self, remapping: ImportRemapping): - path = remapping.package_id - if self.contracts_cache not in path.parents: - path = self.contracts_cache / path - - self.import_map[remapping.key] = str(path) - - -def get_import_lines(source_paths: Set[Path]) -> Dict[Path, List[str]]: - imports_dict: Dict[Path, List[str]] = {} +def get_import_lines(source_paths: Iterable[Path]) -> dict[Path, list[str]]: + imports_dict: dict[Path, list[str]] = {} for filepath in source_paths: import_set = set() - if not filepath.is_file(): + if not filepath or not filepath.is_file(): continue source_lines = filepath.read_text().splitlines() @@ -168,7 +87,7 @@ def get_pragma_spec_from_str(source_str: str) -> Optional[SpecifierSet]: return pragma_str_to_specifier_set(pragma_match.groups()[0]) -def load_dict(data: Union[str, dict]) -> Dict: +def load_dict(data: Union[str, dict]) -> dict: return data if isinstance(data, dict) else json.loads(data) @@ -183,17 +102,12 @@ def add_commit_hash(version: Union[str, Version]) -> Version: return get_solc_version_from_binary(solc, with_commit_hash=True) -def verify_contract_filepaths(contract_filepaths: Sequence[Path]) -> Set[Path]: - invalid_files = [p.name for p in contract_filepaths if p.suffix != Extension.SOL.value] - if not invalid_files: - return set(contract_filepaths) - - sources_str = "', '".join(invalid_files) - raise CompilerError(f"Unable to compile '{sources_str}' using Solidity compiler.") +def get_versions_can_use(pragma_spec: SpecifierSet, options: Iterable[Version]) -> list[Version]: + return sorted(list(pragma_spec.filter(options)), reverse=True) -def select_version(pragma_spec: SpecifierSet, options: Sequence[Version]) -> Optional[Version]: - choices = sorted(list(pragma_spec.filter(options)), reverse=True) +def select_version(pragma_spec: SpecifierSet, options: Iterable[Version]) -> Optional[Version]: + choices = get_versions_can_use(pragma_spec, options) return choices[0] if choices else None diff --git a/ape_solidity/compiler.py b/ape_solidity/compiler.py index 6775c49..a081f9c 100644 --- a/ape_solidity/compiler.py +++ b/ape_solidity/compiler.py @@ -1,20 +1,23 @@ import os import re +from collections.abc import Iterable, Iterator from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast +from typing import Any, Optional, Union from ape.api import CompilerAPI, PluginConfig from ape.contracts import ContractInstance -from ape.exceptions import CompilerError, ConfigError, ContractLogicError +from ape.exceptions import CompilerError, ConfigError, ContractLogicError, ProjectError from ape.logging import logger +from ape.managers.project import ProjectManager from ape.types import AddressType, ContractType -from ape.utils import cached_property, get_package_version, get_relative_path +from ape.utils import cached_property, get_full_extension, get_relative_path +from ape.version import version from eth_pydantic_types import HexBytes from eth_utils import add_0x_prefix, is_0x_prefixed -from ethpm_types import PackageManifest from ethpm_types.source import Compiler, Content from packaging.specifiers import SpecifierSet from packaging.version import Version +from pydantic import model_validator from requests.exceptions import ConnectionError from solcx import ( compile_source, @@ -29,21 +32,18 @@ from ape_solidity._utils import ( OUTPUT_SELECTION, Extension, - ImportRemapping, - ImportRemappingBuilder, add_commit_hash, get_import_lines, get_pragma_spec_from_path, get_pragma_spec_from_str, + get_versions_can_use, load_dict, select_version, strip_commit_hash, - verify_contract_filepaths, ) from ape_solidity.exceptions import ( RUNTIME_ERROR_CODE_PREFIX, RUNTIME_ERROR_MAP, - IncorrectMappingFormatError, RuntimeErrorType, RuntimeErrorUnion, SolcCompileError, @@ -60,32 +60,66 @@ DEFAULT_OPTIMIZATION_RUNS = 200 +class ImportRemapping(PluginConfig): + """ + A remapped import set in the config. + """ + + @model_validator(mode="before") + def validate_str(cls, value): + if isinstance(value, str): + parts = value.split("=") + return {"key": parts[0], "value": parts[1]} + + return value + + """ + The key of the remapping, such as ``@openzeppelin``. + """ + key: str + + """ + The value to use in place of the key, + such as ``path/somewhere/else``. + """ + value: str + + def __str__(self) -> str: + return f"{self.key}={self.value}" + + def __eq__(self, other): + if isinstance(other, str): + return str(self) == other + + return super().__eq__(other) + + class SolidityConfig(PluginConfig): """ - Configure the Solidity plugin. + Configure the ape-solidity plugin. """ - import_remapping: List[str] = [] + import_remapping: list[ImportRemapping] = [] """ - Configure re-mappings using a ``=`` separated-str, - e.g. ``"@import_name=path/to/dependency"``. + Custom remappings as a key value map.. + Note: You do not need to specify dependencies here. """ optimize: bool = True """ - Set to ``False`` to disable compiler-optimizations. + Compile with optimization. Defaults to ``True``. """ version: Optional[str] = None """ - The compiler version to use. Defaults to selecting - the best version(s) it can for each file-set. + Hardcode a Solidity version to use. When not set, + ape-solidity attempts to use the best version(s) + available. """ evm_version: Optional[str] = None """ - The EVM version (fork) to use. Defaults to letting - the compiler decide. + Compile targeting this EVM version. """ via_ir: Optional[bool] = None @@ -96,28 +130,34 @@ class SolidityConfig(PluginConfig): """ +def _get_flattened_source(path: Path, name: Optional[str] = None) -> str: + name = name or path.name + result = f"// File: {name}\n" + result += path.read_text() + "\n" + return result + + class SolidityCompiler(CompilerAPI): - _import_remapping_hash: Optional[int] = None - _cached_project_path: Optional[Path] = None - _cached_import_map: Dict[str, str] = {} - _libraries: Dict[str, Dict[str, AddressType]] = {} - _contracts_needing_libraries: Set[Path] = set() + """ + The implementation of the ape-solidity Compiler class. + Implements all methods in :class:`~ape.api.compilers.CompilerAPI`. + Compiles ``.sol`` files into ``ContractTypes`` for usage in the + Ape framework. + """ + + # Libraries adding for linking. See `add_library` method. + _libraries: dict[str, dict[str, AddressType]] = {} @property def name(self) -> str: return "solidity" @property - def config(self) -> SolidityConfig: - return cast(SolidityConfig, self.config_manager.get_config(self.name)) - - @property - def libraries(self) -> Dict[str, Dict[str, AddressType]]: + def libraries(self) -> dict[str, dict[str, AddressType]]: return self._libraries @cached_property - def available_versions(self) -> List[Version]: - # NOTE: Package version should already be included in available versions + def available_versions(self) -> list[Version]: try: return get_installable_solc_versions() except ConnectionError: @@ -126,7 +166,7 @@ def available_versions(self) -> List[Version]: return [] @property - def installed_versions(self) -> List[Version]: + def installed_versions(self) -> list[Version]: """ Returns a lis of installed version WITHOUT their commit hashes. @@ -150,13 +190,16 @@ def latest_installed_version(self) -> Optional[Version]: """ return _try_max(self.installed_versions) - @property - def _settings_version(self) -> Optional[Version]: + def _get_configured_version( + self, project: Optional[ProjectManager] = None + ) -> Optional[Version]: """ A helper property that gets, verifies, and installs (if needed) - the version specified in the settings. + the version specified in the config. """ - if not (version := self.settings.version): + pm = project or self.local_project + config = self.get_config(project=pm) + if not (version := config.version): return None installed_versions = self.installed_versions @@ -166,60 +209,51 @@ def _settings_version(self) -> Optional[Version]: install_solc(base_version, show_progress=True) settings_version = add_commit_hash(base_version) - if specified_commit_hash: - if settings_version != version: - raise ConfigError( - f"Commit hash from settings version {version} " - f"differs from installed: {settings_version}" - ) + if specified_commit_hash and settings_version != version: + raise ConfigError( + f"Commit hash from settings version {version} " + f"differs from installed: {settings_version}" + ) return settings_version @cached_property def _ape_version(self) -> Version: - version_str = get_package_version("eth-ape").split(".dev")[0].strip() - return Version(version_str) + return Version(version.split(".dev")[0].strip()) - def add_library(self, *contracts: ContractInstance): + def add_library(self, *contracts: ContractInstance, project: Optional[ProjectManager] = None): """ Set a library contract type address. This is useful when deploying a library in a local network and then adding the address afterward. Now, when compiling again, it will use the new address. Args: - contracts (``ContractInstance``): The deployed library contract(s). + *contracts (``ContractInstance``): The deployed library contract(s). + project (Optional[ProjectManager]): The project using the library. """ - + pm = project or self.local_project for contract in contracts: - source_id = contract.contract_type.source_id - if not source_id: + if not (source_id := contract.contract_type.source_id): raise CompilerError("Missing source ID.") - - name = contract.contract_type.name - if not name: + elif not (name := contract.contract_type.name): raise CompilerError("Missing contract type name.") self._libraries[source_id] = {name: contract.address} + path = pm.path / source_id + if not path.is_file(): + return + + # Recompile the same source, in case contracts were in there + # that required the libraries. + contract_types = { + ct.name: ct for ct in self.compile((path,), project=project) if ct.name + } + if contract_types: + all_types = {**pm.manifest.contract_types, **contract_types} + pm.update_manifest(contract_types=all_types) - if self._contracts_needing_libraries: - # TODO: Only attempt to re-compile contacts that use the given libraries. - # Attempt to re-compile contracts that needed libraries. - try: - self.project_manager.load_contracts( - [ - self.config_manager.contracts_folder / s - for s in self._contracts_needing_libraries - ], - use_cache=False, - ) - except CompilerError as err: - logger.error( - f"Failed when trying to re-compile contracts requiring libraries.\n{err}" - ) - - self._contracts_needing_libraries = set() - - def get_versions(self, all_paths: Sequence[Path]) -> Set[str]: + def get_versions(self, all_paths: Iterable[Path]) -> set[str]: + _validate_can_compile(all_paths) versions = set() for path in all_paths: # Make sure we have the compiler available to compile this @@ -229,214 +263,140 @@ def get_versions(self, all_paths: Sequence[Path]) -> Set[str]: return versions - def get_import_remapping(self, base_path: Optional[Path] = None) -> Dict[str, str]: + def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict[str, str]: """ Config remappings like ``'@import_name=path/to/dependency'`` parsed here as ``{'@import_name': 'path/to/dependency'}``. Returns: - Dict[str, str]: Where the key is the import name, e.g. ``"@openzeppelin/contracts"` + Dict[str, str]: Where the key is the import name, e.g. ``"@openzeppelin"` and the value is a stringified relative path (source ID) of the cached contract, - e.g. `".cache/OpenZeppelin/v4.4.2". + e.g. `".cache/openzeppelin/4.4.2". """ - base_path = base_path or self.project_manager.contracts_folder - if not (remappings := self.settings.import_remapping): - return {} - - elif not isinstance(remappings, (list, tuple)) or not isinstance(remappings[0], str): - raise IncorrectMappingFormatError() - - # We use these helpers to transform the values configured - # to values matching files in the compiler cache folder. - builder = ImportRemappingBuilder( - get_relative_path(self.project_manager.compiler_cache_folder, base_path) - ) - packages_cache = self.config_manager.packages_folder - - # Here we hash and validate if there were changes to remappings. - # If there were, we continue, else return the cached value for - # performance reasons. - remappings_tuple = tuple(remappings) - if ( - self._import_remapping_hash - and self._import_remapping_hash == hash(remappings_tuple) - and self.project_manager.compiler_cache_folder.is_dir() - ): - return self._cached_import_map - - # NOTE: Dependencies are only extracted if calling for the first. - # Likely, this was done already before this point, unless - # calling python methods manually. However, dependencies MUST be - # fully loaded to properly evaluate remapping paths. - dependencies = self.project_manager.load_dependencies() - - for item in remappings: - remapping_obj = ImportRemapping(entry=item, packages_cache=packages_cache) - builder.add_entry(remapping_obj) - package_id = remapping_obj.package_id - - # Handle missing version ID - if len(package_id.parts) == 1: - if package_id.name not in dependencies or len(dependencies[package_id.name]) == 0: - logger.warning(f"Missing dependency '{package_id.name}'.") - continue - - elif len(dependencies[package_id.name]) != 1: - logger.warning("version ID missing and unable to evaluate version.") - continue + pm = project or self.local_project + prefix = f"{get_relative_path(pm.contracts_folder, pm.path)}" - version_id = next(iter(dependencies[package_id.name])) - logger.debug(f"for {package_id.name} version ID missing, using {version_id}") - package_id = package_id / version_id + specified = pm.dependencies.install() - data_folder_cache = packages_cache / package_id + # Ensure .cache folder is ready-to-go. + cache_folder = pm.contracts_folder / ".cache" + cache_folder.mkdir(exist_ok=True, parents=True) - # Re-build a downloaded dependency manifest into the .cache directory for imports. - sub_contracts_cache = self.project_manager.compiler_cache_folder / package_id - if not sub_contracts_cache.is_dir() or not list(sub_contracts_cache.iterdir()): - cached_manifest_file = data_folder_cache / f"{remapping_obj.name}.json" - if not cached_manifest_file.is_file(): - logger.debug(f"Unable to find dependency '{package_id}'.") - - else: - manifest = PackageManifest.model_validate_json(cached_manifest_file.read_text()) - self._add_dependencies(manifest, sub_contracts_cache, builder) - - # Update cache and hash - self._cached_project_path = self.project_manager.path - self._cached_import_map = builder.import_map - self._import_remapping_hash = hash(remappings_tuple) - return builder.import_map - - def _add_dependencies( - self, manifest: PackageManifest, cache_dir: Path, builder: ImportRemappingBuilder - ): - if not cache_dir.is_dir(): - cache_dir.mkdir(parents=True) - - sources = manifest.sources or {} + # Start with explicitly configured remappings. + cfg_remappings: dict[str, str] = { + m.key: m.value for m in pm.config.solidity.import_remapping + } + key_map: dict[str, str] = {} - for source_name, src in sources.items(): - cached_source = cache_dir / source_name + def get_cache_id(dep) -> str: + return os.path.sep.join((prefix, ".cache", dep.name, dep.version)) - if cached_source.is_file(): - # Source already present + def unpack(dep): + # Ensure the dependency is installed. + try: + dep.project + except ProjectError: + # Try to compile anyway. + # Let the compiler fail on its own. + return + + for unpacked_dep in dep.unpack(pm.contracts_folder / ".cache"): + _key = key_map.get(unpacked_dep.name, f"@{unpacked_dep.name}") + if _key not in remapping: + remapping[_key] = get_cache_id(unpacked_dep) + # else, was specified or configured more appropriately. + + remapping: dict[str, str] = {} + for key, value in cfg_remappings.items(): + # Check if legacy-style and still accept it. + parts = value.split(os.path.sep) + name = parts[0] + _version = None + if len(parts) > 2: + # Clearly, not pointing at a dependency. + remapping[key] = value continue - # NOTE: Cached source may included sub-directories. - cached_source.parent.mkdir(parents=True, exist_ok=True) - if src.content: - cached_source.touch() - cached_source.write_text(str(src.content)) - - # Add dependency remapping that may be needed. - for compiler in manifest.compilers or []: - settings = compiler.settings or {} - settings_map = settings.get("remappings") or [] - remapping_list = [ - ImportRemapping(entry=x, packages_cache=self.config_manager.packages_folder) - for x in settings_map - ] - for remapping in remapping_list: - builder.add_entry(remapping) - - # Locate the dependency in the .ape packages cache - dependencies = manifest.dependencies or {} - packages_dir = self.config_manager.packages_folder - for dependency_package_name, uri in dependencies.items(): - uri_str = str(uri) - if "://" in uri_str: - uri_str = "://".join(uri_str.split("://")[1:]) # strip off scheme - - dependency_name = str(dependency_package_name) - if str(self.config_manager.packages_folder) in uri_str: - # Using a local dependency - version = "local" - else: - # Check for GitHub-style dependency - match_checks = (r".*/releases/tag/(v?[\d|.]+)", r".*/tree/(v?[\w|.|\d]+)") - version = None - for check in match_checks: - version_match = re.match(check, str(uri_str)) - if version_match: - version = version_match.groups()[0] - if not version.startswith("v") and version[0].isnumeric(): - version = f"v{version}" - - break + elif len(parts) == 2: + _version = parts[1] - # Find matching package - for package in packages_dir.iterdir(): - if package.name.replace("_", "-").lower() == dependency_name: - dependency_name = str(package.name) - break - - dependency_root_path = self.config_manager.packages_folder / Path(dependency_name) - - if version is None: - version_dirs = [ - d - for d in list(dependency_root_path.iterdir()) - if d.is_dir() and not d.name.startswith(".") - ] - if len(version_dirs) == 1: - # Use the only existing version. - version = version_dirs[0].name + if _version is None: + matching_deps = [d for d in pm.dependencies.installed if d.name == name] + if len(matching_deps) == 1: + _version = matching_deps[0].version + else: + # Not obvious if it is pointing at one of these dependencies. + remapping[key] = value + continue - elif (dependency_root_path / "local").is_dir(): - # If not specified, and local exists, use local by default. - version = "local" + # Dependency found. Map to it using the provider key. + dependency = pm.dependencies.get_dependency(name, _version) + key_map[dependency.name] = key + unpack(dependency) - else: - options = ", ".join([x.name for x in dependency_root_path.iterdir()]) - raise CompilerError( - f"Ambiguous dependency version. " - f"Please specify. Available versions: '{options}'." - ) + # Add auto-remapped dependencies. + # (Meaning, the dependencies are specified but their remappings + # are not, so we auto-generate default ones). + for dependency in specified: + unpack(dependency) - dependency_path = dependency_root_path / version / f"{dependency_name}.json" - if dependency_path.is_file(): - sub_manifest = PackageManifest.model_validate_json(dependency_path.read_text()) - dep_id = Path(dependency_name) / version - if dep_id not in builder.dependencies_added: - builder.dependencies_added.add(dep_id) - self._add_dependencies( - sub_manifest, - builder.contracts_cache / dep_id, - builder, - ) + return remapping def get_compiler_settings( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> Dict[Version, Dict]: - base_path = base_path or self.config_manager.contracts_folder - files_by_solc_version = self.get_version_map(contract_filepaths, base_path=base_path) - if not files_by_solc_version: + self, contract_filepaths: Iterable[Path], project: Optional[ProjectManager] = None, **kwargs + ) -> dict[Version, dict]: + pm = project or self.local_project + _validate_can_compile(contract_filepaths) + remapping = self.get_import_remapping(project=pm) + imports = self.get_imports_from_remapping(contract_filepaths, remapping, project=pm) + return self._get_settings_from_imports(contract_filepaths, imports, remapping, project=pm) + + def _get_settings_from_imports( + self, + contract_filepaths: Iterable[Path], + import_map: dict[str, list[str]], + remappings: dict[str, str], + project: Optional[ProjectManager] = None, + ): + pm = project or self.local_project + files_by_solc_version = self.get_version_map_from_imports( + contract_filepaths, import_map, project=pm + ) + return self._get_settings_from_version_map(files_by_solc_version, remappings, project=pm) + + def _get_settings_from_version_map( + self, + version_map: dict, + import_remappings: dict[str, str], + project: Optional[ProjectManager] = None, + **kwargs, + ) -> dict[Version, dict]: + pm = project or self.local_project + if not version_map: return {} - import_remappings = self.get_import_remapping(base_path=base_path) - settings: Dict = {} - for solc_version, sources in files_by_solc_version.items(): - version_settings: Dict[str, Union[Any, List[Any]]] = { - "optimizer": {"enabled": self.settings.optimize, "runs": DEFAULT_OPTIMIZATION_RUNS}, + config = self.get_config(project=pm) + settings: dict = {} + for solc_version, sources in version_map.items(): + version_settings: dict[str, Union[Any, list[Any]]] = { + "optimizer": {"enabled": config.optimize, "runs": DEFAULT_OPTIMIZATION_RUNS}, "outputSelection": { - str(get_relative_path(p, base_path)): {"*": OUTPUT_SELECTION, "": ["ast"]} - for p in sources + str(get_relative_path(p, pm.path)): {"*": OUTPUT_SELECTION, "": ["ast"]} + for p in sorted(sources) }, + **kwargs, } - if remappings_used := self._get_used_remappings( - sources, remappings=import_remappings, base_path=base_path - ): + if remappings_used := self._get_used_remappings(sources, import_remappings, project=pm): remappings_str = [f"{k}={v}" for k, v in remappings_used.items()] # Standard JSON input requires remappings to be sorted. version_settings["remappings"] = sorted(remappings_str) - if evm_version := self.settings.evm_version: + if evm_version := config.evm_version: version_settings["evmVersion"] = evm_version - if solc_version >= Version("0.7.5") and self.settings.via_ir is not None: - version_settings["viaIR"] = self.settings.via_ir + if solc_version >= Version("0.7.5") and config.via_ir is not None: + version_settings["viaIR"] = config.via_ir settings[solc_version] = version_settings @@ -447,40 +407,89 @@ def get_compiler_settings( return settings def _get_used_remappings( - self, sources, remappings: Dict[str, str], base_path: Optional[Path] = None - ) -> Dict[str, str]: - base_path = base_path or self.project_manager.contracts_folder - remappings = remappings or self.get_import_remapping(base_path=base_path) + self, + sources: Iterable[Path], + remappings: dict[str, str], + project: Optional[ProjectManager] = None, + ) -> dict[str, str]: + pm = project or self.local_project if not remappings: # No remappings used at all. return {} - relative_cache = get_relative_path(self.project_manager.compiler_cache_folder, base_path) + cache_path = ( + f"{get_relative_path(pm.contracts_folder.absolute(), pm.path)}{os.path.sep}.cache" + ) # Filter out unused import remapping. - return { - k: v - for source in ( - x - for sourceset in self.get_imports(list(sources), base_path=base_path).values() - for x in sourceset - if str(relative_cache) in x - ) - for parent_key in ( - os.path.sep.join(source.split(os.path.sep)[:3]) for source in [source] - ) - for k, v in [(k, v) for k, v in remappings.items() if parent_key in v] - } + result = {} + sources = list(sources) + imports = self.get_imports(sources, project=pm).values() + + for source_list in imports: + for src in source_list: + if not src.startswith(cache_path): + continue + + parent_key = os.path.sep.join(src.split(os.path.sep)[:3]) + for k, v in remappings.items(): + if parent_key in v: + result[k] = v + + return result def get_standard_input_json( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> Dict[Version, Dict]: - base_path = base_path or self.config_manager.contracts_folder - files_by_solc_version = self.get_version_map(contract_filepaths, base_path=base_path) - settings = self.get_compiler_settings(contract_filepaths, base_path) - input_jsons = {} + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + **overrides, + ) -> dict[Version, dict]: + pm = project or self.local_project + paths = list(contract_filepaths) # Handle if given generator= + remapping = self.get_import_remapping(project=pm) + import_map = self.get_imports_from_remapping(paths, remapping, project=pm) + version_map = self.get_version_map_from_imports(paths, import_map, project=pm) + return self.get_standard_input_json_from_version_map( + version_map, remapping, project=pm, **overrides + ) + + def get_standard_input_json_from( + self, + version_map: dict[Version, set[Path]], + import_remappings: dict[str, str], + project: Optional[ProjectManager] = None, + **overrides, + ): + pm = project or self.local_project + settings = self._get_settings_from_version_map( + version_map, import_remappings, project=pm, **overrides + ) + return self.get_standard_input_json_from_settings(settings, version_map, project=pm) + + def get_standard_input_json_from_version_map( + self, + version_map: dict[Version, set[Path]], + import_remapping: dict[str, str], + project: Optional[ProjectManager] = None, + **overrides, + ): + pm = project or self.local_project + settings = self._get_settings_from_version_map( + version_map, import_remapping, project=pm, **overrides + ) + return self.get_standard_input_json_from_settings(settings, version_map, project=pm) + + def get_standard_input_json_from_settings( + self, + settings: dict[Version, dict], + version_map: dict[Version, set[Path]], + project: Optional[ProjectManager] = None, + ): + pm = project or self.local_project + input_jsons: dict[Version, dict] = {} + for solc_version, vers_settings in settings.items(): - if not list(files_by_solc_version[solc_version]): + if not list(version_map[solc_version]): continue cleaned_version = Version(solc_version.base_version) @@ -488,17 +497,32 @@ def get_standard_input_json( arguments = {"solc_binary": solc_binary, "solc_version": cleaned_version} if solc_version >= Version("0.6.9"): - arguments["base_path"] = base_path + arguments["base_path"] = pm.path if missing_sources := [ - x for x in vers_settings["outputSelection"] if not (base_path / x).is_file() + x for x in vers_settings["outputSelection"] if not (pm.path / x).is_file() ]: + # See if the missing sources are from dependencies (they likely are) + # and cater the error message accordingly. + if dependencies_needed := [x for x in missing_sources if str(x).startswith("@")]: + # Missing dependencies. Should only get here if dependencies are found + # in import-strs but are not installed anywhere (not in project or globally). + missing_str = ", ".join(dependencies_needed) + raise CompilerError( + f"Missing required dependencies '{missing_str}'. " + "Install them using `dependencies:` " + "in an ape-config.yaml or using the `ape pm install` command." + ) + + # Otherwise, we are missing project-level source files for some reason. + # This would only happen if the user passes in unexpected files outside + # of core. missing_src_str = ", ".join(missing_sources) - raise CompilerError(f"Missing sources: '{missing_src_str}'.") + raise CompilerError(f"Sources '{missing_src_str}' not found in '{pm.name}'.") sources = { - x: {"content": (base_path / x).read_text()} - for x in vers_settings["outputSelection"] + x: {"content": (pm.path / x).read_text()} + for x in sorted(vers_settings["outputSelection"]) } input_jsons[solc_version] = { @@ -507,27 +531,53 @@ def get_standard_input_json( "language": "Solidity", } - return input_jsons + return {v: input_jsons[v] for v in sorted(input_jsons)} def compile( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> List[ContractType]: - base_path = base_path or self.config_manager.contracts_folder - contract_versions: Dict[str, Version] = {} - contract_types: List[ContractType] = [] - input_jsons = self.get_standard_input_json(contract_filepaths, base_path=base_path) - + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + settings: Optional[dict] = None, + ) -> Iterator[ContractType]: + pm = project or self.local_project + settings = settings or {} + paths = [p for p in contract_filepaths] # Handles generator. + source_ids = [f"{get_relative_path(p.absolute(), pm.path)}" for p in paths] + _validate_can_compile(paths) + + # Compile in an isolated env so the .cache folder does not interfere with anything. + with pm.isolate_in_tempdir() as isolated_project: + filepaths = [isolated_project.path / src_id for src_id in source_ids] + yield from self._compile(filepaths, project=isolated_project, settings=settings) + compilers = isolated_project.manifest.compilers + + pm.update_manifest(compilers=compilers) + + def _compile( + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + settings: Optional[dict] = None, + ): + pm = project or self.local_project + input_jsons = self.get_standard_input_json( + contract_filepaths, project=pm, **(settings or {}) + ) + contract_versions: dict[str, Version] = {} + contract_types: list[ContractType] = [] for solc_version, input_json in input_jsons.items(): - logger.info(f"Compiling using Solidity compiler '{solc_version}'.") + keys = ( + "\n\t".join(sorted([x for x in input_json.get("sources", {}).keys()])) + or "No input." + ) + log_str = f"Compiling using Solidity compiler '{solc_version}'.\nInput:\n\t{keys}" + logger.info(log_str) cleaned_version = Version(solc_version.base_version) solc_binary = get_executable(version=cleaned_version) - arguments: Dict = {"solc_binary": solc_binary, "solc_version": cleaned_version} + arguments: dict = {"solc_binary": solc_binary, "solc_version": cleaned_version} if solc_version >= Version("0.6.9"): - arguments["base_path"] = base_path - - if self.project_manager.compiler_cache_folder.is_dir(): - arguments["allow_paths"] = self.project_manager.compiler_cache_folder + arguments["base_path"] = pm.path # Allow empty contracts, like Vyper does. arguments["allow_empty"] = True @@ -538,12 +588,11 @@ def compile( raise SolcCompileError(err) from err contracts = output.get("contracts", {}) - # Perf back-out. if not contracts: continue - input_contract_names: List[str] = [] + input_contract_names: list[str] = [] for source_id, contracts_out in contracts.items(): for name, _ in contracts_out.items(): # Filter source files that the user did not ask for, such as @@ -556,8 +605,6 @@ def compile( # ast_data = output["sources"][source_id]["ast"] for contract_name, ct_data in contracts_out.items(): - contract_path = base_path / source_id - if contract_name not in input_contract_names: # Only return ContractTypes explicitly asked for. continue @@ -576,37 +623,29 @@ def compile( f"Unable to compile {contract_name} - missing libraries. " f"Call `{self.add_library.__name__}` with the necessary libraries" ) - self._contracts_needing_libraries.add(contract_path) continue - if previous_version := contract_versions.get(contract_name, None): - if previous_version < solc_version: - # Keep the smallest version for max compat. - continue - - else: - # Remove the previously compiled contract types and re-compile. - contract_types = [ - ct for ct in contract_types if ct.name != contract_name - ] - if contract_name in contract_versions: - del contract_versions[contract_name] + if contract_name in contract_versions: + # Already yield in smaller version. Must not yield again + # or else we will have a contract-type collision. + # (Sources that are required in multiple version-sets will + # hit this). + continue ct_data["contractName"] = contract_name - ct_data["sourceId"] = str( - get_relative_path(base_path / contract_path, base_path) - ) + ct_data["sourceId"] = source_id ct_data["deploymentBytecode"] = {"bytecode": deployment_bytecode} ct_data["runtimeBytecode"] = {"bytecode": runtime_bytecode} ct_data["userdoc"] = load_dict(ct_data["userdoc"]) ct_data["devdoc"] = load_dict(ct_data["devdoc"]) ct_data["sourcemap"] = evm_data["bytecode"]["sourceMap"] contract_type = ContractType.model_validate(ct_data) + yield contract_type contract_types.append(contract_type) contract_versions[contract_name] = solc_version # Output compiler data used. - compilers_used: Dict[Version, Compiler] = {} + compilers_used: dict[Version, Compiler] = {} for ct in contract_types: if not ct.name: # Won't happen, but just for mypy. @@ -628,19 +667,19 @@ def compile( settings=settings, ) + # Update compilers used in project manifest. # First, output compiler information to manifest. compilers_ls = list(compilers_used.values()) - self.project_manager.local_project.add_compiler_data(compilers_ls) - - return contract_types + pm.add_compiler_data(compilers_ls) def compile_code( self, code: str, - base_path: Optional[Path] = None, + project: Optional[ProjectManager] = None, **kwargs, ) -> ContractType: - if settings_version := self._settings_version: + pm = project or self.local_project + if settings_version := self._get_configured_version(project=pm): version = settings_version elif pragma := self._get_pramga_spec_from_str(code): @@ -669,8 +708,8 @@ def compile_code( try: result = compile_source( code, - import_remappings=self.get_import_remapping(base_path=base_path), - base_path=base_path, + import_remappings=self.get_import_remapping(project=pm), + base_path=pm.path, solc_binary=executable, solc_version=cleaned_version, allow_empty=True, @@ -679,137 +718,173 @@ def compile_code( raise SolcCompileError(err) from err output = result[next(iter(result.keys()))] - return ContractType( - abi=output["abi"], - ast=output["ast"], - deploymentBytecode={"bytecode": HexBytes(output["bin"])}, - devdoc=load_dict(output["devdoc"]), - runtimeBytecode={"bytecode": HexBytes(output["bin-runtime"])}, - sourcemap=output["srcmap"], - userdoc=load_dict(output["userdoc"]), - **kwargs, + return ContractType.model_validate( + { + "abi": output["abi"], + "ast": output["ast"], + "deploymentBytecode": {"bytecode": HexBytes(output["bin"])}, + "devdoc": load_dict(output["devdoc"]), + "runtimeBytecode": {"bytecode": HexBytes(output["bin-runtime"])}, + "sourcemap": output["srcmap"], + "userdoc": load_dict(output["userdoc"]), + **kwargs, + } ) - def _get_unmapped_imports( + def get_imports( self, - contract_filepaths: Sequence[Path], - base_path: Optional[Path] = None, - ) -> Dict[str, List[Tuple[str, str]]]: - contracts_path = base_path or self.config_manager.contracts_folder - import_remapping = self.get_import_remapping(base_path=contracts_path) - contract_filepaths_set = verify_contract_filepaths(contract_filepaths) - - imports_dict: Dict[str, List[Tuple[str, str]]] = {} - for src_path, import_strs in get_import_lines(contract_filepaths_set).items(): - import_list = [] - for import_str in import_strs: - raw_import_item_search = re.search(r'"(.*?)"', import_str) - if raw_import_item_search is None: - raise CompilerError(f"No target filename found in import {import_str}") - raw_import_item = raw_import_item_search.group(1) - import_item = _import_str_to_source_id( - import_str, src_path, contracts_path, import_remapping - ) + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + ) -> dict[str, list[str]]: + pm = project or self.local_project + remapping = self.get_import_remapping(project=pm) + _validate_can_compile(contract_filepaths) + paths = [x for x in contract_filepaths] # Handle if given generator. + return self.get_imports_from_remapping(paths, remapping, project=pm) + + def get_imports_from_remapping( + self, + paths: Iterable[Path], + remapping: dict[str, str], + project: Optional[ProjectManager] = None, + ) -> dict[str, list[str]]: + pm = project or self.local_project + return self._get_imports(paths, remapping, pm, tracked=set()) # type: ignore + + def _get_imports( + self, + paths: Iterable[Path], + remapping: dict[str, str], + pm: "ProjectManager", + tracked: set[str], + include_raw: bool = False, + ) -> dict[str, Union[dict[str, str], list[str]]]: + result: dict = {} + + for src_path, import_strs in get_import_lines(paths).items(): + source_id = str(get_relative_path(src_path, pm.path)) + if source_id in tracked: + # We have already accumulated imports from this source. + continue - # Only add to the list if it's not already there, to mimic set behavior - if (import_item, raw_import_item) not in import_list: - import_list.append((import_item, raw_import_item)) + tracked.add(source_id) - source_id = str(get_relative_path(src_path, contracts_path)) - imports_dict[str(source_id)] = import_list + # Init with all top-level imports. + import_map = { + x: self._import_str_to_source_id(x, src_path, remapping, project=pm) + for x in import_strs + } + import_source_ids = list(set(list(import_map.values()))) - return imports_dict + # NOTE: Add entry even if empty here. + result[source_id] = import_map if include_raw else import_source_ids - def get_imports( - self, - contract_filepaths: Sequence[Path], - base_path: Optional[Path] = None, - ) -> Dict[str, List[str]]: - contracts_path = base_path or self.config_manager.contracts_folder + # Add imports of imports. + if not result[source_id]: + # Nothing else imported. + continue - def build_map(paths: Set[Path], prev: Optional[Dict] = None) -> Dict[str, List[str]]: - result: Dict[str, List[str]] = prev or {} + # Add known imports. + known_imports = {p: result[p] for p in import_source_ids if p in result} + imp_paths = [pm.path / p for p in import_source_ids if p not in result] + unknown_imports = self._get_imports( + imp_paths, + remapping, + pm, + tracked=tracked, + include_raw=include_raw, + ) + sub_imports = {**known_imports, **unknown_imports} - for src_path, import_strs in get_import_lines(paths).items(): - source_id = str(get_relative_path(src_path, contracts_path)) - if source_id in result: - continue + # All imported sources from imported sources are imported sources. + for sub_set in sub_imports.values(): + if isinstance(sub_set, dict): + for import_str, sub_import in sub_set.items(): + result[source_id][import_str] = sub_import - import_set = { - _import_str_to_source_id(import_str, src_path, contracts_path, import_remapping) - for import_str in import_strs - } - result[source_id] = list(import_set) + else: + for sub_import in sub_set: + if sub_import not in result[source_id]: + result[source_id].append(sub_import) - # Add imports of imports. - import_paths = {contracts_path / p for p in import_set if p not in result} - result = {**result, **build_map(import_paths, prev=result)} + # Keep sorted. + result[source_id] = sorted((result[source_id])) - return result + # Combine results. This ends up like a tree-structure. + result = {**result, **sub_imports} - # NOTE: Process import remapping list _before_ getting the full contract set. - import_remapping = self.get_import_remapping(base_path=contracts_path) - contract_filepaths_set = verify_contract_filepaths(contract_filepaths) - return build_map(contract_filepaths_set) + # Sort final keys and import lists for more predictable compiler behavior. + return {k: result[k] for k in sorted(result.keys())} def get_version_map( self, - contract_filepaths: Union[Path, Sequence[Path]], - base_path: Optional[Path] = None, - ) -> Dict[Version, Set[Path]]: - # Ensure `.cache` folder is built before getting version map. - self.get_import_remapping(base_path=base_path) - - if not isinstance(contract_filepaths, Sequence): - contract_filepaths = [contract_filepaths] - - base_path = base_path or self.project_manager.contracts_folder - contract_filepaths_set = verify_contract_filepaths(contract_filepaths) - sources = [ - p - for p in self.project_manager.source_paths - if p.is_file() and p.suffix == Extension.SOL.value - ] - imports = self.get_imports(sources, base_path) + contract_filepaths: Union[Path, Iterable[Path]], + project: Optional[ProjectManager] = None, + ) -> dict[Version, set[Path]]: + pm = project or self.local_project + paths = ( + [contract_filepaths] + if isinstance(contract_filepaths, Path) + else [p for p in contract_filepaths] + ) + _validate_can_compile(paths) + imports = self.get_imports(paths, project=pm) + return self.get_version_map_from_imports(paths, imports, project=pm) + + def get_version_map_from_imports( + self, + contract_filepaths: Union[Path, Iterable[Path]], + import_map: dict[str, list[str]], + project: Optional[ProjectManager] = None, + ): + pm = project or self.local_project + paths = ( + [contract_filepaths] + if isinstance(contract_filepaths, Path) + else [p for p in contract_filepaths] + ) + path_set: set[Path] = {p for p in paths} # Add imported source files to list of contracts to compile. - source_paths_to_get = contract_filepaths_set.copy() - for source_path in contract_filepaths_set: - imported_source_paths = self._get_imported_source_paths(source_path, base_path, imports) - for imported_source in imported_source_paths: - source_paths_to_get.add(imported_source) + for source_path in paths: + source_id = f"{get_relative_path(source_path, pm.path)}" + if source_id not in import_map or len(import_map[source_id]) == 0: + continue - # Use specified version if given one - if version := self._settings_version: - return {version: source_paths_to_get} + import_set = {pm.path / src_id for src_id in import_map[source_id]} + path_set = path_set.union(import_set) + # Use specified version if given one + if _version := self._get_configured_version(project=pm): + return {_version: path_set} # else: find best version per source file # Build map of pragma-specs. - source_by_pragma_spec = {p: get_pragma_spec_from_path(p) for p in source_paths_to_get} + pragma_map = {p: get_pragma_spec_from_path(p) for p in path_set} # If no Solidity version has been installed previously while fetching the # contract version pragma, we must install a compiler, so choose the latest if ( not self.installed_versions - and not any(source_by_pragma_spec.values()) + and not any(pragma_map.values()) and (latest := self.latest_version) ): install_solc(latest, show_progress=True) # Adjust best-versions based on imports. - files_by_solc_version: Dict[Version, Set[Path]] = {} - for source_file_path in source_paths_to_get: - solc_version = self._get_best_version(source_file_path, source_by_pragma_spec) + files_by_solc_version: dict[Version, set[Path]] = {} + for source_file_path in path_set: + solc_version = self._get_best_version(source_file_path, pragma_map) imported_source_paths = self._get_imported_source_paths( - source_file_path, base_path, imports + source_file_path, pm.path, import_map ) for imported_source_path in imported_source_paths: - imported_pragma_spec = source_by_pragma_spec[imported_source_path] - imported_version = self._get_best_version( - imported_source_path, source_by_pragma_spec - ) + if imported_source_path not in pragma_map: + continue + + imported_pragma_spec = pragma_map[imported_source_path] + imported_version = self._get_best_version(imported_source_path, pragma_map) if imported_pragma_spec is not None and ( str(imported_pragma_spec)[0].startswith("=") @@ -842,8 +917,8 @@ def get_version_map( other_files = [f for f in files_by_solc_version[solc_version] if f != file] used_in_imports = False for other_file in other_files: - source_id = str(get_relative_path(other_file, base_path)) - import_paths = [base_path / i for i in imports.get(source_id, []) if i] + source_id = str(get_relative_path(other_file, pm.path)) + import_paths = [pm.path / i for i in import_map.get(source_id, []) if i] if file in import_paths: used_in_imports = True break @@ -853,15 +928,19 @@ def get_version_map( if not files_by_solc_version[solc_version]: del files_by_solc_version[solc_version] - return {add_commit_hash(v): ls for v, ls in files_by_solc_version.items()} + result = {add_commit_hash(v): ls for v, ls in files_by_solc_version.items()} + + # Sort, so it is a nicer version map and the rest of the compilation flow + # is more predictable. + return {k: result[k] for k in sorted(result)} def _get_imported_source_paths( self, path: Path, base_path: Path, - imports: Dict, - source_ids_checked: Optional[List[str]] = None, - ) -> Set[Path]: + imports: dict, + source_ids_checked: Optional[list[str]] = None, + ) -> set[Path]: source_ids_checked = source_ids_checked or [] source_identifier = str(get_relative_path(path, base_path)) if source_identifier in source_ids_checked: @@ -893,8 +972,8 @@ def _get_pramga_spec_from_str(self, source_str: str) -> Optional[SpecifierSet]: else: # Attempt to use the best-installed version. - for version in self.installed_versions: - if version not in pragma_spec: + for _version in self.installed_versions: + if _version not in pragma_spec: continue logger.warning( @@ -910,7 +989,20 @@ def _get_pramga_spec_from_str(self, source_str: str) -> Optional[SpecifierSet]: return pragma_spec - def _get_best_version(self, path: Path, source_by_pragma_spec: Dict) -> Version: + def _get_best_versions(self, path: Path, options, source_by_pragma_spec: dict) -> list[Version]: + # NOTE: Doesn't install. + if pragma_spec := source_by_pragma_spec.get(path): + res = get_versions_can_use(pragma_spec, list(options)) + elif latest_installed := self.latest_installed_version: + res = [latest_installed] + elif latest := self.latest_version: + res = [latest] + else: + raise SolcInstallError() + + return [add_commit_hash(v) for v in res] + + def _get_best_version(self, path: Path, source_by_pragma_spec: dict) -> Version: compiler_version: Optional[Version] = None if pragma_spec := source_by_pragma_spec.get(path): if selected := select_version(pragma_spec, self.installed_versions): @@ -955,12 +1047,7 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: selector = bytes_message[:4] input_data = bytes_message[4:] - # TODO: Any version after Ape 0.6.11 we can replace this with `err.address`. - if not ( - address := err.contract_address - or getattr(err.txn, "receiver", None) - or getattr(err.txn, "contract_address", None) - ): + if not err.address: return err if not self.network_manager.active_provider: @@ -968,7 +1055,7 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: return err if ( - not (contract := self.chain_manager.contracts.instance_at(address)) + not (contract := self.chain_manager.contracts.instance_at(err.address)) or selector not in contract.contract_type.errors ): return err @@ -988,36 +1075,108 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: ) def _flatten_source( - self, path: Path, base_path: Optional[Path] = None, raw_import_name: Optional[str] = None + self, + path: Path, + project: Optional[ProjectManager] = None, + raw_import_name: Optional[str] = None, + handled: Optional[set[str]] = None, ) -> str: - base_path = base_path or self.config_manager.contracts_folder - imports = self._get_unmapped_imports([path]) - source = "" - for import_list in imports.values(): - for import_path, raw_import_path in sorted(import_list): - source += self._flatten_source( - base_path / import_path, base_path=base_path, raw_import_name=raw_import_path - ) - if raw_import_name: - source += f"// File: {raw_import_name}\n" - else: - source += f"// File: {path.name}\n" - source += path.read_text() + "\n" - return source + pm = project or self.local_project + handled = handled or set() + source_id = f"{get_relative_path(path, pm.path)}" + handled.add(source_id) + remapping = self.get_import_remapping(project=project) + imports = self._get_imports((path,), remapping, pm, tracked=set(), include_raw=True) + relevant_imports = imports.get(source_id, {}) + + final_source = "" + + for import_str, source_id in relevant_imports.items(): # type: ignore + if source_id in handled: + continue + + sub_import_name = import_str.replace("import ", "").strip(" \n\t;\"'") + final_source += self._flatten_source( + pm.path / source_id, + project=pm, + raw_import_name=sub_import_name, + handled=handled, + ) - def flatten_contract(self, path: Path, **kwargs) -> Content: + final_source += _get_flattened_source(path, name=raw_import_name) + return final_source + + def flatten_contract( + self, path: Path, project: Optional[ProjectManager] = None, **kwargs + ) -> Content: # try compiling in order to validate it works - self.compile([path], base_path=self.config_manager.contracts_folder) - source = self._flatten_source(path) - source = remove_imports(source) - source = process_licenses(source) - source = remove_version_pragmas(source) + res = self._flatten_source(path, project=project) + res = remove_imports(res) + res = process_licenses(res) + res = remove_version_pragmas(res) pragma = get_first_version_pragma(path.read_text()) - source = "\n".join([pragma, source]) - lines = source.splitlines() + res = "\n".join([pragma, res]) + lines = res.splitlines() line_dict = {i + 1: line for i, line in enumerate(lines)} return Content(root=line_dict) + def _import_str_to_source_id( + self, + _import_str: str, + source_path: Path, + import_remapping: dict[str, str], + project: Optional[ProjectManager] = None, + ) -> str: + pm = project or self.local_project + quote = '"' if '"' in _import_str else "'" + + try: + end_index = _import_str.index(quote) + 1 + except ValueError as err: + raise CompilerError( + f"Error parsing import statement '{_import_str}' in '{source_path.name}'." + ) from err + + import_str_prefix = _import_str[end_index:] + import_str_value = import_str_prefix[: import_str_prefix.index(quote)] + + # Get all matches. + valid_matches: list[tuple[str, str]] = [] + key = None + base_path = None + for key, value in import_remapping.items(): + if key not in import_str_value: + continue + + valid_matches.append((key, value)) + + if valid_matches: + key, value = max(valid_matches, key=lambda x: len(x[0])) + import_str_value = import_str_value.replace(key, value) + + if import_str_value.startswith("."): + base_path = source_path.parent + elif (pm.path / import_str_value).is_file(): + base_path = pm.path + elif (pm.contracts_folder / import_str_value).is_file(): + base_path = pm.contracts_folder + elif key is not None and key.startswith("@"): + nm = key[1:] + for cfg_dep in pm.config.dependencies: + if ( + cfg_dep.get("name") == nm + and "project" in cfg_dep + and (Path(cfg_dep["project"]) / import_str_value).is_file() + ): + base_path = Path(cfg_dep["project"]) + + if base_path is None: + # No base_path, do as-is. + return import_str_value + + path = (base_path / import_str_value).resolve() + return f"{get_relative_path(path.absolute(), pm.path)}" + def remove_imports(flattened_contract: str) -> str: # Use regex.sub() to remove matched import statements @@ -1037,7 +1196,7 @@ def get_first_version_pragma(source: str) -> str: return "" -def get_licenses(src: str) -> List[Tuple[str, str]]: +def get_licenses(src: str) -> list[tuple[str, str]]: return LICENSES_PATTERN.findall(src) @@ -1084,7 +1243,7 @@ def process_licenses(contract: str) -> str: return contract_with_single_license -def _get_sol_panic(revert_message: str) -> Optional[Type[RuntimeErrorUnion]]: +def _get_sol_panic(revert_message: str) -> Optional[type[RuntimeErrorUnion]]: if revert_message.startswith(RUNTIME_ERROR_CODE_PREFIX): # ape-geth (style) plugins show the hex with the Panic ABI prefix. error_type_val = int( @@ -1100,52 +1259,13 @@ def _get_sol_panic(revert_message: str) -> Optional[Type[RuntimeErrorUnion]]: return None -def _import_str_to_source_id( - _import_str: str, source_path: Path, base_path: Path, import_remapping: Dict[str, str] -) -> str: - quote = '"' if '"' in _import_str else "'" - - try: - end_index = _import_str.index(quote) + 1 - except ValueError as err: - raise CompilerError( - f"Error parsing import statement '{_import_str}' in '{source_path.name}'." - ) from err - - import_str_prefix = _import_str[end_index:] - import_str_value = import_str_prefix[: import_str_prefix.index(quote)] - path = (source_path.parent / import_str_value).resolve() - source_id_value = str(get_relative_path(path, base_path)) - - # Get all matches. - matches: List[Tuple[str, str]] = [] - for key, value in import_remapping.items(): - if key not in source_id_value: - continue - - matches.append((key, value)) - - if not matches: - return source_id_value - - # Convert remapping list back to source using longest match (most exact). - key, value = max(matches, key=lambda x: len(x[0])) - sections = [s for s in source_id_value.split(key) if s] - depth = len(sections) - 1 - source_id_value = "" - - index = 0 - for section in sections: - if index == depth: - source_id_value += value - source_id_value += section - elif index >= depth: - source_id_value += section - - index += 1 +def _try_max(ls: list[Any]): + return max(ls) if ls else None - return source_id_value +def _validate_can_compile(paths: Iterable[Path]): + extensions = {get_full_extension(p): p for p in paths if p} -def _try_max(ls: List[Any]): - return max(ls) if ls else None + for ext, file in extensions.items(): + if ext not in [e.value for e in Extension]: + raise CompilerError(f"Unable to compile '{file.name}' using Solidity compiler.") diff --git a/ape_solidity/exceptions.py b/ape_solidity/exceptions.py index f059a33..5a11aae 100644 --- a/ape_solidity/exceptions.py +++ b/ape_solidity/exceptions.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Dict, Type, Union +from typing import Union from ape.exceptions import CompilerError, ConfigError, ContractLogicError from ape.logging import LogLevel, logger @@ -169,7 +169,7 @@ def __init__(self, **kwargs): PopOnEmptyArrayError, ZeroInitializedVariableError, ] -RUNTIME_ERROR_MAP: Dict[RuntimeErrorType, Type[RuntimeErrorUnion]] = { +RUNTIME_ERROR_MAP: dict[RuntimeErrorType, type[RuntimeErrorUnion]] = { RuntimeErrorType.ASSERTION_ERROR: SolidityAssertionError, RuntimeErrorType.ARITHMETIC_UNDER_OR_OVERFLOW: SolidityArithmeticError, RuntimeErrorType.DIVISION_BY_ZERO_ERROR: DivisionByZeroError, diff --git a/pyproject.toml b/pyproject.toml index 5e7fecd..00925ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ write_to = "ape_solidity/version.py" [tool.black] line-length = 100 -target-version = ['py38', 'py39', 'py310', 'py311', 'py312'] +target-version = ['py39', 'py310', 'py311', 'py312'] include = '\.pyi?$' [tool.pytest.ini_options] @@ -30,7 +30,10 @@ addopts = """ """ python_files = "test_*.py" testpaths = "tests" -markers = "fuzzing: Run Hypothesis fuzz test suite" +markers = """ +fuzzing: Run Hypothesis fuzz test suite +install: Tests that will install a solc version (slow) +""" [tool.isort] line_length = 100 diff --git a/setup.py b/setup.py index df01162..2a30d37 100644 --- a/setup.py +++ b/setup.py @@ -69,13 +69,13 @@ include_package_data=True, install_requires=[ "py-solc-x>=2.0.2,<3", - "eth-ape>=0.7.10,<0.8", + "eth-ape>=0.8.1,<0.9", "ethpm-types", # Use the version ape requires "eth-pydantic-types", # Use the version ape requires "packaging", # Use the version ape requires "requests", ], - python_requires=">=3.8,<4", + python_requires=">=3.9,<4", extras_require=extras_require, py_modules=["ape_solidity"], license="Apache-2.0", @@ -91,7 +91,6 @@ "Operating System :: MacOS", "Operating System :: POSIX", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/tests/BrownieProject/contracts/BrownieContract.sol b/tests/BrownieProject/contracts/BrownieContract.sol index 3ab2338..20945d6 100644 --- a/tests/BrownieProject/contracts/BrownieContract.sol +++ b/tests/BrownieProject/contracts/BrownieContract.sol @@ -2,7 +2,7 @@ pragma solidity ^0.8.4; -contract BrownieContract { +contract CompilingContract { function foo() pure public returns(bool) { return true; } diff --git a/tests/Dependency/ape-config.yaml b/tests/Dependency/ape-config.yaml index 862a031..6598219 100644 --- a/tests/Dependency/ape-config.yaml +++ b/tests/Dependency/ape-config.yaml @@ -1,7 +1,3 @@ dependencies: - - name: TestDependencyOfDependency + - name: dependencyofdependency local: ../DependencyOfDependency - -solidity: - import_remapping: - - "@dependency_remapping=TestDependencyOfDependency/local" diff --git a/tests/Dependency/contracts/Dependency.sol b/tests/Dependency/contracts/Dependency.sol index 42d5ec3..6989416 100644 --- a/tests/Dependency/contracts/Dependency.sol +++ b/tests/Dependency/contracts/Dependency.sol @@ -2,7 +2,7 @@ pragma solidity ^0.8.4; -import "@dependency_remapping/DependencyOfDependency.sol"; +import "@dependencyofdependency/contracts/DependencyOfDependency.sol"; struct DependencyStruct { string name; diff --git a/tests/BrownieStyleDependency/ape-config.yaml b/tests/NonCompilingDependency/ape-config.yaml similarity index 100% rename from tests/BrownieStyleDependency/ape-config.yaml rename to tests/NonCompilingDependency/ape-config.yaml diff --git a/tests/BrownieStyleDependency/contracts/BrownieStyleDependency.sol b/tests/NonCompilingDependency/contracts/CompilingContract.sol similarity index 100% rename from tests/BrownieStyleDependency/contracts/BrownieStyleDependency.sol rename to tests/NonCompilingDependency/contracts/CompilingContract.sol diff --git a/tests/BrownieStyleDependency/contracts/FailingContract.sol b/tests/NonCompilingDependency/contracts/NonCompilingContract.sol similarity index 57% rename from tests/BrownieStyleDependency/contracts/FailingContract.sol rename to tests/NonCompilingDependency/contracts/NonCompilingContract.sol index eab8279..26819f9 100644 --- a/tests/BrownieStyleDependency/contracts/FailingContract.sol +++ b/tests/NonCompilingDependency/contracts/NonCompilingContract.sol @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.4; -This file exists to show that a brownie-style dependency can contain a contract that does not compile. +This file exists to show that a non-fully compiling dependency can contain a contract that does not compile. But the dependency is still usable (provided the project only uses working contract types) diff --git a/tests/ProjectWithinProject/ape-config.yaml b/tests/ProjectWithinProject/ape-config.yaml index 47898ee..0e9c78f 100644 --- a/tests/ProjectWithinProject/ape-config.yaml +++ b/tests/ProjectWithinProject/ape-config.yaml @@ -1,11 +1,6 @@ dependencies: - - name: TestRemapping + - name: remapping local: ../Dependency - - name: TestDependencyOfDependency + - name: dependencyofdependency local: ../DependencyOfDependency - -solidity: - import_remapping: - - "@remapping=TestRemapping" - - "@dependency_remapping=TestDependencyOfDependency/local" diff --git a/tests/ProjectWithinProject/contracts/Contract.sol b/tests/ProjectWithinProject/contracts/Contract.sol index 22663d2..4aa4aff 100644 --- a/tests/ProjectWithinProject/contracts/Contract.sol +++ b/tests/ProjectWithinProject/contracts/Contract.sol @@ -2,7 +2,7 @@ pragma solidity ^0.8.4; -import "@remapping/Dependency.sol"; +import "@remapping/contracts/Dependency.sol"; contract Contract { function foo() pure public returns(bool) { diff --git a/tests/ape-config.yaml b/tests/ape-config.yaml index eecbaa9..5c22a66 100644 --- a/tests/ape-config.yaml +++ b/tests/ape-config.yaml @@ -1,57 +1,30 @@ dependencies: - - name: TestDependency + - name: dependency local: ./Dependency - - name: DependencyOfDependency - local: ./DependencyOfDependency # Make sure can use a Brownie project as a dependency - - name: BrownieDependency + - name: browniedependency local: ./BrownieProject - # Make sure can use a Brownie-style dependency. - # Brownie-style dependencies don't compile on their own, necessarily - # and are more a collection of contract types you can use. - - name: BrownieStyleDependency - local: ./BrownieStyleDependency + # Make sure can use contracts from a non-fully compiling dependency. + - name: noncompilingdependency + local: ./NonCompilingDependency # Ensure we can build a realistic-brownie project with dependencies. - name: vault github: yearn/yearn-vaults - version: 0.4.5 + ref: v0.4.5 # Ensure dependencies using GitHub references work. - - name: vault + - name: vaultmain github: yearn/yearn-vaults ref: master - # Ensure dependencies using NPM dependencies - - name: gnosis + # Ensure NPM dependencies work. + - name: safe npm: "@gnosis.pm/safe-contracts" version: 1.3.0 solidity: - import_remapping: - - "@remapping/contracts=TestDependency" - - "@remapping_2=TestDependency" - - "@remapping_2_brownie=BrownieDependency" - - "@dependency_remapping=DependencyOfDependency" - - # Remapping for showing we can import a contract type from a brownie-style dependency - # (provided the _single_ contract type compiles in the project). - - "@styleofbrownie=BrownieStyleDependency" - - # Ensure yearn-vaults works as a remapping - - "@vault=vault/v0.4.5" - - "@vaultmain=vault/master" - - # Ensure that npm dependencies work as a remapping - - "@gnosis=gnosis/v1.3.0" - - # Needed for Vault - - "@openzeppelin/contracts=OpenZeppelin/v4.7.1" - - # Testing multiple versions of same dependency. - - "@oz/contracts=OpenZeppelin/v4.5.0" - # Using evm_version compatible with older and newer solidity versions. evm_version: constantinople diff --git a/tests/conftest.py b/tests/conftest.py index df927c6..d7b98f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,21 +1,15 @@ import shutil from contextlib import contextmanager from pathlib import Path -from shutil import copytree from tempfile import mkdtemp from unittest import mock import ape import pytest import solcx +from ape.utils.os import create_tempdir -from ape_solidity.compiler import Extension - -# NOTE: Ensure that we don't use local paths for these -DATA_FOLDER = Path(mkdtemp()).resolve() -PROJECT_FOLDER = Path(mkdtemp()).resolve() -ape.config.DATA_FOLDER = DATA_FOLDER -ape.config.PROJECT_FOLDER = PROJECT_FOLDER +from ape_solidity._utils import Extension @contextmanager @@ -54,44 +48,51 @@ def temp_solcx_path(monkeypatch): yield path -@pytest.fixture(autouse=True) -def data_folder(): - base_path = Path(__file__).parent / "data" - copytree(base_path, DATA_FOLDER, dirs_exist_ok=True) - return DATA_FOLDER - - -@pytest.fixture -def config(): - return ape.config - +@pytest.fixture(scope="session") +def project(config): + _ = config # Ensure temp data folder gets set first. + root = Path(__file__).parent -@pytest.fixture(autouse=True) -def project(data_folder, config): - _ = data_folder # Ensure happens first. - project_source_dir = Path(__file__).parent - project_dest_dir = PROJECT_FOLDER / project_source_dir.name - - # Delete build / .cache that may exist pre-copy - project_path = Path(__file__).parent + # Delete .build / .cache that may exist pre-copy for path in ( - project_path, - project_path / "BrownieProject", - project_path / "BrownieStyleDependency", - project_path / "Dependency", - project_path / "DependencyOfDependency", - project_path / "ProjectWithinProject", - project_path / "VersionSpecifiedInConfig", + root, + root / "BrownieProject", + root / "BrownieStyleDependency", + root / "Dependency", + root / "DependencyOfDependency", + root / "ProjectWithinProject", + root / "VersionSpecifiedInConfig", ): for cache in (path / ".build", path / "contracts" / ".cache"): if cache.is_dir(): shutil.rmtree(cache) - copytree(project_source_dir, project_dest_dir, dirs_exist_ok=True) - with config.using_project(project_dest_dir) as project: - yield project - if project.local_project._cache_folder.is_dir(): - shutil.rmtree(project.local_project._cache_folder) + root_project = ape.Project(root) + with root_project.isolate_in_tempdir() as tmp_project: + yield tmp_project + + +@pytest.fixture(scope="session", autouse=True) +def config(): + cfg = ape.config + + # Uncomment to install dependencies in actual data folder. + # This will save time running tests. + # project = ape.Project(Path(__file__).parent) + # project.dependencies.install() + + # Ensure we don't persist any .ape data. + real_data_folder = cfg.DATA_FOLDER + with create_tempdir() as path: + cfg.DATA_FOLDER = path + + # Copy in existing packages to save test time + # when running locally. + packages = real_data_folder / "packages" + packages.mkdir(parents=True, exist_ok=True) + shutil.copytree(packages, path / "packages", dirs_exist_ok=True) + + yield cfg @pytest.fixture @@ -125,11 +126,6 @@ def ignore_other_compilers(mocker, compiler_manager, compiler): mock_registered_compilers.return_value = {".json": ape_pm, **valid_compilers} -@pytest.fixture -def vyper_source_path(project): - return project.contracts_folder / "RandomVyperFile.vy" - - @pytest.fixture def account(): return ape.accounts.test_accounts[0] diff --git a/tests/contracts/CircularImport1.sol b/tests/contracts/CircularImport1.sol index 15b06bb..f2bba36 100644 --- a/tests/contracts/CircularImport1.sol +++ b/tests/contracts/CircularImport1.sol @@ -2,7 +2,7 @@ pragma solidity ^0.8.4; -import "CircularImport2.sol"; +import "contracts/CircularImport2.sol"; contract CircularImport1 { function foo() pure public returns(bool) { diff --git a/tests/contracts/CircularImport2.sol b/tests/contracts/CircularImport2.sol index 1da7f40..1f6bbc2 100644 --- a/tests/contracts/CircularImport2.sol +++ b/tests/contracts/CircularImport2.sol @@ -2,7 +2,7 @@ pragma solidity ^0.8.4; -import "CircularImport1.sol"; +import "contracts/CircularImport1.sol"; contract CircularImport2 { function foo() pure public returns(bool) { diff --git a/tests/contracts/ContractUsingLibraryNotInSameSource.sol b/tests/contracts/ContractUsingLibraryNotInSameSource.sol new file mode 100644 index 0000000..db072d8 --- /dev/null +++ b/tests/contracts/ContractUsingLibraryNotInSameSource.sol @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: GPL-3.0 +// Borrowed from Solidity documentation. +pragma solidity >=0.6.0 <0.9.0; + +import "./LibraryFun.sol"; + +contract ContractUsingLibraryNotInSameSource { + Data knownValues; + + function register(uint value) public { + // The library functions can be called without a + // specific instance of the library, since the + // "instance" will be the current contract. + require(ExampleLibrary.insert(knownValues, value)); + } + // In this contract, we can also directly access knownValues.flags, if we want. +} diff --git a/tests/contracts/ImportOlderDependency.sol b/tests/contracts/ImportOlderDependency.sol index af8a797..4c96308 100644 --- a/tests/contracts/ImportOlderDependency.sol +++ b/tests/contracts/ImportOlderDependency.sol @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -import "@remapping/contracts/OlderDependency.sol"; +import "@dependency/contracts/OlderDependency.sol"; contract ImportOlderDependency { function foo() pure public returns(bool) { diff --git a/tests/contracts/ImportSourceWithEqualSignVersion.sol b/tests/contracts/ImportSourceWithEqualSignVersion.sol index f4fa3ba..bc6e841 100644 --- a/tests/contracts/ImportSourceWithEqualSignVersion.sol +++ b/tests/contracts/ImportSourceWithEqualSignVersion.sol @@ -4,8 +4,8 @@ pragma solidity ^0.8.4; // The file SpecificVersionWithEqualSign.sol has pragma spec '=0.8.12'. // This means that these files should all compile using that version. -import "SpecificVersionWithEqualSign.sol"; -import "CompilesOnce.sol"; +import "contracts/SpecificVersionWithEqualSign.sol"; +import "contracts/CompilesOnce.sol"; contract ImportSourceWithEqualSignVersion { function foo() pure public returns(bool) { diff --git a/tests/contracts/ImportSourceWithNoPrefixVersion.sol b/tests/contracts/ImportSourceWithNoPrefixVersion.sol index c0c4ff1..38dfe38 100644 --- a/tests/contracts/ImportSourceWithNoPrefixVersion.sol +++ b/tests/contracts/ImportSourceWithNoPrefixVersion.sol @@ -4,8 +4,8 @@ pragma solidity ^0.8.4; // The file SpecificVersionWithEqualSign.sol has pragma spec '0.8.12'. // This means that these files should all compile using that version. -import "SpecificVersionNoPrefix.sol"; -import "CompilesOnce.sol"; +import "contracts/SpecificVersionNoPrefix.sol"; +import "contracts/CompilesOnce.sol"; contract ImportSourceWithNoPrefixVersion { function foo() pure public returns(bool) { diff --git a/tests/contracts/Imports.sol b/tests/contracts/Imports.sol index 462fcf0..5556066 100644 --- a/tests/contracts/Imports.sol +++ b/tests/contracts/Imports.sol @@ -2,13 +2,13 @@ pragma solidity ^0.8.4; -import * as Depend from "@remapping/contracts/Dependency.sol"; +import * as Depend from "@dependency/contracts/Dependency.sol"; import "./././././././././././././././././././././././././././././././././././MissingPragma.sol"; -import { MyStruct } from "CompilesOnce.sol"; +import { MyStruct } from "contracts/CompilesOnce.sol"; import "./subfolder/Relativecontract.sol"; -import "@remapping_2/Dependency.sol" as Depend2; -import "@remapping_2_brownie/BrownieContract.sol"; +import "@dependency/contracts/Dependency.sol" as Depend2; +import "@browniedependency/contracts/BrownieContract.sol"; import { Struct0, Struct1, @@ -17,11 +17,11 @@ import { Struct4, Struct5 } from "./NumerousDefinitions.sol"; -import "@styleofbrownie/BrownieStyleDependency.sol"; +import "@noncompilingdependency/contracts/CompilingContract.sol"; // Purposely repeat an import to test how the plugin handles that. -import "@styleofbrownie/BrownieStyleDependency.sol"; +import "@noncompilingdependency/contracts/CompilingContract.sol"; -import "@gnosis/common/Enum.sol"; +import "@safe/contracts/common/Enum.sol"; contract Imports { function foo() pure public returns(bool) { diff --git a/tests/contracts/LibraryFun.sol b/tests/contracts/LibraryFun.sol index 3cd7644..6330bd3 100644 --- a/tests/contracts/LibraryFun.sol +++ b/tests/contracts/LibraryFun.sol @@ -9,7 +9,7 @@ struct Data { mapping(uint => bool) flags; } -library Set { +library ExampleLibrary { // Note that the first parameter is of type "storage // reference" and thus only its storage address and not // its contents is passed as part of the call. This is a @@ -46,14 +46,14 @@ library Set { } -contract C { +contract ContractUsingLibraryInSameSource { Data knownValues; function register(uint value) public { // The library functions can be called without a // specific instance of the library, since the // "instance" will be the current contract. - require(Set.insert(knownValues, value)); + require(ExampleLibrary.insert(knownValues, value)); } // In this contract, we can also directly access knownValues.flags, if we want. } diff --git a/tests/contracts/SpecificVersionNoPrefix.sol b/tests/contracts/SpecificVersionNoPrefix.sol index 7e3cd91..7ada74a 100644 --- a/tests/contracts/SpecificVersionNoPrefix.sol +++ b/tests/contracts/SpecificVersionNoPrefix.sol @@ -2,7 +2,7 @@ pragma solidity 0.8.12; -import "CompilesOnce.sol"; +import "./CompilesOnce.sol"; contract SpecificVersionNoPrefix { function foo() pure public returns(bool) { diff --git a/tests/contracts/SpecificVersionNoPrefix2.sol b/tests/contracts/SpecificVersionNoPrefix2.sol index a2750b8..eaa9af0 100644 --- a/tests/contracts/SpecificVersionNoPrefix2.sol +++ b/tests/contracts/SpecificVersionNoPrefix2.sol @@ -6,7 +6,7 @@ pragma solidity 0.8.14; // Both specific versions import the same file. // This is an important test! -import "CompilesOnce.sol"; +import "contracts/CompilesOnce.sol"; contract SpecificVersionNoPrefix2 { function foo() pure public returns(bool) { diff --git a/tests/contracts/SpecificVersionWithEqualSign.sol b/tests/contracts/SpecificVersionWithEqualSign.sol index 623faac..e38e203 100644 --- a/tests/contracts/SpecificVersionWithEqualSign.sol +++ b/tests/contracts/SpecificVersionWithEqualSign.sol @@ -2,7 +2,7 @@ pragma solidity =0.8.12; -import "CompilesOnce.sol"; +import "contracts/CompilesOnce.sol"; contract SpecificVersionWithEqualSign { function foo() pure public returns(bool) { diff --git a/tests/contracts/UseYearn.sol b/tests/contracts/UseYearn.sol index fffd29d..f6c1f7c 100644 --- a/tests/contracts/UseYearn.sol +++ b/tests/contracts/UseYearn.sol @@ -1,8 +1,8 @@ // SPDX-License-Identifier: GPL-3.0 pragma solidity >=0.8.17; -import {VaultAPI} from "@vault/BaseStrategy.sol"; -import {VaultAPI as VaultMain} from "@vaultmain/BaseStrategy.sol"; +import {VaultAPI} from "@vault/contracts/BaseStrategy.sol"; +import {VaultAPI as VaultMain} from "@vaultmain/contracts/BaseStrategy.sol"; interface UseYearn is VaultAPI { diff --git a/tests/contracts/subfolder/UsingDependencyWithinSubFolder.sol b/tests/contracts/subfolder/UsingDependencyWithinSubFolder.sol index f77f92c..65d12f8 100644 --- a/tests/contracts/subfolder/UsingDependencyWithinSubFolder.sol +++ b/tests/contracts/subfolder/UsingDependencyWithinSubFolder.sol @@ -2,7 +2,7 @@ pragma solidity ^0.8.4; -import "@remapping/contracts/Dependency.sol"; +import "@dependency/contracts/Dependency.sol"; contract UsingDependencyWithinSubFolder { function foo() pure public returns(bool) { diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 0a7bc3d..c233ac2 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -1,540 +1,514 @@ -import json -import re -import shutil from pathlib import Path import pytest import solcx -from ape import reverts -from ape.contracts import ContractContainer +from ape import Project, reverts from ape.exceptions import CompilerError from ape.logging import LogLevel +from ethpm_types import ContractType from packaging.version import Version -from requests.exceptions import ConnectionError -from ape_solidity import Extension -from ape_solidity._utils import OUTPUT_SELECTION from ape_solidity.exceptions import IndexOutOfBoundsError -BASE_PATH = Path(__file__).parent / "contracts" -TEST_CONTRACT_PATHS = [ - p - for p in BASE_PATH.iterdir() - if ".cache" not in str(p) and not p.is_dir() and p.suffix == Extension.SOL.value -] -TEST_CONTRACTS = [str(p.stem) for p in TEST_CONTRACT_PATHS] -PATTERN_REQUIRING_COMMIT_HASH = re.compile(r"\d+\.\d+\.\d+\+commit\.[\d|a-f]+") EXPECTED_NON_SOLIDITY_ERR_MSG = "Unable to compile 'RandomVyperFile.vy' using Solidity compiler." - -# These are tested elsewhere, not in `test_compile`. -normal_test_skips = ( - "DifferentNameThanFile", - "MultipleDefinitions", - "RandomVyperFile", - "LibraryFun", - "JustAStruct", -) raises_because_not_sol = pytest.raises(CompilerError, match=EXPECTED_NON_SOLIDITY_ERR_MSG) -DEFAULT_OPTIMIZER = {"enabled": True, "runs": 200} -@pytest.mark.parametrize( - "contract", - [c for c in TEST_CONTRACTS if all(n not in str(c) for n in normal_test_skips)], -) -def test_compile(project, contract): - assert contract in project.contracts, ", ".join([n for n in project.contracts.keys()]) - contract = project.contracts[contract] - assert contract.source_id == f"{contract.name}.sol" +def test_get_config(project, compiler): + actual = compiler.get_config(project=project) + assert actual.evm_version == "constantinople" -def test_compile_performance(benchmark, compiler, project): +def test_get_import_remapping(project, compiler): + actual = compiler.get_import_remapping(project=project) + expected = { + "@browniedependency": "contracts/.cache/browniedependency/local", + "@dependency": "contracts/.cache/dependency/local", + "@dependencyofdependency": "contracts/.cache/dependencyofdependency/local", + "@noncompilingdependency": "contracts/.cache/noncompilingdependency/local", + "@openzeppelin": "contracts/.cache/openzeppelin/4.5.0", + "@safe": "contracts/.cache/safe/1.3.0", + "@vault": "contracts/.cache/vault/v0.4.5", + "@vaultmain": "contracts/.cache/vaultmain/master", + } + for key, value in expected.items(): + assert key in actual + assert actual[key] == value + + +def test_get_import_remapping_handles_config(project, compiler): """ - See https://pytest-benchmark.readthedocs.io/en/latest/ + Show you can override default remappings. + Normally, these are deduced from dependencies, but you can change them + and/or add new ones. """ - source_path = project.contracts_folder / "MultipleDefinitions.sol" - result = benchmark.pedantic(compiler.compile, args=([source_path],), rounds=1) - assert len(result) > 0 - + new_value = "NEW_VALUE" + cfg = { + "solidity": { + "import_remapping": [ + "@dependency=dependency", # Backwards compat! + "@dependencyofdependency=dependencyofdependency/local", + f"@vaultmain={new_value}", # Changing a dependency + f"@{new_value}={new_value}123", # Adding something new + ] + }, + "dependencies": project.config.dependencies, + } + with project.temp_config(**cfg): + actual = compiler.get_import_remapping(project=project) -def test_compile_solc_not_installed(project, fake_no_installs): - assert len(project.load_contracts(use_cache=False)) > 0 + # Show it is backwards compatible (still works w/o changing cfg) + assert actual["@dependency"] == "contracts/.cache/dependency/local" + assert actual["@dependencyofdependency"] == "contracts/.cache/dependencyofdependency/local" + # Show we can change a dependency. + assert actual["@vaultmain"] == new_value + # Show we can add a new remapping (quiet dependency). + assert actual[f"@{new_value}"] == f"{new_value}123" + # Show other dependency-deduced remappings still work. + assert actual["@browniedependency"] == "contracts/.cache/browniedependency/local" -def test_compile_when_offline(project, compiler, mocker): - # When offline, getting solc versions raises a requests connection error. - # This should trigger the plugin to return an empty list. - patch = mocker.patch("ape_solidity.compiler.get_installable_solc_versions") - patch.side_effect = ConnectionError +def test_get_imports(project, compiler): + source_id = "contracts/ImportSourceWithEqualSignVersion.sol" + path = project.sources.lookup(source_id) + # Source (total) only has these 2 imports. + expected = ( + "contracts/SpecificVersionWithEqualSign.sol", + "contracts/CompilesOnce.sol", + ) + actual = compiler.get_imports((path,), project=project) + assert source_id in actual + assert all(e in actual[source_id] for e in expected) - # Using a non-specific contract - doesn't matter too much which one. - source_path = project.contracts_folder / "MultipleDefinitions.sol" - result = compiler.compile([source_path]) - assert len(result) > 0, "Nothing got compiled." +def test_get_imports_indirect(project, compiler): + """ + Show that twice-removed indirect imports show up. This is required + for accurate version mapping. + """ -def test_compile_multiple_definitions_in_source(project, compiler): - source_path = project.contracts_folder / "MultipleDefinitions.sol" - result = compiler.compile([source_path]) - assert len(result) == 2 - assert [r.name for r in result] == ["IMultipleDefinitions", "MultipleDefinitions"] - assert all(r.source_id == "MultipleDefinitions.sol" for r in result) + source_id = "contracts/IndirectlyImportingMoreConstrainedVersion.sol" + path = project.sources.lookup(source_id) + expected = ( + # These 2 are directly imported. + "contracts/ImportSourceWithEqualSignVersion.sol", + "contracts/IndirectlyImportingMoreConstrainedVersionCompanion.sol", + # These are 2 are imported by the imported. + "contracts/SpecificVersionWithEqualSign.sol", + "contracts/CompilesOnce.sol", + ) + actual = compiler.get_imports((path,), project=project) + assert source_id in actual + actual_str = ", ".join(list(actual[source_id])) + for ex in expected: + assert ex in actual[source_id], f"{ex} not in {actual_str}" - assert project.MultipleDefinitions - assert project.IMultipleDefinitions +def test_get_imports_complex(project, compiler): + """ + `contracts/Imports.sol` imports sources in every possible + way. This test shows that we are able to detect all those + unique ways of importing. + """ + path = project.sources.lookup("contracts/Imports.sol") + assert path is not None, "Failed to find Imports test contract." -def test_compile_specific_order(project, compiler): - # NOTE: This test seems random but it's important! - # It replicates a bug where the first contract had a low solidity version - # and the second had a bunch of imports. - ordered_files = [ - project.contracts_folder / "OlderVersion.sol", - project.contracts_folder / "Imports.sol", - ] - compiler.compile(ordered_files) + actual = compiler.get_imports((path,), project=project) + expected = { + "contracts/CompilesOnce.sol": [], + "contracts/Imports.sol": [ + "contracts/.cache/browniedependency/local/contracts/BrownieContract.sol", + "contracts/.cache/dependency/local/contracts/Dependency.sol", + "contracts/.cache/dependencyofdependency/local/contracts/DependencyOfDependency.sol", + "contracts/.cache/noncompilingdependency/local/contracts/CompilingContract.sol", + "contracts/.cache/safe/1.3.0/contracts/common/Enum.sol", + "contracts/CompilesOnce.sol", + "contracts/MissingPragma.sol", + "contracts/NumerousDefinitions.sol", + "contracts/subfolder/Relativecontract.sol", + ], + "contracts/MissingPragma.sol": [], + "contracts/NumerousDefinitions.sol": [], + "contracts/subfolder/Relativecontract.sol": [], + } + for base, imports in expected.items(): + assert base in actual + assert actual[base] == imports -def test_compile_missing_version(project, compiler, temp_solcx_path): +def test_get_imports_dependencies(project, compiler): """ - Test the compilation of a contract with no defined pragma spec. - - The plugin should implicitly download the latest version to compile the - contract with. `temp_solcx_path` is used to simulate an environment without - compilers installed. + Show all the affected dependency contracts get included in the imports list. """ - assert not solcx.get_installed_solc_versions() - contract_types = compiler.compile([project.contracts_folder / "MissingPragma.sol"]) - assert len(contract_types) == 1 - installed_versions = solcx.get_installed_solc_versions() - assert len(installed_versions) == 1 - assert installed_versions[0] == max(solcx.get_installable_solc_versions()) - - -def test_compile_contract_with_different_name_than_file(project): - file_name = "DifferentNameThanFile.sol" - contract = project.contracts["ApeDifferentNameThanFile"] - assert contract.source_id == file_name + source_id = "contracts/UseYearn.sol" + path = project.sources.lookup(source_id) + import_ls = compiler.get_imports((path,), project=project) + actual = import_ls[source_id] + token_path = "contracts/.cache/openzeppelin/4.5.0/contracts/token" + expected = [ + f"{token_path}/ERC20/ERC20.sol", + f"{token_path}/ERC20/IERC20.sol", + f"{token_path}/ERC20/extensions/IERC20Metadata.sol", + f"{token_path}/ERC20/utils/SafeERC20.sol", + "contracts/.cache/openzeppelin/4.5.0/contracts/utils/Address.sol", + "contracts/.cache/openzeppelin/4.5.0/contracts/utils/Context.sol", + "contracts/.cache/vault/v0.4.5/contracts/BaseStrategy.sol", + "contracts/.cache/vaultmain/master/contracts/BaseStrategy.sol", + ] + assert actual == expected -def test_compile_only_returns_contract_types_for_inputs(compiler, project): - # The compiler has to compile multiple files for 'Imports.sol' (it imports stuff). - # However - it should only return a single contract type in this case. - contract_types = compiler.compile([project.contracts_folder / "Imports.sol"]) - assert len(contract_types) == 1 - assert contract_types[0].name == "Imports" +def test_get_imports_vyper_file(project, compiler): + path = Path(__file__).parent / "contracts" / "RandomVyperFile.vy" + assert path.is_file(), f"Setup failed - file not found {path}" + with raises_because_not_sol: + compiler.get_imports((path,)) -def test_compile_vyper_contract(compiler, vyper_source_path): - with raises_because_not_sol: - compiler.compile([vyper_source_path]) +def test_get_imports_full_project(project, compiler): + paths = [x for x in project.sources.paths if x.suffix == ".sol"] + actual = compiler.get_imports(paths, project=project) + assert len(actual) > 0 + # Prove that every import source also is present in the import map. + for imported_source_ids in actual.values(): + for source_id in imported_source_ids: + assert source_id in actual, f"{source_id}'s imports not present." -def test_compile_just_a_struct(compiler, project): +def test_get_version_map(project, compiler): """ - Before, you would get a nasty index error, even though this is valid Solidity. - The fix involved using nicer access to "contracts" in the standard output JSON. + Test that a strict version pragma is recognized in the version map. """ - contract_types = compiler.compile([project.contracts_folder / "JustAStruct.sol"]) - assert len(contract_types) == 0 - - -def test_get_imports(project, compiler): - test_contract_paths = [ - p - for p in project.contracts_folder.iterdir() - if ".cache" not in str(p) and not p.is_dir() and p.suffix == Extension.SOL.value - ] - import_dict = compiler.get_imports(test_contract_paths, project.contracts_folder) - contract_imports = import_dict["Imports.sol"] - # NOTE: make sure there aren't duplicates - assert len([x for x in contract_imports if contract_imports.count(x) > 1]) == 0 - # NOTE: returning a list - assert isinstance(contract_imports, list) - # NOTE: in case order changes - expected = { - ".cache/BrownieDependency/local/BrownieContract.sol", - ".cache/BrownieStyleDependency/local/BrownieStyleDependency.sol", - ".cache/TestDependency/local/Dependency.sol", - ".cache/gnosis/v1.3.0/common/Enum.sol", - "CompilesOnce.sol", - "MissingPragma.sol", - "NumerousDefinitions.sol", - "subfolder/Relativecontract.sol", - } - assert set(contract_imports) == expected - - -def test_get_imports_cache_folder(project, compiler): - """Test imports when cache folder is configured""" - compile_config = project.config_manager.get_config("compile") - og_cache_colder = compile_config.cache_folder - compile_config.cache_folder = project.path / ".cash" - # assert False - test_contract_paths = [ - p - for p in project.contracts_folder.iterdir() - if ".cache" not in str(p) and not p.is_dir() and p.suffix == Extension.SOL.value - ] - # Using a different base path here because the cache folder is in the project root - import_dict = compiler.get_imports(test_contract_paths, project.path) - contract_imports = import_dict["contracts/Imports.sol"] - # NOTE: make sure there aren't duplicates - assert len([x for x in contract_imports if contract_imports.count(x) > 1]) == 0 - # NOTE: returning a list - assert isinstance(contract_imports, list) - # NOTE: in case order changes - expected = { - ".cash/BrownieDependency/local/BrownieContract.sol", - ".cash/BrownieStyleDependency/local/BrownieStyleDependency.sol", - ".cash/TestDependency/local/Dependency.sol", - ".cash/gnosis/v1.3.0/common/Enum.sol", - "contracts/CompilesOnce.sol", - "contracts/MissingPragma.sol", - "contracts/NumerousDefinitions.sol", - "contracts/subfolder/Relativecontract.sol", - } - assert set(contract_imports) == expected + path = project.sources.lookup("contracts/SpecificVersionWithEqualSign.sol") + actual = compiler.get_version_map((path,), project=project) + expected_version = Version("0.8.12+commit.f00d7308") + expected_sources = ("SpecificVersionWithEqualSign",) + assert expected_version in actual - # Reset because this config is stateful across tests - compile_config.cache_folder = og_cache_colder - shutil.rmtree(og_cache_colder, ignore_errors=True) + actual_ids = [x.stem for x in actual[expected_version]] + assert all(e in actual_ids for e in expected_sources) -def test_get_imports_raises_when_non_solidity_files(compiler, vyper_source_path): - with raises_because_not_sol: - compiler.get_imports([vyper_source_path]) - - -def test_get_import_remapping(compiler, project, config): - import_remapping = compiler.get_import_remapping() - assert import_remapping == { - "@remapping_2_brownie": ".cache/BrownieDependency/local", - "@dependency_remapping": ".cache/DependencyOfDependency/local", - "@remapping_2": ".cache/TestDependency/local", - "@remapping/contracts": ".cache/TestDependency/local", - "@styleofbrownie": ".cache/BrownieStyleDependency/local", - "@openzeppelin/contracts": ".cache/OpenZeppelin/v4.7.1", - "@oz/contracts": ".cache/OpenZeppelin/v4.5.0", - "@vault": ".cache/vault/v0.4.5", - "@vaultmain": ".cache/vault/master", - "@gnosis": ".cache/gnosis/v1.3.0", - } +def test_get_version_map_importing_more_constrained_version(project, compiler): + """ + Test that a strict version pragma in an imported source is recognized + in the version map. + """ + # This file's version is not super constrained, but it imports + # a different source that does have a strict constraint. + path = project.sources.lookup("contracts/ImportSourceWithEqualSignVersion.sol") - with config.using_project(project.path / "ProjectWithinProject") as proj: - # Trigger downloading dependencies in new ProjectWithinProject - dependencies = proj.dependencies - assert dependencies - # Should be different now that we have changed projects. - second_import_remapping = compiler.get_import_remapping() - assert second_import_remapping + actual = compiler.get_version_map((path,), project=project) + expected_version = Version("0.8.12+commit.f00d7308") + expected_sources = ("ImportSourceWithEqualSignVersion", "SpecificVersionWithEqualSign") + assert expected_version in actual - assert import_remapping != second_import_remapping + actual_ids = [x.stem for x in actual[expected_version]] + assert all(e in actual_ids for e in expected_sources) -def test_brownie_project(compiler, config): - brownie_project_path = Path(__file__).parent / "BrownieProject" - with config.using_project(brownie_project_path) as project: - assert isinstance(project.BrownieContract, ContractContainer) +def test_get_version_map_indirectly_importing_more_constrained_version(project, compiler): + """ + Test that a strict version pragma in a source imported by an imported + source (twice removed) is recognized in the version map. + """ + # This file's version is not super constrained, but it imports + # a different source that imports another source that does have a constraint. + path = project.sources.lookup("contracts/IndirectlyImportingMoreConstrainedVersion.sol") - # Ensure can access twice (to make sure caching does not break anything). - _ = project.BrownieContract + actual = compiler.get_version_map((path,), project=project) + expected_version = Version("0.8.12+commit.f00d7308") + expected_sources = ( + "IndirectlyImportingMoreConstrainedVersion", + "ImportSourceWithEqualSignVersion", + "SpecificVersionWithEqualSign", + ) + assert expected_version in actual + actual_ids = [x.stem for x in actual[expected_version]] + assert all(e in actual_ids for e in expected_sources) -def test_compile_single_source_with_no_imports(compiler, config): - # Tests against an important edge case that was discovered - # where the source file was individually compiled and it had no imports. - path = Path(__file__).parent / "DependencyOfDependency" - with config.using_project(path) as project: - assert isinstance(project.DependencyOfDependency, ContractContainer) +def test_get_version_map_dependencies(project, compiler): + """ + Show all the affected dependency contracts get included in the version map. + """ + source_id = "contracts/UseYearn.sol" + older_example = "contracts/ImportOlderDependency.sol" + paths = [project.sources.lookup(x) for x in (source_id, older_example)] + actual = compiler.get_version_map(paths, project=project) + + fail_msg = f"versions: {', '.join([str(x) for x in actual])}" + assert len(actual) == 2, fail_msg + + versions = sorted(list(actual.keys())) + older = versions[0] # Via ImportOlderDependency + latest = versions[1] # via UseYearn + + oz_token = "contracts/.cache/openzeppelin/4.5.0/contracts/token" + expected_latest_source_ids = [ + f"{oz_token}/ERC20/ERC20.sol", + f"{oz_token}/ERC20/IERC20.sol", + f"{oz_token}/ERC20/extensions/IERC20Metadata.sol", + f"{oz_token}/ERC20/utils/SafeERC20.sol", + "contracts/.cache/openzeppelin/4.5.0/contracts/utils/Address.sol", + "contracts/.cache/openzeppelin/4.5.0/contracts/utils/Context.sol", + "contracts/.cache/vault/v0.4.5/contracts/BaseStrategy.sol", + "contracts/.cache/vaultmain/master/contracts/BaseStrategy.sol", + source_id, + ] + expected_older_source_ids = [ + "contracts/.cache/dependency/local/contracts/OlderDependency.sol", + older_example, + ] + expected_latest_source_paths = {project.path / e for e in expected_latest_source_ids} + expected_oldest_source_paths = {project.path / e for e in expected_older_source_ids} + assert len(actual[latest]) == len(expected_latest_source_paths) + assert actual[latest] == expected_latest_source_paths + assert actual[older] == expected_oldest_source_paths -def test_version_specified_in_config_file(compiler, config): - path = Path(__file__).parent / "VersionSpecifiedInConfig" - with config.using_project(path) as project: - source_path = project.contracts_folder / "VersionSpecifiedInConfig.sol" - version_map = compiler.get_version_map(source_path) - actual_versions = ", ".join(str(v) for v in version_map) - fail_msg = f"Actual versions: {actual_versions}" - expected_version = Version("0.8.12+commit.f00d7308") - assert expected_version in version_map, fail_msg - assert version_map[expected_version] == {source_path}, fail_msg +def test_get_version_map_picks_most_constrained_version(project, compiler): + """ + Test that if given both a file that can compile at the latest version + and a file that requires a lesser version but also imports the same file + that could compile at the latest version, that they are all designated + to compile using the lesser version. + """ + source_ids = ( + "contracts/CompilesOnce.sol", + "contracts/IndirectlyImportingMoreConstrainedVersion.sol", + ) + paths = [project.sources.lookup(x) for x in source_ids] + actual = compiler.get_version_map(paths, project=project) + expected_version = Version("0.8.12+commit.f00d7308") + assert expected_version in actual + for path in paths: + assert path in actual[expected_version], f"{path} is missing!" -def test_get_version_map(project, compiler): - # Files are selected in order to trigger `CompilesOnce.sol` to - # get removed from version '0.8.12'. - cache_folder = project.contracts_folder / ".cache" - shutil.rmtree(cache_folder, ignore_errors=True) - - file_paths = [ - project.contracts_folder / "ImportSourceWithEqualSignVersion.sol", - project.contracts_folder / "SpecificVersionNoPrefix.sol", - project.contracts_folder / "CompilesOnce.sol", - project.contracts_folder / "Imports.sol", # Uses mapped imports! - ] - version_map = compiler.get_version_map(file_paths) - assert len(version_map) == 2 +def test_get_version_map_version_specified_in_config_file(compiler): + path = Path(__file__).parent / "VersionSpecifiedInConfig" + project = Project(path) + paths = [p for p in project.sources.paths if p.suffix == ".sol"] + actual = compiler.get_version_map(paths, project=project) expected_version = Version("0.8.12+commit.f00d7308") - latest_version = [v for v in version_map if v != expected_version][0] - assert all([f in version_map[expected_version] for f in file_paths[:-1]]) + assert len(actual) == 1 + assert expected_version in actual + assert len(actual[expected_version]) > 0 - latest_version_sources = version_map[latest_version] - assert len(latest_version_sources) >= 10, "Did the import remappings load correctly?" - assert file_paths[-1] in latest_version_sources - # Will fail if the import remappings have not loaded yet. - assert all([f.is_file() for f in file_paths]) +def test_get_version_map_raises_on_non_solidity_sources(project, compiler): + path = project.contracts_folder / "RandomVyperFile.vy" + with raises_because_not_sol: + compiler.get_version_map((path,), project=project) -def test_get_version_map_single_source(compiler, project): - # Source has no imports - source = project.contracts_folder / "OlderVersion.sol" - actual = compiler.get_version_map([source]) - expected = {Version("0.5.16+commit.9c3226ce"): {source}} - assert len(actual) == 1 - assert actual == expected, f"Actual version: {[k for k in actual.keys()][0]}" +def test_get_version_map_full_project(project, compiler): + paths = [x for x in project.sources.paths if x.suffix == ".sol"] + actual = compiler.get_version_map(paths, project=project) + latest = sorted(list(actual.keys()), reverse=True)[0] + v0812 = Version("0.8.12+commit.f00d7308") + vold = Version("0.4.26+commit.4563c3fc") + assert v0812 in actual + assert vold in actual + v0812_1 = project.path / "contracts/ImportSourceWithEqualSignVersion.sol" + v0812_2 = project.path / "contracts/IndirectlyImportingMoreConstrainedVersion.sol" -def test_get_version_map_raises_on_non_solidity_sources(compiler, vyper_source_path): - with raises_because_not_sol: - compiler.get_version_map([vyper_source_path]) + assert v0812_1 in actual[v0812], "Constrained version files missing" + assert v0812_2 in actual[v0812], "Constrained version files missing" + # TDD: This was happening during development of 0.8.0. + assert v0812_1 not in actual[latest], f"{v0812_1.stem} ended up in latest" + assert v0812_2 not in actual[latest], f"{v0812_2.stem} ended up in latest" -def test_compiler_data_in_manifest(project): - def run_test(manifest): - compilers = [c for c in manifest.compilers if c.name == "solidity"] - latest_version = max(c.version for c in compilers) + # TDD: Old file ending up in multiple spots. + older_file = project.path / "contracts/.cache/dependency/local/contracts/OlderDependency.sol" + assert older_file in actual[vold] + for vers, fileset in actual.items(): + if vers == vold: + continue - compiler_latest = [c for c in compilers if str(c.version) == latest_version][0] - compiler_0812 = [c for c in compilers if str(c.version) == "0.8.12+commit.f00d7308"][0] - compiler_0612 = [c for c in compilers if str(c.version) == "0.6.12+commit.27d51765"][0] - compiler_0426 = [c for c in compilers if str(c.version) == "0.4.26+commit.4563c3fc"][0] + assert older_file not in fileset, f"Oldest file also appears in version {vers}" - # Compiler name test - for compiler in (compiler_latest, compiler_0812, compiler_0612, compiler_0426): - assert compiler.name == "solidity" - assert compiler.settings["optimizer"] == DEFAULT_OPTIMIZER - assert compiler.settings["evmVersion"] == "constantinople" - # No remappings for sources in the following compilers - assert ( - "remappings" not in compiler_0812.settings - ), f"Remappings found: {compiler_0812.settings['remappings']}" +def test_get_compiler_settings(project, compiler): + path = project.sources.lookup("contracts/Imports.sol") + actual = compiler.get_compiler_settings((path,), project=project) + # No reason (when alone) to not use + assert len(actual) == 1 - assert ( - "@openzeppelin/contracts=.cache/OpenZeppelin/v4.7.1" - in compiler_latest.settings["remappings"] - ) - assert "@vault=.cache/vault/v0.4.5" in compiler_latest.settings["remappings"] - assert "@vaultmain=.cache/vault/master" in compiler_latest.settings["remappings"] - common_suffix = ".cache/TestDependency/local" - expected_remappings = ( - "@remapping_2_brownie=.cache/BrownieDependency/local", - "@dependency_remapping=.cache/DependencyOfDependency/local", - f"@remapping_2={common_suffix}", - f"@remapping/contracts={common_suffix}", - "@styleofbrownie=.cache/BrownieStyleDependency/local", - ) - actual_remappings = compiler_latest.settings["remappings"] - assert all(x in actual_remappings for x in expected_remappings) - assert all( - b >= a for a, b in zip(actual_remappings, actual_remappings[1:]) - ), "Import remappings should be sorted" - assert f"@remapping/contracts={common_suffix}" in compiler_0426.settings["remappings"] - assert "UseYearn" in compiler_latest.contractTypes - assert "@gnosis=.cache/gnosis/v1.3.0" in compiler_latest.settings["remappings"] - - # Compiler contract types test - assert set(compiler_0812.contractTypes) == { - "ImportSourceWithEqualSignVersion", - "ImportSourceWithNoPrefixVersion", - "ImportingLessConstrainedVersion", - "IndirectlyImportingMoreConstrainedVersion", - "IndirectlyImportingMoreConstrainedVersionCompanion", - "SpecificVersionNoPrefix", - "SpecificVersionRange", - "SpecificVersionWithEqualSign", - "CompilesOnce", - "IndirectlyImportingMoreConstrainedVersionCompanionImport", - } - assert set(compiler_0612.contractTypes) == {"RangedVersion", "VagueVersion"} - assert set(compiler_0426.contractTypes) == { - "ExperimentalABIEncoderV2", - "SpacesInPragma", - "ImportOlderDependency", - } + # 0.8.12 is hardcoded in some files, but none of those files should be here. + version = next(iter(actual.keys())) + assert version > Version("0.8.12+commit.f00d7308") - # Ensure compiled first so that the local cached manifest exists. - # We want to make ape-solidity has placed the compiler info in there. - project.load_contracts(use_cache=False) - if man := project.local_project.manifest: - run_test(man) - else: - pytest.fail("Manifest was not cached after loading.") - - # The extracted manifest should produce the same result. - run_test(project.extract_manifest()) - - -def test_get_versions(compiler, project): - # NOTE: the expected versions **DO NOT** contain commit hashes here - # because we can only get the commit hash of installed compilers - # and this returns all versions including uninstalled. - versions = compiler.get_versions(project.source_paths) - - # The "latest" version will always be in this list, but avoid - # asserting on it directly to handle new "latest"'s coming out. - expected = ("0.4.26", "0.5.16", "0.6.12", "0.8.12", "0.8.14") - assert all([e in versions for e in expected]) - - -def test_get_compiler_settings(compiler, project): - # We start with the following sources as inputs: - # `forced_812_*` are forced to compile using solc 0.8.12 because its - # import is hard-pinned to it. - forced_812_0 = "ImportSourceWithEqualSignVersion.sol" - forced_812_1 = "SpecificVersionNoPrefix.sol" - # The following are unspecified and not used by the above. - # Thus are compiled on the latest. - latest_0 = "CompilesOnce.sol" - latest_1 = "Imports.sol" # Uses mapped imports! - file_paths = [ - project.contracts_folder / x for x in (forced_812_0, forced_812_1, latest_0, latest_1) - ] + settings = actual[version] + assert settings["optimizer"] == {"enabled": True, "runs": 200} - # Actual should contain all the settings for every file used in a would-be compile. - actual = compiler.get_compiler_settings(file_paths) - - # The following is indirectly used by 0.8.12 from an import. - forced_812_0_import = "SpecificVersionWithEqualSign.sol" - - # These are the versions we are checking in our expectations. - v812 = Version("0.8.12+commit.f00d7308") - latest = max(list(actual.keys())) - - expected_v812_contracts = [forced_812_0, forced_812_0_import, forced_812_1] - expected_latest_contracts = [ - latest_0, - latest_1, - # The following are expected imported sources. - ".cache/BrownieDependency/local/BrownieContract.sol", - "CompilesOnce.sol", - ".cache/TestDependency/local/Dependency.sol", - ".cache/DependencyOfDependency/local/DependencyOfDependency.sol", - "subfolder/Relativecontract.sol", - ".cache/gnosis/v1.3.0/common/Enum.sol", - ] - expected_remappings = [ - "@remapping_2_brownie=.cache/BrownieDependency/local", - "@dependency_remapping=.cache/DependencyOfDependency/local", - "@remapping_2=.cache/TestDependency/local", - "@remapping/contracts=.cache/TestDependency/local", - "@styleofbrownie=.cache/BrownieStyleDependency/local", - "@gnosis=.cache/gnosis/v1.3.0", + # NOTE: These should be sorted! + assert settings["remappings"] == [ + "@browniedependency=contracts/.cache/browniedependency/local", + "@dependency=contracts/.cache/dependency/local", + "@dependencyofdependency=contracts/.cache/dependencyofdependency/local", + "@noncompilingdependency=contracts/.cache/noncompilingdependency/local", + "@safe=contracts/.cache/safe/1.3.0", ] - # Shared compiler defaults tests - expected_source_lists = (expected_v812_contracts, expected_latest_contracts) - for version, expected_sources in zip((v812, latest), expected_source_lists): - expected_sources.sort() - output_selection = actual[version]["outputSelection"] - assert actual[version]["optimizer"] == DEFAULT_OPTIMIZER - for _, item_selection in output_selection.items(): - for key, selection in item_selection.items(): - if key == "*": # All contracts - assert selection == OUTPUT_SELECTION - elif key == "": # All sources - assert selection == ["ast"] - - # Sort to help debug. - actual_sources = sorted([x for x in output_selection.keys()]) - - for expected_source_id in expected_sources: - assert ( - expected_source_id in actual_sources - ), f"{expected_source_id} not one of {', '.join(actual_sources)}" - - # Remappings test - actual_remappings = actual[latest]["remappings"] - assert isinstance(actual_remappings, list) - assert len(actual_remappings) == len(expected_remappings) - assert all(e in actual_remappings for e in expected_remappings) - assert all( - b >= a for a, b in zip(actual_remappings, actual_remappings[1:]) - ), "Import remappings should be sorted" - - # Tests against bug potentially preventing JSON decoding errors related - # to contract verification. - for key, output_json_dict in actual.items(): - assert json.dumps(output_json_dict) - - -def test_evm_version(compiler): - assert compiler.config.evm_version == "constantinople" - - -def test_source_map(project, compiler): - source_path = project.contracts_folder / "MultipleDefinitions.sol" - result = compiler.compile([source_path])[-1] - assert result.sourcemap.root == "124:87:0:-:0;;;;;;;;;;;;;;;;;;;" + # Set in config. + assert settings["evmVersion"] == "constantinople" + + # Should be all files (imports of imports etc.) + actual_files = sorted(list(settings["outputSelection"].keys())) + expected_files = [ + "contracts/.cache/browniedependency/local/contracts/BrownieContract.sol", + "contracts/.cache/dependency/local/contracts/Dependency.sol", + "contracts/.cache/dependencyofdependency/local/contracts/DependencyOfDependency.sol", + "contracts/.cache/noncompilingdependency/local/contracts/CompilingContract.sol", + "contracts/.cache/safe/1.3.0/contracts/common/Enum.sol", + "contracts/CompilesOnce.sol", + "contracts/Imports.sol", + "contracts/MissingPragma.sol", + "contracts/NumerousDefinitions.sol", + "contracts/subfolder/Relativecontract.sol", + ] + assert actual_files == expected_files + + # Output request is the same for all. + expected_output_request = { + "*": [ + "abi", + "bin-runtime", + "devdoc", + "userdoc", + "evm.bytecode.object", + "evm.bytecode.sourceMap", + "evm.deployedBytecode.object", + ], + "": ["ast"], + } + for output in settings["outputSelection"].values(): + assert output == expected_output_request + + +def test_get_standard_input_json(project, compiler): + paths = [x for x in project.sources.paths if x.suffix == ".sol"] + actual = compiler.get_standard_input_json(paths, project=project) + v0812 = Version("0.8.12+commit.f00d7308") + v056 = Version("0.5.16+commit.9c3226ce") + v0426 = Version("0.4.26+commit.4563c3fc") + latest = sorted(list(actual.keys()), reverse=True)[0] + + fail_msg = f"Versions: {', '.join([str(v) for v in actual])}" + assert v0812 in actual, fail_msg + assert v056 in actual, fail_msg + assert v0426 in actual, fail_msg + assert latest in actual, fail_msg + + v0812_sources = list(actual[v0812]["sources"].keys()) + v056_sources = list(actual[v056]["sources"].keys()) + v0426_sources = list(actual[v0426]["sources"].keys()) + latest_sources = list(actual[latest]["sources"].keys()) + + assert "contracts/ImportSourceWithEqualSignVersion.sol" not in latest_sources + assert "contracts/IndirectlyImportingMoreConstrainedVersion.sol" not in latest_sources + + # Some source expectations. + assert "contracts/CompilesOnce.sol" in v0812_sources + assert "contracts/SpecificVersionRange.sol" in v0812_sources + assert "contracts/ImportSourceWithNoPrefixVersion.sol" in v0812_sources + + assert "contracts/OlderVersion.sol" in v056_sources + assert "contracts/ImportOlderDependency.sol" in v0426_sources + assert "contracts/.cache/dependency/local/contracts/OlderDependency.sol" in v0426_sources + + +def test_compile(project, compiler): + path = project.sources.lookup("contracts/Imports.sol") + actual = [c for c in compiler.compile((path,), project=project)] + # We only get back the contracts we requested, even if it had to compile + # others (like imports) to get it to work. + assert len(actual) == 1 + assert isinstance(actual[0], ContractType) + assert actual[0].name == "Imports" + assert actual[0].source_id == "contracts/Imports.sol" + assert actual[0].deployment_bytecode is not None + assert actual[0].runtime_bytecode is not None + assert len(actual[0].abi) > 0 -def test_add_library(project, account, compiler, connection): - with pytest.raises(AttributeError): - # Does not exist yet because library is not deployed or known. - _ = project.C +def test_compile_performance(benchmark, compiler, project): + """ + See https://pytest-benchmark.readthedocs.io/en/latest/ + """ + path = project.sources.lookup("contracts/MultipleDefinitions.sol") + result = benchmark.pedantic( + lambda *args, **kwargs: [x for x in compiler.compile(*args, **kwargs)], + args=((path,),), + kwargs={"project": project}, + rounds=1, + ) + assert len(result) > 0 - library = project.Set.deploy(sender=account) - compiler.add_library(library) - # After deploying and adding the library, we can use contracts that need it. - assert project.C +def test_compile_multiple_definitions_in_source(project, compiler): + """ + Show that if multiple contracts / interfaces are defined in a single + source, that we get all of them when compiling. + """ + source_id = "contracts/MultipleDefinitions.sol" + path = project.sources.lookup(source_id) + result = [c for c in compiler.compile((path,), project=project)] + assert len(result) == 2 + assert [r.name for r in result] == ["IMultipleDefinitions", "MultipleDefinitions"] + assert all(r.source_id == source_id for r in result) + assert project.MultipleDefinitions + assert project.IMultipleDefinitions -def test_enrich_error_when_custom(compiler, project, owner, not_owner, connection): - compiler.compile((project.contracts_folder / "HasError.sol",)) - # Deploy so Ape know about contract type. - contract = owner.deploy(project.HasError, 1) - with pytest.raises(contract.Unauthorized) as err: - contract.withdraw(sender=not_owner) +def test_compile_contract_with_different_name_than_file(project, compiler): + source_id = "contracts/DifferentNameThanFile.sol" + path = project.sources.lookup(source_id) + actual = [c for c in compiler.compile((path,), project=project)] + assert len(actual) == 1 + assert actual[0].source_id == source_id - # TODO: Can remove hasattr check after race condition resolved in Core. - if hasattr(err.value, "inputs"): - assert err.value.inputs == {"addr": not_owner.address, "counter": 123} +def test_compile_only_returns_contract_types_for_inputs(project, compiler): + """ + Test showing only the requested contract types get returned. + """ + path = project.sources.lookup("contracts/Imports.sol") + contract_types = [c for c in compiler.compile((path,), project=project)] + assert len(contract_types) == 1 + assert contract_types[0].name == "Imports" -def test_enrich_error_when_custom_in_constructor(compiler, project, owner, not_owner, connection): - # Deploy so Ape know about contract type. - with reverts(project.HasError.Unauthorized) as err: - not_owner.deploy(project.HasError, 0) - # TODO: After ape 0.6.14, try this again. It is working locally but there - # may be a race condition causing it to fail? I added a fix to core that - # may resolve but I am not sure. - if hasattr(err.value, "inputs"): - assert err.value.inputs == {"addr": not_owner.address, "counter": 123} +def test_compile_vyper_contract(project, compiler): + path = project.contracts_folder / "RandomVyperFile.vy" + with raises_because_not_sol: + _ = [c for c in compiler.compile((path,), project=project)] -def test_enrich_error_when_builtin(project, owner, connection): - contract = project.BuiltinErrorChecker.deploy(sender=owner) - with pytest.raises(IndexOutOfBoundsError): - contract.checkIndexOutOfBounds(sender=owner) +def test_compile_just_a_struct(compiler, project): + """ + Before, you would get a nasty index error, even though this is valid Solidity. + The fix involved using nicer access to "contracts" in the standard output JSON. + """ + path = project.sources.lookup("contracts/JustAStruct.sol") + contract_types = [c for c in compiler.compile((path,), project=project)] + assert len(contract_types) == 0 -# TODO: Not yet used and super slow. -# def test_ast(project, compiler): -# source_path = project.contracts_folder / "MultipleDefinitions.sol" -# actual = compiler.compile([source_path])[-1].ast -# fn_node = actual.children[1].children[0] -# assert actual.ast_type == "SourceUnit" -# assert fn_node.classification == ASTClassification.FUNCTION +def test_compile_produces_source_map(project, compiler): + path = project.sources.lookup("contracts/MultipleDefinitions.sol") + result = [c for c in compiler.compile((path,), project=project)][-1] + assert result.sourcemap.root == "124:87:0:-:0;;;;;;;;;;;;;;;;;;;" -def test_via_ir(project, compiler): - source_path = project.contracts_folder / "StackTooDeep.sol" +def test_compile_via_ir(project, compiler): + path = project.contracts_folder / "StackTooDep.sol" source_code = """ // SPDX-License-Identifier: MIT @@ -580,28 +554,107 @@ def test_via_ir(project, compiler): """ # write source code to file - source_path.write_text(source_code) + path.write_text(source_code) try: - compiler.compile([source_path]) + [c for c in compiler.compile((path,), project=project)] except Exception as e: assert "Stack too deep" in str(e) - compiler.config.via_ir = True + with project.temp_config(solidity={"via_ir": True}): + _ = [c for c in compiler.compile((path,), project=project)] - compiler.compile([source_path]) + # delete source code file + path.unlink() - # delete source code file - source_path.unlink() - # flip the via_ir flag back to False - compiler.config.via_ir = False +@pytest.mark.install +def test_installs_from_compile(project, compiler, temp_solcx_path): + """ + Test the compilation of a contract with no defined pragma spec. + + The plugin should implicitly download the latest version to compile the + contract with. `temp_solcx_path` is used to simulate an environment without + compilers installed. + """ + assert not solcx.get_installed_solc_versions() + path = project.sources.lookup("contracts/MissingPragma.sol") + contract_types = [c for c in compiler.compile((path,), project=project)] + assert len(contract_types) == 1 + installed_versions = solcx.get_installed_solc_versions() + assert len(installed_versions) == 1 + assert installed_versions[0] == max(solcx.get_installable_solc_versions()) + + +def test_compile_project(project, compiler): + """ + Simple test showing the full project indeed compiles. + """ + paths = [x for x in project.sources.paths if x.suffix == ".sol"] + actual = [c for c in compiler.compile(paths, project=project)] + assert len(actual) > 0 + + +def test_compile_outputs_compiler_data_to_manifest(project, compiler): + project.update_manifest(compilers=[]) + path = project.sources.lookup("contracts/CompilesOnce.sol") + _ = [c for c in compiler.compile((path,), project=project)] + assert len(project.manifest.compilers or []) == 1 + actual = project.manifest.compilers[0] + assert actual.name == "solidity" + assert "CompilesOnce" in actual.contractTypes + assert actual.version == "0.8.26+commit.8a97fa7a" + # Compiling again should not add the same compiler again. + _ = [c for c in compiler.compile((path,), project=project)] + length_again = len(project.manifest.compilers or []) + assert length_again == 1 + + +def test_add_library(project, account, compiler, connection): + # Does not exist yet because library is not deployed or known. + with pytest.raises(AttributeError): + _ = project.ContractUsingLibraryInSameSource + with pytest.raises(AttributeError): + _ = project.ContractUsingLibraryNotInSameSource + + library = project.ExampleLibrary.deploy(sender=account) + compiler.add_library(library, project=project) + + # After deploying and adding the library, we can use contracts that need it. + assert project.ContractUsingLibraryNotInSameSource + assert project.ContractUsingLibraryInSameSource -def test_flatten(project, compiler, data_folder, caplog): - source_path = project.contracts_folder / "Imports.sol" +def test_enrich_error_when_custom(compiler, project, owner, not_owner, connection): + path = project.sources.lookup("contracts/HasError.sol") + _ = [c for c in compiler.compile((path,), project=project)] + + # Deploy so Ape know about contract type. + contract = owner.deploy(project.HasError, 1) + with pytest.raises(contract.Unauthorized) as err: + contract.withdraw(sender=not_owner) + + assert err.value.inputs == {"addr": not_owner.address, "counter": 123} + + +def test_enrich_error_when_custom_in_constructor(compiler, project, owner, not_owner, connection): + # Deploy so Ape know about contract type. + with reverts(project.HasError.Unauthorized) as err: + not_owner.deploy(project.HasError, 0) + + assert err.value.inputs == {"addr": not_owner.address, "counter": 123} + + +def test_enrich_error_when_builtin(project, owner, connection): + contract = project.BuiltinErrorChecker.deploy(sender=owner) + with pytest.raises(IndexOutOfBoundsError): + contract.checkIndexOutOfBounds(sender=owner) + + +def test_flatten(project, compiler, caplog): + path = project.sources.lookup("contracts/Imports.sol") with caplog.at_level(LogLevel.WARNING): - compiler.flatten_contract(source_path) + compiler.flatten_contract(path, project=project) actual = caplog.messages[-1] expected = ( "Conflicting licenses found: 'LGPL-3.0-only, MIT'. " @@ -609,15 +662,15 @@ def test_flatten(project, compiler, data_folder, caplog): ) assert actual == expected - source_path = project.contracts_folder / "ImportingLessConstrainedVersion.sol" - flattened_source = compiler.flatten_contract(source_path) - flattened_source_path = data_folder / "ImportingLessConstrainedVersionFlat.sol" - actual = str(flattened_source) - expected = str(flattened_source_path.read_text()) - assert actual == expected + path = project.sources.lookup("contracts/ImportingLessConstrainedVersion.sol") + flattened_source = compiler.flatten_contract(path, project=project) + flattened_source_path = ( + Path(__file__).parent / "data" / "ImportingLessConstrainedVersionFlat.sol" + ) + assert str(flattened_source) == str(flattened_source_path.read_text()) -def test_compile_code(compiler): +def test_compile_code(project, compiler): code = """ contract Contract { function snakes() pure public returns(bool) { @@ -625,7 +678,7 @@ def test_compile_code(compiler): } } """ - actual = compiler.compile_code(code, contractName="TestContractName") + actual = compiler.compile_code(code, project=project, contractName="TestContractName") assert actual.name == "TestContractName" assert len(actual.abi) > 0 assert actual.ast is not None diff --git a/tests/test_integration.py b/tests/test_integration.py index f099339..8b52390 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -13,11 +13,12 @@ def runner(): return CliRunner() -def test_compile_using_cli(ape_cli, runner): - result = runner.invoke(ape_cli, ("compile", "--force"), catch_exceptions=False) +def test_compile_using_cli(ape_cli, runner, project): + arguments = ["compile", "--project", f"{project.path}"] + result = runner.invoke(ape_cli, [*arguments, "--force"], catch_exceptions=False) assert result.exit_code == 0 assert "CompilesOnce" in result.output - result = runner.invoke(ape_cli, "compile", catch_exceptions=False) + result = runner.invoke(ape_cli, arguments, catch_exceptions=False) # Already compiled so does not compile again. assert "CompilesOnce" not in result.output @@ -32,12 +33,8 @@ def test_compile_using_cli(ape_cli, runner): "contracts/CompilesOnce.sol", ), ) -def test_compile_specified_contracts(ape_cli, runner, contract_path): - result = runner.invoke(ape_cli, ("compile", contract_path, "--force"), catch_exceptions=False) +def test_compile_specified_contracts(ape_cli, runner, contract_path, project): + arguments = ("compile", "--project", f"{project.path}", contract_path, "--force") + result = runner.invoke(ape_cli, arguments, catch_exceptions=False) assert result.exit_code == 0, result.output - assert "Compiling 'CompilesOnce.sol'" in result.output, f"Failed to compile {contract_path}." - - -def test_force_recompile(ape_cli, runner): - result = runner.invoke(ape_cli, ("compile", "--force"), catch_exceptions=False) - assert result.exit_code == 0 + assert "contracts/CompilesOnce.sol" in result.output, f"Failed to compile {contract_path}."