Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement flattening api [APE-794] #107

Merged
merged 31 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f626626
implement flattening api
z80dev Jun 5, 2023
8a0d320
run linter
z80dev Jun 6, 2023
833d9e5
fix: run isort
z80dev Jun 6, 2023
1dcb247
fix: existing error (missing dependency remapping)
Jun 15, 2023
f003cb0
fix: strip 'import' statements & include filename
Jun 15, 2023
d456d14
fix: check for no file path (mypi)
Jun 15, 2023
7901ade
fix: lint
Jun 15, 2023
540f641
implement license detection & fix filename order
Jun 16, 2023
5aaec5e
lint
Jun 16, 2023
ea5c080
rename function and sort imports
Jun 19, 2023
7581e32
Update ape_solidity/compiler.py
z80dev Jun 22, 2023
572c9e0
pre-compile regex for perf
Jun 22, 2023
460c2ba
add test comparing against pre-flattened contracts
Jun 22, 2023
2bd6cc6
fix: regex multiline flag at comp time
Jun 22, 2023
89f9cb9
fix: refactor process_licenses
Jun 22, 2023
52e50f5
fix: rename imports function
Jun 22, 2023
2146547
fix: rename fn to internal name
Jun 22, 2023
8faf5db
fix test diff
Jun 22, 2023
ecd214c
fix: lint
Jun 22, 2023
e7af728
fix: type annotation on get_licenses
Jun 22, 2023
2add5d2
fix: rename flattened file
Jun 22, 2023
3da8bc8
fix: move flattened file again
Jun 22, 2023
c33fd3e
fix: remove extra newline
Jun 22, 2023
130019a
update test file names
Jun 22, 2023
27b6c9b
Merge branch 'main' into implement-flattening
antazoey Jun 26, 2023
a612afc
update tests
Jun 28, 2023
3cee86c
fix: compile regex and remove unused import
Jun 28, 2023
3b36696
fix: leave only version pragma from target file
Jun 28, 2023
a17143c
fix: lint
Jun 28, 2023
0b7dfe4
fix: type on pragma fn
Jun 28, 2023
b9eb249
Merge branch 'main' into implement-flattening
antazoey Jun 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 150 additions & 44 deletions ape_solidity/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from eth_utils import add_0x_prefix, is_0x_prefixed
from ethpm_types import ASTNode, HexBytes, PackageManifest
from ethpm_types.ast import ASTClassification
from ethpm_types.source import Content
from pkg_resources import get_distribution
from requests.exceptions import ConnectionError
from semantic_version import NpmSpec, Version # type: ignore
Expand Down Expand Up @@ -453,63 +454,53 @@ def classify_ast(_node: ASTNode):

return contract_types

def get_imports(
self, contract_filepaths: List[Path], base_path: Optional[Path] = None
) -> Dict[str, List[str]]:
# NOTE: Process import remappings _before_ getting the full contract set.
def get_imports_with_raw_name(
z80dev marked this conversation as resolved.
Show resolved Hide resolved
self,
contract_filepaths: List[Path],
base_path: Optional[Path] = None,
) -> Dict[str, List[Tuple[str, str]]]:
contracts_path = base_path or self.config_manager.contracts_folder
import_remapping = self.get_import_remapping(base_path=contracts_path)
contract_filepaths_set = verify_contract_filepaths(contract_filepaths)

def import_str_to_source_id(_import_str: str, source_path: Path) -> str:
quote = '"' if '"' in _import_str else "'"

try:
end_index = _import_str.index(quote) + 1
except ValueError as err:
raise CompilerError(
f"Error parsing import statement '{_import_str}' in '{source_path.name}'."
) from err

import_str_prefix = _import_str[end_index:]
import_str_value = import_str_prefix[: import_str_prefix.index(quote)]
path = (source_path.parent / import_str_value).resolve()
source_id_value = str(get_relative_path(path, contracts_path))

# Get all matches.
matches: List[Tuple[str, str]] = []
for key, value in import_remapping.items():
if key not in source_id_value:
continue

matches.append((key, value))

if not matches:
return source_id_value
imports_dict: Dict[str, List[Tuple[str, str]]] = {}
for src_path, import_strs in get_import_lines(contract_filepaths_set).items():
import_list = []
for import_str in import_strs:
raw_import_item_search = re.search(r'"(.*?)"', import_str)
if raw_import_item_search is None:
raise CompilerError(f"No target filename found in import {import_str}")
raw_import_item = raw_import_item_search.group(1)
import_item = import_str_to_source_id(
import_str, src_path, contracts_path, import_remapping
)

# Convert remapping list back to source using longest match (most exact).
key, value = max(matches, key=lambda x: len(x[0]))
sections = [s for s in source_id_value.split(key) if s]
depth = len(sections) - 1
source_id_value = ""
# Only add to the list if it's not already there, to mimic set behavior
if (import_item, raw_import_item) not in import_list:
import_list.append((import_item, raw_import_item))

index = 0
for section in sections:
if index == depth:
source_id_value += value
source_id_value += section
elif index >= depth:
source_id_value += section
source_id = str(get_relative_path(src_path, contracts_path))
imports_dict[str(source_id)] = import_list

index += 1
return imports_dict

return source_id_value
def get_imports(
self,
contract_filepaths: List[Path],
base_path: Optional[Path] = None,
) -> Dict[str, List[str]]:
# NOTE: Process import remappings _before_ getting the full contract set.
contracts_path = base_path or self.config_manager.contracts_folder
import_remapping = self.get_import_remapping(base_path=contracts_path)
contract_filepaths_set = verify_contract_filepaths(contract_filepaths)

