Skip to content

Commit

Permalink
fix: handle missing contracts folder ID in remappings (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jun 6, 2024
1 parent 4998efe commit f378390
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 46 deletions.
142 changes: 100 additions & 42 deletions ape_solidity/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
from collections import defaultdict
from collections.abc import Iterable, Iterator
from pathlib import Path
from typing import Any, Optional, Union
Expand All @@ -8,7 +9,7 @@
from ape.contracts import ContractInstance
from ape.exceptions import CompilerError, ConfigError, ContractLogicError, ProjectError
from ape.logging import logger
from ape.managers.project import ProjectManager
from ape.managers.project import LocalProject, ProjectManager
from ape.types import AddressType, ContractType
from ape.utils import cached_property, get_full_extension, get_relative_path
from ape.version import version
Expand Down Expand Up @@ -309,10 +310,12 @@ def unpack(dep):
return

for unpacked_dep in dep.unpack(pm.contracts_folder / ".cache"):
_key = key_map.get(unpacked_dep.name, f"@{unpacked_dep.name}")
if _key not in remapping:
remapping[_key] = get_cache_id(unpacked_dep)
# else, was specified or configured more appropriately.
main_key = key_map.get(unpacked_dep.name)
keys = (main_key,) if main_key else (f"@{unpacked_dep.name}", unpacked_dep.name)
for _key in keys:
if _key not in remapping:
remapping[_key] = get_cache_id(unpacked_dep)
# else, was specified or configured more appropriately.

remapping: dict[str, str] = {}
for key, value in cfg_remappings.items():
Expand Down Expand Up @@ -514,19 +517,19 @@ def get_standard_input_json_from_settings(
# and cater the error message accordingly.
if dependencies_needed := [x for x in missing_sources if str(x).startswith("@")]:
# Missing dependencies. Should only get here if dependencies are found
# in import-strs but are not installed anywhere (not in project or globally).
# in import-strs but are not installed (not in project or globally).
missing_str = ", ".join(dependencies_needed)
raise CompilerError(
f"Missing required dependencies '{missing_str}'. "
"Install them using `dependencies:` "
"in an ape-config.yaml or using the `ape pm install` command."
)

# Otherwise, we are missing project-level source files for some reason.
# This would only happen if the user passes in unexpected files outside
# of core.
missing_src_str = ", ".join(missing_sources)
raise CompilerError(f"Sources '{missing_src_str}' not found in '{pm.name}'.")
# Otherwise, we are missing project-level source files for some reason.
# This would only happen if the user passes in unexpected files outside
# of core.
missing_src_str = ", ".join(missing_sources)
raise CompilerError(f"Sources '{missing_src_str}' not found in '{pm.name}'.")

sources = {
x: {"content": (pm.path / x).read_text()}
Expand Down Expand Up @@ -914,33 +917,42 @@ def get_version_map_from_imports(

# If being used in another version AND no imports in this version require it,
# remove it from this version.
for solc_version, files in files_by_solc_version.copy().items():
for file in files.copy():
used_in_other_version = any(
[file in ls for v, ls in files_by_solc_version.items() if v != solc_version]
)
if not used_in_other_version:
cleaned_mapped: dict[Version, set[Path]] = defaultdict(set)
for solc_version, files in files_by_solc_version.items():
other_versions = {v: ls for v, ls in files_by_solc_version.items() if v != solc_version}
for file in files:
other_versions_used_in = {v for v in other_versions if file in other_versions[v]}
if not other_versions_used_in:
# This file is only in 1 version, which is perfect.
cleaned_mapped[solc_version].add(file)
continue

other_files = [f for f in files_by_solc_version[solc_version] if f != file]
used_in_imports = False
for other_file in other_files:
source_id = str(get_relative_path(other_file, pm.path))
import_paths = [pm.path / i for i in import_map.get(source_id, []) if i]
if file in import_paths:
used_in_imports = True
break
# This file is in multiple versions. Attempt to clean.
for other_version in other_versions_used_in:
# Other files that may need this file are any file that is not this file as well
# any file that is not also found the other version. We want to make sure
# before removing this file that it won't be needed.
other_files_that_may_need_this_file = [
f for f in files if f != file and f not in other_versions[other_version]
]
if other_files_that_may_need_this_file:
# This file is used by other files in this version, so we must keep it.
cleaned_mapped[solc_version].add(file)
continue

if not used_in_imports:
files_by_solc_version[solc_version].remove(file)
if not files_by_solc_version[solc_version]:
del files_by_solc_version[solc_version]
# Remove other the rest of files.
other_files_can_remove = [
f for f in files if f != file and f in other_versions[other_version]
]
for other_file in other_files_can_remove:
if other_file in cleaned_mapped[solc_version]:
cleaned_mapped[solc_version].remove(other_file)

result = {add_commit_hash(v): ls for v, ls in files_by_solc_version.items()}
result = {add_commit_hash(v): ls for v, ls in cleaned_mapped.items()}

# Sort, so it is a nicer version map and the rest of the compilation flow
# is more predictable.
return {k: result[k] for k in sorted(result)}
# is more predictable. Also, remove any lingering empties.
return {k: result[k] for k in sorted(result) if result[k]}

def _get_imported_source_paths(
self,
Expand Down Expand Up @@ -1137,6 +1149,7 @@ def _import_str_to_source_id(
) -> str:
pm = project or self.local_project
quote = '"' if '"' in _import_str else "'"
sep = "\\" if "\\" in _import_str else "/"

try:
end_index = _import_str.index(quote) + 1
Expand All @@ -1150,26 +1163,26 @@ def _import_str_to_source_id(

# Get all matches.
valid_matches: list[tuple[str, str]] = []
key = None
import_remap_key = None
base_path = None
for key, value in import_remapping.items():
if key not in import_str_value:
for check_remap_key, check_remap_value in import_remapping.items():
if check_remap_key not in import_str_value:
continue

valid_matches.append((key, value))
valid_matches.append((check_remap_key, check_remap_value))

if valid_matches:
key, value = max(valid_matches, key=lambda x: len(x[0]))
import_str_value = import_str_value.replace(key, value)
import_remap_key, import_remap_value = max(valid_matches, key=lambda x: len(x[0]))
import_str_value = import_str_value.replace(import_remap_key, import_remap_value)

if import_str_value.startswith("."):
base_path = source_path.parent
elif (pm.path / import_str_value).is_file():
base_path = pm.path
elif (pm.contracts_folder / import_str_value).is_file():
base_path = pm.contracts_folder
elif key is not None and key.startswith("@"):
nm = key[1:]
elif import_remap_key is not None and import_remap_key.startswith("@"):
nm = import_remap_key[1:]
for cfg_dep in pm.config.dependencies:
if (
cfg_dep.get("name") == nm
Expand All @@ -1178,8 +1191,53 @@ def _import_str_to_source_id(
):
base_path = Path(cfg_dep["project"])

if base_path is None:
# No base_path, do as-is.
import_str_parts = import_str_value.split(sep)
if base_path is None and ".cache" in import_str_parts:
# No base_path. First, check if the `contracts/` folder is missing,
# which is the case when compiling older Ape projects and some Foundry
# projects as well.
cache_index = import_str_parts.index(".cache")
nm_index = cache_index + 1
version_index = nm_index + 1
if version_index >= len(import_str_parts):
# Not sure.
return import_str_value

cache_folder_name = import_str_parts[nm_index]
cache_folder_version = import_str_parts[version_index]
dm = pm.dependencies
dependency = dm.get_dependency(cache_folder_name, cache_folder_version)
dep_project = dependency.project

if not isinstance(dep_project, LocalProject):
# TODO: Handle manifest-based projects as well.
# to work with old compiled manifests.
return import_str_value

contracts_dir = dep_project.contracts_folder
dep_path = dep_project.path
contracts_folder_name = f"{get_relative_path(contracts_dir, dep_path)}"
prefix_pth = dep_path / contracts_folder_name
start_idx = version_index + 1
suffix = sep.join(import_str_parts[start_idx:])
new_path = prefix_pth / suffix

if not new_path.is_file():
# Maybe this source is actually missing...
return import_str_value

adjusted_base_path = f"{sep.join(import_str_parts[:4])}{sep}{contracts_folder_name}"
adjusted_src_id = f"{adjusted_base_path}{sep}{suffix}"

# Also, correct import remappings now, since it didn't work.
if key := import_remap_key:
# Base path will now included the missing contracts name.
import_remapping[key] = adjusted_base_path

return adjusted_src_id

elif base_path is None:
# No base_path, return as-is.
return import_str_value

path = (base_path / import_str_value).resolve()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// SPDX-License-Identifier: MIT

pragma solidity ^0.8.4;

contract SubCompilingContract {
function foo() pure public returns(bool) {
return true;
}
}
4 changes: 4 additions & 0 deletions tests/ape-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ dependencies:
solidity:
# Using evm_version compatible with older and newer solidity versions.
evm_version: constantinople

import_remapping:
# Legacy support test (missing contracts key in import test)
- "@noncompilingdependency=noncompilingdependency"
7 changes: 5 additions & 2 deletions tests/contracts/Imports.sol
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ import {
Struct4,
Struct5
} from "./NumerousDefinitions.sol";
import "@noncompilingdependency/contracts/CompilingContract.sol";
import "@noncompilingdependency/CompilingContract.sol";
// Purposely repeat an import to test how the plugin handles that.
import "@noncompilingdependency/contracts/CompilingContract.sol";
import "@noncompilingdependency/CompilingContract.sol";

import "@safe/contracts/common/Enum.sol";

// Purposely exclude the contracts folder to test older Ape-style project imports.
import "@noncompilingdependency/subdir/SubCompilingContract.sol";

contract Imports {
function foo() pure public returns(bool) {
return true;
Expand Down
33 changes: 31 additions & 2 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_get_imports_complex(project, compiler):
"contracts/.cache/dependency/local/contracts/Dependency.sol",
"contracts/.cache/dependencyofdependency/local/contracts/DependencyOfDependency.sol",
"contracts/.cache/noncompilingdependency/local/contracts/CompilingContract.sol",
"contracts/.cache/noncompilingdependency/local/contracts/subdir/SubCompilingContract.sol", # noqa: E501
"contracts/.cache/safe/1.3.0/contracts/common/Enum.sol",
"contracts/CompilesOnce.sol",
"contracts/MissingPragma.sol",
Expand Down Expand Up @@ -239,7 +240,29 @@ def test_get_version_map_dependencies(project, compiler):
actual = compiler.get_version_map(paths, project=project)

fail_msg = f"versions: {', '.join([str(x) for x in actual])}"
assert len(actual) == 2, fail_msg
actual_len = len(actual)

# Expecting one old version for ImportOlderDependency and one version for Yearn stuff.
expected_len = 2

if actual_len > expected_len:
# Weird anomaly in CI/CD tests sometimes (at least at the time of write).
# Including additional debug information.
alt_map: dict = {}
for version, src_ids in actual.items():
for src_id in src_ids:
if src_id in alt_map:
other_version = alt_map[src_id]
versions_str = ", ".join([str(other_version), str(version)])
pytest.fail(f"{src_id} in multiple version '{versions_str}'")
else:
alt_map[src_id] = version

# No duplicated versions found but still have unexpected extras.
pytest.fail(f"Unexpected number of versions. {fail_msg}")

elif actual_len < expected_len:
pytest.fail(fail_msg)

versions = sorted(list(actual.keys()))
older = versions[0] # Via ImportOlderDependency
Expand Down Expand Up @@ -355,8 +378,13 @@ def test_get_compiler_settings(project, compiler):
"@browniedependency=contracts/.cache/browniedependency/local",
"@dependency=contracts/.cache/dependency/local",
"@dependencyofdependency=contracts/.cache/dependencyofdependency/local",
"@noncompilingdependency=contracts/.cache/noncompilingdependency/local",
# This remapping below was auto-corrected because imports were excluding contracts/ suffix.
"@noncompilingdependency=contracts/.cache/noncompilingdependency/local/contracts",
"@safe=contracts/.cache/safe/1.3.0",
"browniedependency=contracts/.cache/browniedependency/local",
"dependency=contracts/.cache/dependency/local",
"dependencyofdependency=contracts/.cache/dependencyofdependency/local",
"safe=contracts/.cache/safe/1.3.0",
]

# Set in config.
Expand All @@ -369,6 +397,7 @@ def test_get_compiler_settings(project, compiler):
"contracts/.cache/dependency/local/contracts/Dependency.sol",
"contracts/.cache/dependencyofdependency/local/contracts/DependencyOfDependency.sol",
"contracts/.cache/noncompilingdependency/local/contracts/CompilingContract.sol",
"contracts/.cache/noncompilingdependency/local/contracts/subdir/SubCompilingContract.sol",
"contracts/.cache/safe/1.3.0/contracts/common/Enum.sol",
"contracts/CompilesOnce.sol",
"contracts/Imports.sol",
Expand Down

0 comments on commit f378390

Please sign in to comment.