Skip to content

Commit

Permalink
fix: issue with extra suffix parts in imported sources (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jun 13, 2024
1 parent f378390 commit a84cf38
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 140 deletions.
116 changes: 81 additions & 35 deletions ape_solidity/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,15 @@
# Define a regex pattern that matches import statements
# Both single and multi-line imports will be matched
IMPORTS_PATTERN = re.compile(
r"import\s+((.*?)(?=;)|[\s\S]*?from\s+(.*?)(?=;));\s", flags=re.MULTILINE
r"import\s+(([\s\S]*?)(?=;)|[\s\S]*?from\s+([^\s;]+));\s*", flags=re.MULTILINE
)
LICENSES_PATTERN = re.compile(r"(// SPDX-License-Identifier:\s*([^\n]*)\s)")

# Comment patterns
SINGLE_LINE_COMMENT_PATTERN = re.compile(r"^\s*//")
MULTI_LINE_COMMENT_START_PATTERN = re.compile(r"/\*")
MULTI_LINE_COMMENT_END_PATTERN = re.compile(r"\*/")

VERSION_PRAGMA_PATTERN = re.compile(r"pragma solidity[^;]*;")
DEFAULT_OPTIMIZATION_RUNS = 200

Expand Down Expand Up @@ -142,7 +148,7 @@ class SolidityConfig(PluginConfig):
def _get_flattened_source(path: Path, name: Optional[str] = None) -> str:
name = name or path.name
result = f"// File: {name}\n"
result += path.read_text() + "\n"
result += f"{path.read_text().rstrip()}\n"
return result


Expand Down Expand Up @@ -373,12 +379,15 @@ def _get_settings_from_imports(
files_by_solc_version = self.get_version_map_from_imports(
contract_filepaths, import_map, project=pm
)
return self._get_settings_from_version_map(files_by_solc_version, remappings, project=pm)
return self._get_settings_from_version_map(
files_by_solc_version, remappings, import_map=import_map, project=pm
)

def _get_settings_from_version_map(
self,
version_map: dict,
import_remappings: dict[str, str],
import_map: Optional[dict[str, list[str]]] = None,
project: Optional[ProjectManager] = None,
**kwargs,
) -> dict[Version, dict]:
Expand All @@ -397,7 +406,9 @@ def _get_settings_from_version_map(
},
**kwargs,
}
if remappings_used := self._get_used_remappings(sources, import_remappings, project=pm):
if remappings_used := self._get_used_remappings(
sources, import_remappings, import_map=import_map, project=pm
):
remappings_str = [f"{k}={v}" for k, v in remappings_used.items()]

# Standard JSON input requires remappings to be sorted.
Expand All @@ -421,6 +432,7 @@ def _get_used_remappings(
self,
sources: Iterable[Path],
remappings: dict[str, str],
import_map: Optional[dict[str, list[str]]] = None,
project: Optional[ProjectManager] = None,
) -> dict[str, str]:
pm = project or self.local_project
Expand All @@ -435,7 +447,8 @@ def _get_used_remappings(
# Filter out unused import remapping.
result = {}
sources = list(sources)
imports = self.get_imports(sources, project=pm).values()
import_map = import_map or self.get_imports(sources, project=pm)
imports = import_map.values()

for source_list in imports:
for src in source_list:
Expand All @@ -461,32 +474,20 @@ def get_standard_input_json(
import_map = self.get_imports_from_remapping(paths, remapping, project=pm)
version_map = self.get_version_map_from_imports(paths, import_map, project=pm)
return self.get_standard_input_json_from_version_map(
version_map, remapping, project=pm, **overrides
version_map, remapping, project=pm, import_map=import_map, **overrides
)

def get_standard_input_json_from(
self,
version_map: dict[Version, set[Path]],
import_remappings: dict[str, str],
project: Optional[ProjectManager] = None,
**overrides,
):
pm = project or self.local_project
settings = self._get_settings_from_version_map(
version_map, import_remappings, project=pm, **overrides
)
return self.get_standard_input_json_from_settings(settings, version_map, project=pm)

def get_standard_input_json_from_version_map(
self,
version_map: dict[Version, set[Path]],
import_remapping: dict[str, str],
import_map: Optional[dict[str, list[str]]] = None,
project: Optional[ProjectManager] = None,
**overrides,
):
pm = project or self.local_project
settings = self._get_settings_from_version_map(
version_map, import_remapping, project=pm, **overrides
version_map, import_remapping, import_map=import_map, project=pm, **overrides
)
return self.get_standard_input_json_from_settings(settings, version_map, project=pm)

Expand Down Expand Up @@ -571,8 +572,16 @@ def _compile(
settings: Optional[dict] = None,
):
pm = project or self.local_project
input_jsons = self.get_standard_input_json(
contract_filepaths, project=pm, **(settings or {})
remapping = self.get_import_remapping(project=pm)
paths = list(contract_filepaths) # Handle if given generator=
import_map = self.get_imports_from_remapping(paths, remapping, project=pm)
version_map = self.get_version_map_from_imports(paths, import_map, project=pm)
input_jsons = self.get_standard_input_json_from_version_map(
version_map,
remapping,
project=pm,
import_map=import_map,
**(settings or {}),
)
contract_versions: dict[str, Version] = {}
contract_types: list[ContractType] = []
Expand Down Expand Up @@ -608,7 +617,7 @@ def _compile(
for name, _ in contracts_out.items():
# Filter source files that the user did not ask for, such as
# imported relative files that are not part of the input.
for input_file_path in contract_filepaths:
for input_file_path in paths:
if source_id in str(input_file_path):
input_contract_names.append(name)

Expand Down Expand Up @@ -1096,14 +1105,17 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError:

def _flatten_source(
self,
path: Path,
path: Union[Path, str],
project: Optional[ProjectManager] = None,
raw_import_name: Optional[str] = None,
handled: Optional[set[str]] = None,
) -> str:
pm = project or self.local_project
handled = handled or set()
source_id = f"{get_relative_path(path, pm.path)}"

path = Path(path)
source_id = f"{get_relative_path(path, pm.path)}" if path.is_absolute() else f"{path}"

handled.add(source_id)
remapping = self.get_import_remapping(project=project)
imports = self._get_imports((path,), remapping, pm, tracked=set(), include_raw=True)
Expand All @@ -1116,26 +1128,36 @@ def _flatten_source(
continue

sub_import_name = import_str.replace("import ", "").strip(" \n\t;\"'")
final_source += self._flatten_source(
sub_source = self._flatten_source(
pm.path / source_id,
project=pm,
raw_import_name=sub_import_name,
handled=handled,
)
final_source += sub_source

flattened_src = _get_flattened_source(path, name=raw_import_name)
if flattened_src and final_source.rstrip():
final_source = f"{final_source.rstrip()}\n\n{flattened_src}"
elif flattened_src:
final_source = flattened_src

final_source += _get_flattened_source(path, name=raw_import_name)
return final_source

def flatten_contract(
self, path: Path, project: Optional[ProjectManager] = None, **kwargs
) -> Content:
# try compiling in order to validate it works
res = self._flatten_source(path, project=project)
res = remove_imports(res)
res = process_licenses(res)
res = remove_version_pragmas(res)
pragma = get_first_version_pragma(path.read_text())
res = "\n".join([pragma, res])

# Simple auto-format.
while "\n\n\n" in res:
res = res.replace("\n\n\n", "\n\n")

lines = res.splitlines()
line_dict = {i + 1: line for i, line in enumerate(lines)}
return Content(root=line_dict)
Expand Down Expand Up @@ -1244,11 +1266,37 @@ def _import_str_to_source_id(
return f"{get_relative_path(path.absolute(), pm.path)}"


def remove_imports(flattened_contract: str) -> str:
# Use regex.sub() to remove matched import statements
no_imports_contract = IMPORTS_PATTERN.sub("", flattened_contract)
def remove_imports(source_code: str) -> str:
in_multi_line_comment = False
result_lines = []

lines = source_code.splitlines()
for line in lines:
# Check if we're entering a multi-line comment
if MULTI_LINE_COMMENT_START_PATTERN.search(line):
in_multi_line_comment = True

# If inside a multi-line comment, just add the line to the result
if in_multi_line_comment:
result_lines.append(line)
# Check if this line ends the multi-line comment
if MULTI_LINE_COMMENT_END_PATTERN.search(line):
in_multi_line_comment = False
continue

# Skip single-line comments
if SINGLE_LINE_COMMENT_PATTERN.match(line):
result_lines.append(line)
continue

# Skip import statements in non-comment lines
if IMPORTS_PATTERN.search(line):
continue

# Add the line to the result if it's not an import statement
result_lines.append(line)

return no_imports_contract
return "\n".join(result_lines)


def remove_version_pragmas(flattened_contract: str) -> str:
Expand Down Expand Up @@ -1285,9 +1333,7 @@ def process_licenses(contract: str) -> str:
license_line, root_license = extracted_licenses[-1]

# Get the unique license identifiers. All licenses in a contract _should_ be the same.
unique_license_identifiers = {
license_identifier for _, license_identifier in extracted_licenses
}
unique_license_identifiers = {lid for _, lid in extracted_licenses}

# If we have more than one unique license identifier, warn the user and use the root.
if len(unique_license_identifiers) > 1:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
include_package_data=True,
install_requires=[
"py-solc-x>=2.0.2,<3",
"eth-ape>=0.8.1,<0.9",
"eth-ape>=0.8.4,<0.9",
"ethpm-types", # Use the version ape requires
"eth-pydantic-types", # Use the version ape requires
"packaging", # Use the version ape requires
Expand Down
4 changes: 4 additions & 0 deletions tests/contracts/Imports.sol
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ import "@safe/contracts/common/Enum.sol";
// Purposely exclude the contracts folder to test older Ape-style project imports.
import "@noncompilingdependency/subdir/SubCompilingContract.sol";

// Showing sources with extra extensions are by default excluded,
// unless used as an import somewhere in a non-excluded source.
import "./Source.extra.ext.sol";

contract Imports {
function foo() pure public returns(bool) {
return true;
Expand Down
10 changes: 10 additions & 0 deletions tests/contracts/Source.extra.ext.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;

// Showing sources with extra extensions are by default excluded,
// unless used as an import somewhere in a non-excluded source.
contract SourceExtraExt {
function foo() pure public returns(bool) {
return true;
}
}
4 changes: 0 additions & 4 deletions tests/data/ImportingLessConstrainedVersionFlat.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ pragma solidity =0.8.12;

// File: ./SpecificVersionRange.sol



contract SpecificVersionRange {
function foo() pure public returns(bool) {
return true;
Expand All @@ -13,8 +11,6 @@ contract SpecificVersionRange {

// File: ImportingLessConstrainedVersion.sol



// The file we are importing specific range '>=0.8.12 <0.8.15';
// This means on its own, the plugin would use 0.8.14 if its installed.
// However - it should use 0.8.12 because of this file's requirements.
Expand Down
Loading

0 comments on commit a84cf38

Please sign in to comment.