Skip to content

Commit

Permalink
fix: issue with extra suffix parts
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Jun 10, 2024
1 parent f378390 commit a8d3787
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
32 changes: 24 additions & 8 deletions ape_solidity/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,15 @@ def _get_settings_from_imports(
files_by_solc_version = self.get_version_map_from_imports(
contract_filepaths, import_map, project=pm
)
return self._get_settings_from_version_map(files_by_solc_version, remappings, project=pm)
return self._get_settings_from_version_map(
files_by_solc_version, remappings, import_map=import_map, project=pm
)

def _get_settings_from_version_map(
self,
version_map: dict,
import_remappings: dict[str, str],
import_map: Optional[dict[str, list[str]]] = None,
project: Optional[ProjectManager] = None,
**kwargs,
) -> dict[Version, dict]:
Expand All @@ -397,7 +400,9 @@ def _get_settings_from_version_map(
},
**kwargs,
}
if remappings_used := self._get_used_remappings(sources, import_remappings, project=pm):
if remappings_used := self._get_used_remappings(
sources, import_remappings, import_map=import_map, project=pm
):
remappings_str = [f"{k}={v}" for k, v in remappings_used.items()]

# Standard JSON input requires remappings to be sorted.
Expand All @@ -421,6 +426,7 @@ def _get_used_remappings(
self,
sources: Iterable[Path],
remappings: dict[str, str],
import_map: Optional[dict[str, list[str]]] = None,
project: Optional[ProjectManager] = None,
) -> dict[str, str]:
pm = project or self.local_project
Expand All @@ -435,7 +441,8 @@ def _get_used_remappings(
# Filter out unused import remapping.
result = {}
sources = list(sources)
imports = self.get_imports(sources, project=pm).values()
import_map = import_map or self.get_imports(sources, project=pm)
imports = import_map.values()

for source_list in imports:
for src in source_list:
Expand All @@ -461,7 +468,7 @@ def get_standard_input_json(
import_map = self.get_imports_from_remapping(paths, remapping, project=pm)
version_map = self.get_version_map_from_imports(paths, import_map, project=pm)
return self.get_standard_input_json_from_version_map(
version_map, remapping, project=pm, **overrides
version_map, remapping, project=pm, import_map=import_map, **overrides
)

def get_standard_input_json_from(
Expand All @@ -481,12 +488,13 @@ def get_standard_input_json_from_version_map(
self,
version_map: dict[Version, set[Path]],
import_remapping: dict[str, str],
import_map: Optional[dict[str, list[str]]] = None,
project: Optional[ProjectManager] = None,
**overrides,
):
pm = project or self.local_project
settings = self._get_settings_from_version_map(
version_map, import_remapping, project=pm, **overrides
version_map, import_remapping, import_map=import_map, project=pm, **overrides
)
return self.get_standard_input_json_from_settings(settings, version_map, project=pm)

Expand Down Expand Up @@ -571,8 +579,16 @@ def _compile(
settings: Optional[dict] = None,
):
pm = project or self.local_project
input_jsons = self.get_standard_input_json(
contract_filepaths, project=pm, **(settings or {})
remapping = self.get_import_remapping(project=pm)
paths = list(contract_filepaths) # Handle if given generator=
import_map = self.get_imports_from_remapping(paths, remapping, project=pm)
version_map = self.get_version_map_from_imports(paths, import_map, project=pm)
input_jsons = self.get_standard_input_json_from_version_map(
version_map,
remapping,
project=pm,
import_map=import_map,
**(settings or {}),
)
contract_versions: dict[str, Version] = {}
contract_types: list[ContractType] = []
Expand Down Expand Up @@ -608,7 +624,7 @@ def _compile(
for name, _ in contracts_out.items():
# Filter source files that the user did not ask for, such as
# imported relative files that are not part of the input.
for input_file_path in contract_filepaths:
for input_file_path in paths:
if source_id in str(input_file_path):
input_contract_names.append(name)

Expand Down
4 changes: 4 additions & 0 deletions tests/contracts/Imports.sol
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ import "@safe/contracts/common/Enum.sol";
// Purposely exclude the contracts folder to test older Ape-style project imports.
import "@noncompilingdependency/subdir/SubCompilingContract.sol";

// Showing sources with extra extensions are by default excluded,
// unless used as an import somewhere in a non-excluded source.
import "./Source.extra.ext.sol";

contract Imports {
function foo() pure public returns(bool) {
return true;
Expand Down
10 changes: 10 additions & 0 deletions tests/contracts/Source.extra.ext.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;

// Showing sources with extra extensions are by default excluded,
// unless used as an import somewhere in a non-excluded source.
contract SourceExtraExt {
function foo() pure public returns(bool) {
return true;
}
}
4 changes: 3 additions & 1 deletion tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ape import Project, reverts
from ape.exceptions import CompilerError
from ape.logging import LogLevel
from ape.utils import get_full_extension
from ethpm_types import ContractType
from packaging.version import Version

Expand Down Expand Up @@ -126,6 +127,7 @@ def test_get_imports_complex(project, compiler):
"contracts/CompilesOnce.sol",
"contracts/MissingPragma.sol",
"contracts/NumerousDefinitions.sol",
"contracts/Source.extra.ext.sol",
"contracts/subfolder/Relativecontract.sol",
],
"contracts/MissingPragma.sol": [],
Expand Down Expand Up @@ -623,7 +625,7 @@ def test_compile_project(project, compiler):
"""
Simple test showing the full project indeed compiles.
"""
paths = [x for x in project.sources.paths if x.suffix == ".sol"]
paths = [x for x in project.sources.paths if get_full_extension(x) == ".sol"]
actual = [c for c in compiler.compile(paths, project=project)]
assert len(actual) > 0

Expand Down

0 comments on commit a8d3787

Please sign in to comment.