imports_dict: Dict[str, List[str]] = {}
for src_path, import_strs in get_import_lines(contract_filepaths_set).items():
import_set = set()
for import_str in import_strs:
import_item = import_str_to_source_id(import_str, src_path)
import_item = import_str_to_source_id(
import_str, src_path, contracts_path, import_remapping
)
import_set.add(import_item)

source_id = str(get_relative_path(src_path, contracts_path))
Expand Down Expand Up @@ -738,6 +729,74 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError:
txn=err.txn,
)

def _flatten_source(self, path: Path, base_path=None, raw_import_name=None) -> str:
z80dev marked this conversation as resolved.
Show resolved Hide resolved
base_path = base_path or self.config_manager.contracts_folder
imports = self.get_imports_with_raw_name([path])
source = ""
for import_list in imports.values():
for import_path, raw_import_path in sorted(import_list):
source += self._flatten_source(
base_path / import_path, base_path=base_path, raw_import_name=raw_import_path
)
if raw_import_name:
source += f"// File: {raw_import_name}\n"
else:
source += f"// File: {path.name}\n"
source += path.read_text() + "\n"
return source

def flatten_contract(self, path: Path, **kwargs) -> Content:
"""Flatten a contract.
antazoey marked this conversation as resolved.
Show resolved Hide resolved

Args:
path: Path to contract file.
**kwargs: Keyword arguments

Returns:
Content of flattened contract.
"""
# try compiling in order to validate it works
self.compile([path], base_path=self.config_manager.contracts_folder)
source = self._flatten_source(path)
antazoey marked this conversation as resolved.
Show resolved Hide resolved
source = remove_imports(source)
source = process_licenses(source)
lines = source.splitlines()
line_dict = {i + 1: line for i, line in enumerate(lines)}
return Content(__root__=line_dict)


def remove_imports(flattened_contract: str) -> str:
# Define a regex pattern that matches import statements
# Both single and multi-line imports will be matched
pattern = r"import\s+((.*?)(?=;)|[\s\S]*?from\s+(.*?)(?=;));\s"
antazoey marked this conversation as resolved.
Show resolved Hide resolved

# Use re.sub() to remove matched import statements
no_imports_contract = re.sub(pattern, "", flattened_contract, flags=re.MULTILINE)

return no_imports_contract


def get_licenses(source: str) -> List[str]:
pattern = r"(// SPDX-License-Identifier:\s*([^\n]*)\s)"
z80dev marked this conversation as resolved.
Show resolved Hide resolved
matches = re.findall(pattern, source)
return matches


def process_licenses(contract: str) -> str:
# Extract SPDX license identifiers
licenses = get_licenses(contract)

# Ensure all licenses are identical
unique_licenses = {license[1] for license in licenses}
if len(unique_licenses) > 1:
raise CompilerError(f"Conflicting licenses found: {unique_licenses}")

contract = contract.replace(licenses[0][0], "")
z80dev marked this conversation as resolved.
Show resolved Hide resolved

contract = f"// SPDX-License-Identifier: {licenses[0][1]}\n\n{contract}"

return contract


def _get_sol_panic(revert_message: str) -> Optional[Type[RuntimeErrorUnion]]:
if revert_message.startswith(RUNTIME_ERROR_CODE_PREFIX):
Expand All @@ -753,3 +812,50 @@ def _get_sol_panic(revert_message: str) -> Optional[Type[RuntimeErrorUnion]]:
return RUNTIME_ERROR_MAP[RuntimeErrorType(error_type_val)]

return None


def import_str_to_source_id(
z80dev marked this conversation as resolved.
Show resolved Hide resolved
_import_str: str, source_path: Path, base_path: Path, import_remapping: Dict[str, str]
) -> str:
quote = '"' if '"' in _import_str else "'"

try:
end_index = _import_str.index(quote) + 1
except ValueError as err:
raise CompilerError(
f"Error parsing import statement '{_import_str}' in '{source_path.name}'."
) from err

import_str_prefix = _import_str[end_index:]
import_str_value = import_str_prefix[: import_str_prefix.index(quote)]
path = (source_path.parent / import_str_value).resolve()
source_id_value = str(get_relative_path(path, base_path))

# Get all matches.
matches: List[Tuple[str, str]] = []
for key, value in import_remapping.items():
if key not in source_id_value:
continue

matches.append((key, value))

if not matches:
return source_id_value

# Convert remapping list back to source using longest match (most exact).
key, value = max(matches, key=lambda x: len(x[0]))
sections = [s for s in source_id_value.split(key) if s]
depth = len(sections) - 1
source_id_value = ""

index = 0
for section in sections:
if index == depth:
source_id_value += value
source_id_value += section
elif index >= depth:
source_id_value += section

index += 1

return source_id_value
3 changes: 3 additions & 0 deletions tests/ape-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
dependencies:
- name: TestDependency
local: ./Dependency
- name: DependencyOfDependency
local: ./DependencyOfDependency

# Make sure can use a Brownie project as a dependency
- name: BrownieDependency
Expand All @@ -27,6 +29,7 @@ solidity:
- "@remapping/contracts=TestDependency"
- "@remapping_2=TestDependency"
- "@remapping_2_brownie=BrownieDependency"
- "@dependency_remapping=DependencyOfDependency"

# Remapping for showing we can import a contract type from a brownie-style dependency
# (provided the _single_ contract type compiles in the project).
Expand Down