Skip to content

Commit

Permalink
fix: imports
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Apr 15, 2024
1 parent d7df5dc commit 38fe812
Showing 1 changed file with 42 additions and 10 deletions.
52 changes: 42 additions & 10 deletions ape_vyper/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,47 +235,79 @@ def evm_version(self) -> Optional[str]:
def get_imports(
self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None
) -> Dict[str, List[str]]:
base_path = (base_path or self.project_manager.contracts_folder).absolute()
contracts_path = (base_path or self.project_manager.contracts_folder).absolute()
return self._get_imports(contract_filepaths, base_path=contracts_path, handled=set())

def _get_imports(
self,
contract_filepaths: Sequence[Path],
base_path: Optional[Path] = None,
handled: Optional[Set[str]] = None,
):
handled = handled or set()
contracts_path = (base_path or self.project_manager.contracts_folder).absolute()
import_map = {}
for path in contract_filepaths:
if not path.is_file():
continue

content = path.read_text().splitlines()
source_id = str(get_relative_path(path.absolute(), base_path.absolute()))
source_id = str(get_relative_path(path.absolute(), contracts_path.absolute()))
if source_id in handled:
continue

handled.add(source_id)
for line in content:
if line.startswith("import "):
import_line_parts = line.replace("import ", "").split(" ")
prefix = import_line_parts[0].replace(".", os.path.sep)
prefix = import_line_parts[0]

elif line.startswith("from ") and " import " in line:
import_line_parts = line.replace("from ", "").strip().split(" ")
module_name = import_line_parts[0].strip().replace(".", os.path.sep)
module_name = import_line_parts[0].strip()
prefix = os.path.sep.join([module_name, import_line_parts[2].strip()])

else:
# Not an import line
continue

while f"{os.path.sep}{os.path.sep}" in prefix:
prefix = prefix.replace(f"{os.path.sep}{os.path.sep}", os.path.sep)
dots = ""
while prefix.startswith("."):
dots += prefix[0]
prefix = prefix[1:]

# Replace rest of dots with slashes.
prefix = prefix.replace(".", os.path.sep)

prefix = prefix.lstrip(os.path.sep)
full_path = (path.parent / dots / prefix.lstrip(os.path.sep)).resolve()
prefix = str(full_path).replace(f"{contracts_path}", "").lstrip(os.path.sep)

# NOTE: Defaults to JSON (assuming from input JSON or a local JSON),
# unless a Vyper file exists.
if (base_path / f"{prefix}{FileType.SOURCE}").is_file():
if (contracts_path / f"{prefix}{FileType.SOURCE}").is_file():
ext = FileType.SOURCE.value
elif (base_path / f"{prefix}{FileType.SOURCE}").is_file():
elif (contracts_path / f"{prefix}{FileType.SOURCE}").is_file():
ext = FileType.INTERFACE.value
elif (base_path / f"{prefix}{FileType.INTERFACE}").is_file():
elif (contracts_path / f"{prefix}{FileType.INTERFACE}").is_file():
ext = FileType.INTERFACE.value
else:
ext = ".json"

full_path = full_path.parent / f"{full_path.name}{ext}"

import_source_id = f"{prefix}{ext}"
if source_id not in import_map:
import_map[source_id] = [import_source_id]
elif import_source_id not in import_map[source_id]:
import_map[source_id].append(import_source_id)

# Also include imports of imports.
sub_imports = self._get_imports(
(full_path,), base_path=contracts_path, handled=handled
)
for sub_import_ls in sub_imports.values():
import_map[source_id].extend(sub_import_ls)

return import_map

def get_versions(self, all_paths: Sequence[Path]) -> Set[str]:
Expand Down

0 comments on commit 38fe812

Please sign in to comment.