diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 84714f16..7fdf26eb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -68,7 +68,7 @@ jobs: # TODO: Replace with macos-latest when works again. # https://github.com/actions/setup-python/issues/808 os: [ubuntu-latest, macos-12] # eventually add `windows-latest` - python-version: [3.8, 3.9, "3.10", "3.11", "3.12"] + python-version: [3.9, "3.10", "3.11", "3.12"] env: GETH_VERSION: 1.12.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 23d53854..47442f7b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: rev: 0.7.17 hooks: - id: mdformat - additional_dependencies: [mdformat-gfm, mdformat-frontmatter] + additional_dependencies: [mdformat-gfm, mdformat-frontmatter, mdformat-pyproject] default_language_version: python: python3 diff --git a/README.md b/README.md index fd688607..5b548838 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Ape compiler plugin around [VVM](https://github.com/vyperlang/vvm) ## Dependencies -- [python3](https://www.python.org/downloads) version 3.8 up to 3.12. +- [python3](https://www.python.org/downloads) version 3.9 up to 3.12. ## Installation diff --git a/ape_vyper/_cli.py b/ape_vyper/_cli.py index 665bfff1..bfddd372 100644 --- a/ape_vyper/_cli.py +++ b/ape_vyper/_cli.py @@ -2,7 +2,7 @@ import ape import click -from ape.cli import ape_cli_context +from ape.cli import ape_cli_context, project_option @click.group @@ -12,14 +12,17 @@ def cli(): @cli.command(short_help="Flatten select contract source files") @ape_cli_context() +@project_option() @click.argument("CONTRACT", type=click.Path(exists=True, resolve_path=True)) @click.argument("OUTFILE", type=click.Path(exists=False, resolve_path=True, writable=True)) -def flatten(cli_ctx, contract: Path, outfile: Path): +def flatten(cli_ctx, project, contract: Path, outfile: Path): """ Flatten a contract into a single file """ with Path(outfile).open("w") as fout: content = ape.compilers.vyper.flatten_contract( - Path(contract), base_path=ape.project.contracts_folder + Path(contract), + base_path=ape.project.contracts_folder, + project=project, ) fout.write(str(content)) diff --git a/ape_vyper/ast.py b/ape_vyper/ast.py index 234ca492..e9b40838 100644 --- a/ape_vyper/ast.py +++ b/ape_vyper/ast.py @@ -1,7 +1,5 @@ """Utilities for dealing with Vyper AST""" -from typing import List - from ethpm_types import ABI, MethodABI from ethpm_types.abi import ABIType from vyper.ast import parse_to_ast # type: ignore @@ -16,11 +14,11 @@ } -def funcdef_decorators(funcdef: FunctionDef) -> List[str]: +def funcdef_decorators(funcdef: FunctionDef) -> list[str]: return [d.id for d in funcdef.get("decorator_list") or []] -def funcdef_inputs(funcdef: FunctionDef) -> List[ABIType]: +def funcdef_inputs(funcdef: FunctionDef) -> list[ABIType]: """Get a FunctionDef's defined input args""" args = funcdef.get("args") # TODO: Does Vyper allow complex input types, like structs and arrays? @@ -31,7 +29,7 @@ def funcdef_inputs(funcdef: FunctionDef) -> List[ABIType]: ) -def funcdef_outputs(funcdef: FunctionDef) -> List[ABIType]: +def funcdef_outputs(funcdef: FunctionDef) -> list[ABIType]: """Get a FunctionDef's outputs, or return values""" returns = funcdef.get("returns") @@ -46,9 +44,9 @@ def funcdef_outputs(funcdef: FunctionDef) -> List[ABIType]: elif isinstance(returns, Subscript): # An array type length = returns.slice.value.value - array_type = returns.value.id - # TOOD: Is this an acurrate way to define a fixed length array for ABI? - return [ABIType.model_validate({"type": f"{array_type}[{length}]"})] + if array_type := getattr(returns.value, "id", None): + # TOOD: Is this an accurate way to define a fixed length array for ABI? + return [ABIType.model_validate({"type": f"{array_type}[{length}]"})] raise NotImplementedError(f"Unhandled return type {type(returns)}") @@ -81,7 +79,7 @@ def funcdef_to_abi(func: FunctionDef) -> ABI: ) -def module_to_abi(module: Module) -> List[ABI]: +def module_to_abi(module: Module) -> list[ABI]: """ Create a list of MethodABIs from a Vyper AST Module instance. """ @@ -92,7 +90,7 @@ def module_to_abi(module: Module) -> List[ABI]: return abi -def source_to_abi(source: str) -> List[ABI]: +def source_to_abi(source: str) -> list[ABI]: """ Given Vyper source code, return a list of Ape ABI elements needed for an external interface. This currently does not include complex types or events. diff --git a/ape_vyper/compiler.py b/ape_vyper/compiler.py index 85fb3b58..5a9aff2b 100644 --- a/ape_vyper/compiler.py +++ b/ape_vyper/compiler.py @@ -3,28 +3,39 @@ import shutil import time from base64 import b64encode +from collections import defaultdict +from collections.abc import Iterable, Iterator from fnmatch import fnmatch from importlib import import_module from pathlib import Path -from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import Any, Optional, Union, cast import vvm # type: ignore -from ape.api import PluginConfig +from ape.api import PluginConfig, TraceAPI from ape.api.compiler import CompilerAPI from ape.exceptions import ContractLogicError from ape.logging import logger -from ape.types import ContractSourceCoverage, ContractType, SourceTraceback, TraceFrame -from ape.utils import GithubClient, cached_property, get_relative_path, pragma_str_to_specifier_set +from ape.managers.project import ProjectManager +from ape.types import ContractSourceCoverage, ContractType, SourceTraceback +from ape.utils import ( + cached_property, + get_full_extension, + get_relative_path, + pragma_str_to_specifier_set, +) +from ape.utils._github import _GithubClient from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed from ethpm_types import ASTNode, PackageManifest, PCMap, SourceMapItem from ethpm_types.ast import ASTClassification from ethpm_types.contract_type import SourceMap from ethpm_types.source import Compiler, Content, ContractSource, Function, SourceLocation +from evm_trace import TraceFrame from evm_trace.enums import CALL_OPCODES +from evm_trace.geth import create_call_node_data from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version -from pydantic import field_serializer, field_validator +from pydantic import field_serializer, field_validator, model_validator from vvm import compile_standard as vvm_compile_standard from vvm.exceptions import VyperError # type: ignore @@ -68,6 +79,38 @@ } +class Remapping(PluginConfig): + key: str + dependency_name: str + dependency_version: Optional[None] = None + + @model_validator(mode="before") + @classmethod + def validate_str(cls, value): + if isinstance(value, str): + parts = value.split("=") + key = parts[0].strip() + value = parts[1].strip() + if "@" in value: + value_parts = value.split("@") + dep_name = value_parts[0].strip() + dep_version = value_parts[1].strip() + else: + dep_name = value + dep_version = None + + return {"key": key, "dependency_name": dep_name, "dependency_version": dep_version} + + return value + + def __str__(self) -> str: + value = self.dependency_name + if _version := self.dependency_version: + value = f"{value}@{_version}" + + return f"{self.key}={value}" + + class VyperConfig(PluginConfig): version: Optional[SpecifierSet] = None """ @@ -80,7 +123,7 @@ class VyperConfig(PluginConfig): The evm-version or hard-fork name. """ - import_remapping: List[str] = [] + import_remapping: list[Remapping] = [] """ Configuration of an import name mapped to a dependency listing. To use a specific version of a dependency, specify using ``@`` symbol. @@ -125,7 +168,7 @@ def get_version_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]: Returns: ``packaging.specifiers.SpecifierSet``, or None if no valid pragma is found. """ - _version_pragma_patterns: Tuple[str, str] = ( + _version_pragma_patterns: tuple[str, str] = ( r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", r"(?:\n|^)\s*#\s*pragma\s+version\s*([^\n]*)", ) @@ -185,9 +228,9 @@ def get_evmversion_pragma(source: Union[str, Path]) -> Optional[str]: def get_optimization_pragma_map( - contract_filepaths: Sequence[Path], base_path: Path -) -> Dict[str, Optimization]: - pragma_map: Dict[str, Optimization] = {} + contract_filepaths: Iterable[Path], base_path: Path +) -> dict[str, Optimization]: + pragma_map: dict[str, Optimization] = {} for path in contract_filepaths: pragma = get_optimization_pragma(path) or True @@ -198,9 +241,9 @@ def get_optimization_pragma_map( def get_evm_version_pragma_map( - contract_filepaths: Sequence[Path], base_path: Path -) -> Dict[str, str]: - pragmas: Dict[str, str] = {} + contract_filepaths: Iterable[Path], base_path: Path +) -> dict[str, str]: + pragmas: dict[str, str] = {} for path in contract_filepaths: pragma = get_evmversion_pragma(path) if not pragma: @@ -217,22 +260,17 @@ class VyperCompiler(CompilerAPI): def name(self) -> str: return "vyper" - @property - def settings(self) -> VyperConfig: - return cast(VyperConfig, super().settings) - - @property - def evm_version(self) -> Optional[str]: - return self.settings.evm_version - def get_imports( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> Dict[str, List[str]]: - base_path = (base_path or self.project_manager.contracts_folder).absolute() - import_map = {} + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + ) -> dict[str, list[str]]: + pm = project or self.local_project + import_map: defaultdict = defaultdict(list) + dependencies = self.get_dependencies(project=pm) for path in contract_filepaths: content = path.read_text().splitlines() - source_id = str(get_relative_path(path.absolute(), base_path.absolute())) + source_id = str(get_relative_path(path.absolute(), pm.path.absolute())) for line in content: if line.startswith("import "): import_line_parts = line.replace("import ", "").split(" ") @@ -249,17 +287,44 @@ def get_imports( # NOTE: Defaults to JSON (assuming from input JSON or a local JSON), # unless a Vyper file exists. - ext = "vy" if (base_path / f"{suffix}.vy").is_file() else "json" + import_source_id = None + if (pm.interfaces_folder.parent / f"{suffix}.vy").is_file(): + import_source_id = f"{suffix}.vy" + + elif (pm.interfaces_folder.parent / f"{suffix}.json").is_file(): + import_source_id = f"{suffix}.json" + + elif suffix.startswith(f"vyper{os.path.sep}"): + # Vyper built-ins. + import_source_id = f"{suffix}.json" + + elif suffix.split(os.path.sep)[0] in dependencies: + dependency_name = suffix.split(os.path.sep)[0] + filestem = suffix.replace(f"{dependency_name}{os.path.sep}", "") + for version_str, dep_project in pm.dependencies[dependency_name].items(): + dependency = pm.dependencies.get_dependency(dependency_name, version_str) + path_id = dependency.package_id.replace("/", "_") + dependency_source_prefix = ( + f"{get_relative_path(dep_project.contracts_folder, dep_project.path)}" + ) + source_id_stem = f"{dependency_source_prefix}{os.path.sep}{filestem}" + for ext in (".vy", ".json"): + if f"{source_id_stem}{ext}" in dep_project.sources: + import_source_id = os.path.sep.join( + (path_id, version_str, f"{source_id_stem}{ext}") + ) + break + + else: + logger.error(f"Unable to find dependency {suffix}") + continue - import_source_id = f"{suffix}.{ext}" - if source_id not in import_map: - import_map[source_id] = [import_source_id] - elif import_source_id not in import_map[source_id]: + if import_source_id and import_source_id not in import_map[source_id]: import_map[source_id].append(import_source_id) - return import_map + return dict(import_map) - def get_versions(self, all_paths: Sequence[Path]) -> Set[str]: + def get_versions(self, all_paths: Iterable[Path]) -> set[str]: versions = set() for path in all_paths: if version_spec := get_version_pragma_spec(path): @@ -293,14 +358,14 @@ def package_version(self) -> Optional[Version]: return Version(version_str) if version_str else None @cached_property - def available_versions(self) -> List[Version]: + def available_versions(self) -> list[Version]: # NOTE: Package version should already be included in available versions max_retries = 10 buffer = 1 times_tried = 0 result = [] headers = None - if token := os.environ.get(GithubClient.TOKEN_KEY): + if token := os.environ.get(_GithubClient.TOKEN_KEY): auth = b64encode(token.encode()).decode() headers = {"Authorization": f"Basic {auth}"} @@ -332,7 +397,7 @@ def available_versions(self) -> List[Version]: return result @property - def installed_versions(self) -> List[Version]: + def installed_versions(self) -> list[Version]: # Doing this so it prefers package version package_version = self.package_version versions = [package_version] if package_version else [] @@ -348,62 +413,76 @@ def vyper_json(self): except ImportError: return None - @property - def config_version_pragma(self) -> Optional[SpecifierSet]: - if version := self.settings.version: - return version - - return None + def get_dependencies( + self, project: Optional[ProjectManager] = None + ) -> dict[str, ProjectManager]: + pm = project or self.local_project + config = self.get_config(pm) + dependencies: dict[str, ProjectManager] = {} + handled: set[str] = set() + + # Add remappings from config. + for remapping in config.import_remapping: + name = remapping.dependency_name + if not (_version := remapping.dependency_version): + versions = pm.dependencies[name] + if len(versions) == 1: + _version = versions[0] + else: + continue - @property - def remapped_manifests(self) -> Dict[str, PackageManifest]: - """ - Interface import manifests. - """ + dependency = pm.dependencies.get_dependency(name, _version) + dep_id = f"{dependency.name}_{dependency.version}" + if dep_id in handled: + continue - dependencies: Dict[str, PackageManifest] = {} + handled.add(dep_id) - for remapping in self.settings.import_remapping: - key, value = remapping.split("=") + try: + dependency.compile() + except Exception as err: + logger.warning( + f"Failed to compile dependency '{dependency.name}' @ '{dependency.version}'.\n" + f"Reason: {err}" + ) + continue - if remapping in dependencies: - dependency = dependencies[remapping] - else: - parts = value.split("@") - dep_name = parts[0] - dependency_versions = self.project_manager.dependencies[dep_name] - if not dependency_versions: - raise VyperCompileError(f"Missing dependency '{dep_name}'.") + dependencies[remapping.key] = dependency.project - elif len(parts) == 1 and len(dependency_versions) < 2: - # Use only version. - version = list(dependency_versions.keys())[0] + # Add auto-remapped dependencies. + for dependency in pm.dependencies.specified: + dep_id = f"{dependency.name}_{dependency.version}" + if dep_id in handled: + continue - elif parts[1] not in dependency_versions: - raise VyperCompileError(f"Missing dependency '{dep_name}'.") + handled.add(dep_id) - else: - version = parts[1] + try: + dependency.compile() + except Exception as err: + logger.warning( + f"Failed to compile dependency '{dependency.name}' @ '{dependency.version}'.\n" + f"Reason: {err}" + ) + continue - dependency = dependency_versions[version].compile() - dependencies[remapping] = dependency + dependencies[dependency.name] = dependency.project return dependencies - @property - def import_remapping(self) -> Dict[str, Dict]: + def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict[str, dict]: """ Configured interface imports from dependencies. """ - - interfaces = {} - - for remapping in self.settings.import_remapping: - key, _ = remapping.split("=") - for name, ct in (self.remapped_manifests[remapping].contract_types or {}).items(): - interfaces[f"{key}/{name}.json"] = { - "abi": [x.model_dump(mode="json", by_alias=True) for x in ct.abi] - } + pm = project or self.local_project + dependencies = self.get_dependencies(project=pm) + interfaces: dict[str, dict] = {} + for key, dependency_project in dependencies.items(): + manifest = dependency_project.manifest + for name, ct in (manifest.contract_types or {}).items(): + filename = f"{key}/{name}.json" + abi_list = [x.model_dump(mode="json", by_alias=True) for x in ct.abi] + interfaces[filename] = {"abi": abi_list} return interfaces @@ -415,36 +494,55 @@ def classify_ast(self, _node: ASTNode): self.classify_ast(child) def compile( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> List[ContractType]: - contract_types = [] - base_path = base_path or self.config_manager.contracts_folder + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + settings: Optional[dict] = None, + ) -> Iterator[ContractType]: + pm = project or self.local_project + settings = settings or {} sources = [p for p in contract_filepaths if p.parent.name != "interfaces"] - version_map = self.get_version_map(sources) - compiler_data = self._get_compiler_arguments(version_map, base_path) - all_settings = self.get_compiler_settings(sources, base_path=base_path) - contract_versions: Dict[str, Tuple[Version, str]] = {} + contract_types: list[ContractType] = [] + if version := settings.get("version", None): + version_map = {Version(version): set(sources)} + else: + version_map = self.get_version_map(sources, project=project) + + compiler_data = self._get_compiler_arguments(version_map, project=pm) + all_settings: dict = self.get_compiler_settings( + sources, project=project, **(settings or {}) + ) + contract_versions: dict[str, tuple[Version, str]] = {} + import_remapping = self.get_import_remapping(project=pm) for vyper_version, version_settings in all_settings.items(): - for settings_key, settings in version_settings.items(): - source_ids = settings["outputSelection"] - optimization_paths = {p: base_path / p for p in source_ids} - input_json = { + for settings_key, settings_set in version_settings.items(): + source_ids = settings_set["outputSelection"] + optimization_paths = {p: pm.path / p for p in source_ids} + input_json: dict = { "language": "Vyper", - "settings": settings, + "settings": settings_set, "sources": { s: {"content": p.read_text()} for s, p in optimization_paths.items() }, } - if interfaces := self.import_remapping: + if interfaces := import_remapping: input_json["interfaces"] = interfaces + # Output compiler details. + keys = ( + "\n\t".join(sorted([x for x in input_json.get("sources", {}).keys()])) + or "No input." + ) + log_str = f"Compiling using Vyper compiler '{vyper_version}'.\nInput:\n\t{keys}" + logger.info(log_str) + vyper_binary = compiler_data[vyper_version]["vyper_binary"] try: result = vvm_compile_standard( input_json, - base_path=base_path, + base_path=pm.path, vyper_version=vyper_version, vyper_binary=vyper_binary, ) @@ -454,7 +552,7 @@ def compile( for source_id, output_items in result["contracts"].items(): content = { i + 1: ln - for i, ln in enumerate((base_path / source_id).read_text().splitlines()) + for i, ln in enumerate((pm.path / source_id).read_text().splitlines()) } for name, output in output_items.items(): # De-compress source map to get PC POS map. @@ -505,9 +603,10 @@ def compile( ) contract_types.append(contract_type) contract_versions[name] = (vyper_version, settings_key) + yield contract_type # Output compiler data used. - compilers_used: Dict[Version, Dict[str, Compiler]] = {} + compilers_used: dict[Version, dict[str, Compiler]] = {} for ct in contract_types: if not ct.name: # Won't happen, but just for mypy. @@ -546,12 +645,12 @@ def compile( # NOTE: This method handles merging contractTypes and filtered out # no longer used Compilers. - self.project_manager.local_project.add_compiler_data(compilers_ls) - - return contract_types + pm.add_compiler_data(compilers_ls) - def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) -> ContractType: - base_path = base_path or self.project_manager.contracts_folder + def compile_code( + self, code: str, project: Optional[ProjectManager] = None, **kwargs + ) -> ContractType: + pm = project or self.local_project # Figure out what compiler version we need for this contract... version = self._source_vyper_version(code) @@ -559,7 +658,7 @@ def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) -> _install_vyper(version) try: - result = vvm.compile_source(code, base_path=base_path, vyper_version=version) + result = vvm.compile_source(code, base_path=pm.path, vyper_version=version) except Exception as err: raise VyperCompileError(str(err)) from err @@ -579,42 +678,43 @@ def first_full_release(versions: Iterable[Version]) -> Optional[Version]: for vers in versions: if not vers.is_devrelease and not vers.is_postrelease and not vers.is_prerelease: return vers + return None if version_spec is None: if version := first_full_release(self.installed_versions + self.available_versions): return version + raise VyperInstallError("No available version.") return next(version_spec.filter(self.available_versions)) - def _flatten_source( - self, path: Path, base_path: Optional[Path] = None, raw_import_name: Optional[str] = None - ) -> str: - base_path = base_path or self.config_manager.contracts_folder + def _flatten_source(self, path: Path, project: Optional[ProjectManager] = None) -> str: + pm = project or self.local_project # Get the non stdlib import paths for our contracts imports = list( filter( lambda x: not x.startswith("vyper/"), - [y for x in self.get_imports([path], base_path).values() for y in x], + [y for x in self.get_imports((path,), project=pm).values() for y in x], ) ) - dependencies: Dict[str, PackageManifest] = {} - for key, manifest in self.remapped_manifests.items(): + dependencies: dict[str, PackageManifest] = {} + dependency_projects = self.get_dependencies(project=pm) + for key, dependency_project in dependency_projects.items(): package = key.split("=")[0] - + base = dependency_project.path if hasattr(dependency_project, "path") else package + manifest = dependency_project.manifest if manifest.sources is None: continue for source_id in manifest.sources.keys(): - import_match = f"{package}/{source_id}" + import_match = f"{base}/{source_id}" dependencies[import_match] = manifest - flattened_source = "" interfaces_source = "" - og_source = (base_path / path).read_text() + og_source = (pm.path / path).read_text() # Get info about imports and source meta aliases = extract_import_aliases(og_source) @@ -622,20 +722,27 @@ def _flatten_source( stdlib_imports, _, source_without_imports = extract_imports(source_without_meta) for import_path in sorted(imports): - import_file = base_path / import_path + import_file = None + for base in (pm.path, pm.interfaces_folder): + for opt in {import_path, import_path.replace(f"interfaces{os.path.sep}", "")}: + try_import_file = base / opt + if try_import_file.is_file(): + import_file = try_import_file + break + + if import_file is None: + import_file = pm.path / import_path # Vyper imported interface names come from their file names file_name = iface_name_from_file(import_file) # If we have a known alias, ("import X as Y"), use the alias as interface name iface_name = aliases[file_name] if file_name in aliases else file_name - # We need to compare without extensions because sometimes they're made up for some - # reason. TODO: Cleaner way to deal with this? - def _match_source(import_path: str) -> Optional[PackageManifest]: - import_path_name = ".".join(import_path.split(".")[:-1]) + def _match_source(imp_path: str) -> Optional[PackageManifest]: for source_path in dependencies.keys(): - if source_path.startswith(import_path_name): + if source_path.endswith(imp_path): return dependencies[source_path] + return None if matched_source := _match_source(import_path): @@ -650,11 +757,10 @@ def _match_source(import_path: str) -> Optional[PackageManifest]: interfaces_source += generate_interface(abis, iface_name) continue - # Vyper imported interface names come from their file names - file_name = iface_name_from_file(import_file) # Generate an ABI from the source code - abis = source_to_abi(import_file.read_text()) - interfaces_source += generate_interface(abis, iface_name) + elif import_file.is_file(): + abis = source_to_abi(import_file.read_text()) + interfaces_source += generate_interface(abis, iface_name) def no_nones(it: Iterable[Optional[str]]) -> Iterable[str]: # Type guard like generator to remove Nones and make mypy happy @@ -675,23 +781,33 @@ def format_source(source: str) -> str: return format_source(flattened_source) - def flatten_contract(self, path: Path, base_path: Optional[Path] = None) -> Content: + def flatten_contract( + self, + path: Path, + project: Optional[ProjectManager] = None, + **kwargs, + ) -> Content: """ Returns the flattened contract suitable for compilation or verification as a single file """ - source = self._flatten_source(path, base_path, path.name) - return Content({i: ln for i, ln in enumerate(source.splitlines())}) + pm = project or self.local_project + src = self._flatten_source(path, project=pm) + return Content({i: ln for i, ln in enumerate(src.splitlines())}) def get_version_map( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> Dict[Version, Set[Path]]: - version_map: Dict[Version, Set[Path]] = {} - source_path_by_version_spec: Dict[SpecifierSet, Set[Path]] = {} + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + ) -> dict[Version, set[Path]]: + pm = project or self.local_project + config = self.get_config(pm) + version_map: dict[Version, set[Path]] = {} + source_path_by_version_spec: dict[SpecifierSet, set[Path]] = {} source_paths_without_pragma = set() # Sort contract_filepaths to promote consistent, reproduce-able behavior for path in sorted(contract_filepaths): - if config_spec := self.config_version_pragma: + if config_spec := config.version: _safe_append(source_path_by_version_spec, config_spec, path) elif pragma := get_version_pragma_spec(path): _safe_append(source_path_by_version_spec, pragma, path) @@ -747,31 +863,36 @@ def get_version_map( return version_map def get_compiler_settings( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> Dict[Version, Dict]: + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + **kwargs, + ) -> dict[Version, dict]: + pm = project or self.local_project valid_paths = [p for p in contract_filepaths if p.suffix == ".vy"] - contracts_path = base_path or self.config_manager.contracts_folder - files_by_vyper_version = self.get_version_map(valid_paths, base_path=contracts_path) + if version := kwargs.pop("version", None): + files_by_vyper_version = {Version(version): set(valid_paths)} + else: + files_by_vyper_version = self.get_version_map(valid_paths, project=pm) + if not files_by_vyper_version: return {} - compiler_data = self._get_compiler_arguments(files_by_vyper_version, contracts_path) + compiler_data = self._get_compiler_arguments(files_by_vyper_version, project=pm) settings = {} for version, data in compiler_data.items(): source_paths = list(files_by_vyper_version.get(version, [])) if not source_paths: continue - output_selection: Dict[str, Set[str]] = {} - optimizations_map = get_optimization_pragma_map(source_paths, contracts_path) - evm_version_map = get_evm_version_pragma_map(source_paths, contracts_path) - default_evm_version = ( - data.get("evm_version") - or data.get("evmVersion") - or EVM_VERSION_DEFAULT.get(version.base_version) - ) + output_selection: dict[str, set[str]] = {} + optimizations_map = get_optimization_pragma_map(source_paths, pm.path) + evm_version_map = get_evm_version_pragma_map(source_paths, pm.path) + default_evm_version = data.get( + "evm_version", data.get("evmVersion") + ) or EVM_VERSION_DEFAULT.get(version.base_version) for source_path in source_paths: - source_id = str(get_relative_path(source_path.absolute(), contracts_path)) + source_id = str(get_relative_path(source_path.absolute(), pm.path)) optimization = optimizations_map.get(source_id, True) evm_version = evm_version_map.get(source_id, default_evm_version) settings_key = f"{optimization}%{evm_version}".lower() @@ -780,7 +901,7 @@ def get_compiler_settings( else: output_selection[settings_key].add(source_id) - version_settings: Dict[str, Dict] = {} + version_settings: dict[str, dict] = {} for settings_key, selection in output_selection.items(): optimization, evm_version = settings_key.split("%") if optimization == "true": @@ -846,7 +967,7 @@ def _profile(_name: str, _full_name: str): # Some statements are too difficult to know right away where they belong, # such as statement related to kwarg-default auto-generated implicit lookups. # function_name -> (pc, location) - pending_statements: Dict[str, List[Tuple[int, SourceLocation]]] = {} + pending_statements: dict[str, list[tuple[int, SourceLocation]]] = {} for pc, item in contract_source.pcmap.root.items(): pc_int = int(pc) @@ -928,7 +1049,7 @@ def _profile(_name: str, _full_name: str): ] # Sort the autogenerated ABIs so we can loop through them in the correct order. autogenerated_abis.sort(key=lambda a: len(a.inputs)) - buckets: Dict[str, List[Tuple[int, SourceLocation]]] = { + buckets: dict[str, list[tuple[int, SourceLocation]]] = { a.selector: [] for a in autogenerated_abis } selector_index = 0 @@ -977,15 +1098,18 @@ def _profile(_name: str, _full_name: str): # Auto-getter found. Profile function without statements. contract_coverage.include(method.name, method.selector) - def _get_compiler_arguments(self, version_map: Dict, base_path: Path) -> Dict[Version, Dict]: - base_path = base_path or self.project_manager.contracts_folder + def _get_compiler_arguments( + self, version_map: dict, project: Optional[ProjectManager] = None + ) -> dict[Version, dict]: + pm = project or self.local_project + config = self.get_config(pm) + evm_version = config.evm_version arguments_map = {} for vyper_version, source_paths in version_map.items(): bin_arg = self._get_vyper_bin(vyper_version) arguments_map[vyper_version] = { - "base_path": str(base_path), - "evm_version": self.evm_version - or EVM_VERSION_DEFAULT.get(vyper_version.base_version), + "base_path": f"{pm.path}", + "evm_version": evm_version, "vyper_version": str(vyper_version), "vyper_binary": bin_arg, } @@ -1026,7 +1150,7 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: return err runtime_error_cls = RUNTIME_ERROR_MAP[error_type] - tx_kwargs: Dict = { + tx_kwargs: dict = { "contract_address": err.contract_address, "source_traceback": err.source_traceback, "trace": err.trace, @@ -1039,17 +1163,15 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: ) def trace_source( - self, contract_type: ContractType, trace: Iterator[TraceFrame], calldata: HexBytes + self, contract_source: ContractSource, trace: TraceAPI, calldata: HexBytes ) -> SourceTraceback: - if source_contract_type := self.project_manager._create_contract_source(contract_type): - return self._get_traceback(source_contract_type, trace, calldata) - - return SourceTraceback.model_validate([]) + frames = trace.get_raw_frames() + return self._get_traceback(contract_source, frames, calldata) def _get_traceback( self, contract_src: ContractSource, - trace: Iterator[TraceFrame], + frames: Iterator[dict], calldata: HexBytes, previous_depth: Optional[int] = None, ) -> SourceTraceback: @@ -1058,16 +1180,16 @@ def _get_traceback( completed = False pcmap = PCMap.model_validate({}) - for frame in trace: - if frame.op in CALL_OPCODES: - start_depth = frame.depth + for frame in frames: + if frame["op"] in CALL_OPCODES: + start_depth = frame["depth"] called_contract, sub_calldata = self._create_contract_from_call(frame) if called_contract: - ext = Path(called_contract.source_id).suffix + ext = get_full_extension(Path(called_contract.source_id)) if ext.endswith(".vy"): # Called another Vyper contract. sub_trace = self._get_traceback( - called_contract, trace, sub_calldata, previous_depth=frame.depth + called_contract, frames, sub_calldata, previous_depth=frame["depth"] ) traceback.extend(sub_trace) @@ -1076,48 +1198,48 @@ def _get_traceback( compiler = self.compiler_manager.registered_compilers[ext] try: sub_trace = compiler.trace_source( - called_contract.contract_type, trace, sub_calldata + called_contract.contract_type, frames, sub_calldata ) traceback.extend(sub_trace) except NotImplementedError: # Compiler not supported. Fast forward out of this call. - for fr in trace: - if fr.depth <= start_depth: + for fr in frames: + if fr["depth"] <= start_depth: break continue else: # Contract not found. Fast forward out of this call. - for fr in trace: - if fr.depth <= start_depth: + for fr in frames: + if fr["depth"] <= start_depth: break continue - elif frame.op in _RETURN_OPCODES: + elif frame["op"] in _RETURN_OPCODES: # For the base CALL, don't mark as completed until trace is gone. # This helps in cases where we failed to detect a subcall properly. completed = previous_depth is not None pcs_to_try_adding = set() - if "PUSH" in frame.op and frame.pc in contract_src.pcmap: + if "PUSH" in frame["op"] and frame["pc"] in contract_src.pcmap: # Check if next op is SSTORE to properly use AST from push op. - next_frame: Optional[TraceFrame] = frame - loc = contract_src.pcmap[frame.pc] - pcs_to_try_adding.add(frame.pc) + next_frame: Optional[dict] = frame + loc = contract_src.pcmap[frame["pc"]] + pcs_to_try_adding.add(frame["pc"]) - while next_frame and "PUSH" in next_frame.op: - next_frame = next(trace, None) - if next_frame and "PUSH" in next_frame.op: - pcs_to_try_adding.add(next_frame.pc) + while next_frame and "PUSH" in next_frame["op"]: + next_frame = next(frames, None) + if next_frame and "PUSH" in next_frame["op"]: + pcs_to_try_adding.add(next_frame["pc"]) is_non_payable_hit = False - if next_frame and next_frame.op == "SSTORE": + if next_frame and next_frame["op"] == "SSTORE": push_location = tuple(loc["location"]) # type: ignore - pcmap = PCMap.model_validate({next_frame.pc: {"location": push_location}}) + pcmap = PCMap.model_validate({next_frame["pc"]: {"location": push_location}}) - elif next_frame and next_frame.op in _RETURN_OPCODES: + elif next_frame and next_frame["op"] in _RETURN_OPCODES: completed = True else: @@ -1131,22 +1253,22 @@ def _get_traceback( else: pcmap = contract_src.pcmap - pcs_to_try_adding.add(frame.pc) + pcs_to_try_adding.add(frame["pc"]) pcs_to_try_adding = {pc for pc in pcs_to_try_adding if pc in pcmap} if not pcs_to_try_adding: if ( - frame.op == "REVERT" - and frame.pc + 1 in pcmap + frame["op"] == "REVERT" + and frame["pc"] + 1 in pcmap and RuntimeErrorType.USER_ASSERT.value - in str(pcmap[frame.pc + 1].get("dev", "")) + in str(pcmap[frame["pc"] + 1].get("dev", "")) ): # Not sure why this happens. Maybe an off-by-1 bug in Vyper. - pcs_to_try_adding.add(frame.pc + 1) + pcs_to_try_adding.add(frame["pc"] + 1) - pc_groups: List[List] = [] + pc_groups: list[list] = [] for pc in pcs_to_try_adding: location = ( - cast(Tuple[int, int, int, int], tuple(pcmap[pc].get("location") or [])) or None + cast(tuple[int, int, int, int], tuple(pcmap[pc].get("location") or [])) or None ) dev_item = pcmap[pc].get("dev", "") dev = str(dev_item).replace("dev: ", "") @@ -1221,9 +1343,9 @@ def _get_traceback( or not isinstance(traceback.last.closure, Function) ): depth = ( - frame.depth + 1 - if traceback.last and traceback.last.depth == frame.depth - else frame.depth + frame["depth"] + 1 + if traceback.last and traceback.last.depth == frame["depth"] + else frame["depth"] ) traceback.add_jump( @@ -1253,8 +1375,26 @@ def _get_traceback( return traceback + def _create_contract_from_call(self, frame: dict) -> tuple[Optional[ContractSource], HexBytes]: + evm_frame = TraceFrame(**frame) + data = create_call_node_data(evm_frame) + calldata = data.get("calldata", HexBytes("")) + if not (address := (data.get("address", evm_frame.contract_address) or None)): + return None, calldata + + try: + address = self.provider.network.ecosystem.decode_address(address) + except Exception: + return None, calldata + + if address not in self.chain_manager.contracts: + return None, calldata -def _safe_append(data: Dict, version: Union[Version, SpecifierSet], paths: Union[Path, Set]): + called_contract = self.chain_manager.contracts[address] + return self.local_project._create_contract_source(called_contract), calldata + + +def _safe_append(data: dict, version: Union[Version, SpecifierSet], paths: Union[Path, set]): if isinstance(paths, Path): paths = {paths} if version in data: @@ -1267,13 +1407,13 @@ def _is_revert_jump(op: str, value: Optional[int], revert_pc: int) -> bool: return op == "JUMPI" and value is not None and value == revert_pc -def _has_empty_revert(opcodes: List[str]) -> bool: +def _has_empty_revert(opcodes: list[str]) -> bool: return (len(opcodes) > 12 and opcodes[-13] == "JUMPDEST" and opcodes[-9] == "REVERT") or ( len(opcodes) > 4 and opcodes[-5] == "JUMPDEST" and opcodes[-1] == "REVERT" ) -def _get_pcmap(bytecode: Dict) -> PCMap: +def _get_pcmap(bytecode: dict) -> PCMap: # Find the non payable value check. src_info = bytecode["sourceMapFull"] pc_data = {pc: {"location": ln} for pc, ln in src_info["pc_pos_map"].items()} @@ -1338,13 +1478,13 @@ def _get_pcmap(bytecode: Dict) -> PCMap: return PCMap.model_validate(pc_data) -def _get_legacy_pcmap(ast: ASTNode, src_map: List[SourceMapItem], opcodes: List[str]): +def _get_legacy_pcmap(ast: ASTNode, src_map: list[SourceMapItem], opcodes: list[str]): """ For Vyper versions <= 0.3.7, allows us to still get a PCMap. """ pc = 0 - pc_map_list: List[Tuple[int, Dict[str, Optional[Any]]]] = [] + pc_map_list: list[tuple[int, dict[str, Optional[Any]]]] = [] last_value = None revert_pc = -1 if _has_empty_revert(opcodes): @@ -1379,7 +1519,7 @@ def _get_legacy_pcmap(ast: ASTNode, src_map: List[SourceMapItem], opcodes: List[ if stmt: # Add located item. line_nos = list(stmt.line_numbers) - item: Dict = {"location": line_nos} + item: dict = {"location": line_nos} is_revert_jump = _is_revert_jump(op, last_value, revert_pc) if op == "REVERT" or is_revert_jump: dev = None @@ -1429,7 +1569,7 @@ def _get_legacy_pcmap(ast: ASTNode, src_map: List[SourceMapItem], opcodes: List[ return PCMap.model_validate(pcmap_data) -def _find_non_payable_check(src_map: List[SourceMapItem], opcodes: List[str]) -> Optional[int]: +def _find_non_payable_check(src_map: list[SourceMapItem], opcodes: list[str]) -> Optional[int]: pc = 0 revert_pc = -1 if _has_empty_revert(opcodes): @@ -1450,7 +1590,7 @@ def _find_non_payable_check(src_map: List[SourceMapItem], opcodes: List[str]) -> return None -def _is_non_payable_check(opcodes: List[str], op: str, revert_pc: int) -> bool: +def _is_non_payable_check(opcodes: list[str], op: str, revert_pc: int) -> bool: return ( len(opcodes) >= 3 and op == "CALLVALUE" @@ -1460,7 +1600,7 @@ def _is_non_payable_check(opcodes: List[str], op: str, revert_pc: int) -> bool: ) -def _get_revert_pc(opcodes: List[str]) -> int: +def _get_revert_pc(opcodes: list[str]) -> int: """ Starting in vyper 0.2.14, reverts without a reason string are optimized with a jump to the "end" of the bytecode. @@ -1472,7 +1612,7 @@ def _get_revert_pc(opcodes: List[str]) -> int: ) -def _is_immutable_member_load(opcodes: List[str]): +def _is_immutable_member_load(opcodes: list[str]): is_code_copy = len(opcodes) > 5 and opcodes[5] == "CODECOPY" return not is_code_copy and opcodes and is_0x_prefixed(opcodes[0]) @@ -1502,7 +1642,7 @@ def _extend_return(function: Function, traceback: SourceTraceback, last_pc: int, traceback.add_jump(location, function, 1, last_pcs, source_path=source_path) -def _is_fallback_check(opcodes: List[str], op: str) -> bool: +def _is_fallback_check(opcodes: list[str], op: str) -> bool: return ( "JUMP" in op and len(opcodes) >= 7 @@ -1510,11 +1650,3 @@ def _is_fallback_check(opcodes: List[str], op: str) -> bool: and opcodes[6] == "SHR" and opcodes[5] == "0xE0" ) - - -# def _version_to_specifier(version: str) -> str: -# pragma_str = " ".join(version.split()).replace("^", "~=") -# if pragma_str and pragma_str[0].isnumeric(): -# return f"=={pragma_str}" -# -# return pragma_str diff --git a/ape_vyper/exceptions.py b/ape_vyper/exceptions.py index b3467af9..3234d483 100644 --- a/ape_vyper/exceptions.py +++ b/ape_vyper/exceptions.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, Optional, Type, Union +from typing import Optional, Union from ape.exceptions import CompilerError, ContractLogicError from ape.utils import USER_ASSERT_TAG @@ -157,7 +157,7 @@ def __init__(self, **kwargs): super().__init__(RuntimeErrorType.FALLBACK_NOT_DEFINED, **kwargs) -RUNTIME_ERROR_MAP: Dict[RuntimeErrorType, Type[ContractLogicError]] = { +RUNTIME_ERROR_MAP: dict[RuntimeErrorType, type[ContractLogicError]] = { RuntimeErrorType.NONPAYABLE_CHECK: NonPayableError, RuntimeErrorType.INVALID_CALLDATA_OR_VALUE: InvalidCalldataOrValueError, RuntimeErrorType.INDEX_OUT_OF_RANGE: IndexOutOfRangeError, diff --git a/ape_vyper/interface.py b/ape_vyper/interface.py index 1b5acfcd..eb367645 100644 --- a/ape_vyper/interface.py +++ b/ape_vyper/interface.py @@ -3,7 +3,7 @@ """ from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union from ethpm_types import ABI, MethodABI from ethpm_types.abi import ABIType @@ -22,7 +22,7 @@ def iface_name_from_file(fpath: Path) -> str: return fpath.name.split(".")[0] -def generate_inputs(inputs: List[ABIType]) -> str: +def generate_inputs(inputs: list[ABIType]) -> str: """Generate the source code input args from ABI inputs""" return ", ".join(f"{i.name}: {i.type}" for i in inputs) @@ -34,14 +34,14 @@ def generate_method(abi: MethodABI) -> str: return f"def {abi.name}({inputs}){return_maybe}: {abi.stateMutability}\n" -def abi_to_type(iface: Dict[str, Any]) -> Optional[ABI]: +def abi_to_type(iface: dict[str, Any]) -> Optional[ABI]: """Convert a dict JSON-like interface to an ethpm-types ABI type""" if iface["type"] == "function": return MethodABI.model_validate(iface) return None -def generate_interface(abi: Union[List[Dict[str, Any]], List[ABI]], iface_name: str) -> str: +def generate_interface(abi: Union[list[dict[str, Any]], list[ABI]], iface_name: str) -> str: """ Generate a Vyper interface source code from an ABI spec @@ -70,10 +70,10 @@ def generate_interface(abi: Union[List[Dict[str, Any]], List[ABI]], iface_name: return f"{source}\n" -def extract_meta(source_code: str) -> Tuple[Optional[str], str]: +def extract_meta(source_code: str) -> tuple[Optional[str], str]: """Extract version pragma, and returne cleaned source""" version_pragma: Optional[str] = None - cleaned_source_lines: List[str] = [] + cleaned_source_lines: list[str] = [] """ Pragma format changed a bit. @@ -94,7 +94,7 @@ def extract_meta(source_code: str) -> Tuple[Optional[str], str]: return (version_pragma, "\n".join(cleaned_source_lines)) -def extract_imports(source: str) -> Tuple[str, str, str]: +def extract_imports(source: str) -> tuple[str, str, str]: """ Extract import lines from the source, return them and the source without imports @@ -121,7 +121,7 @@ def extract_imports(source: str) -> Tuple[str, str, str]: ) -def extract_import_aliases(source: str) -> Dict[str, str]: +def extract_import_aliases(source: str) -> dict[str, str]: """ Extract import aliases from import lines diff --git a/pyproject.toml b/pyproject.toml index 4f748355..9d0c5d65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ write_to = "ape_vyper/version.py" [tool.black] line-length = 100 -target-version = ['py38', 'py39', 'py310', 'py311', 'py312'] +target-version = ['py39', 'py310', 'py311', 'py312'] include = '\.pyi?$' [tool.pytest.ini_options] diff --git a/setup.py b/setup.py index 16326b9a..57092a0d 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- from setuptools import find_packages, setup extras_require = { @@ -59,13 +58,13 @@ url="https://github.com/ApeWorX/ape-vyper", include_package_data=True, install_requires=[ - "eth-ape>=0.7.13,<0.8", + "eth-ape>=0.8.2,<0.9", "ethpm-types", # Use same version as eth-ape "tqdm", # Use same version as eth-ape "vvm>=0.2.0,<0.3", "vyper~=0.3.7", ], - python_requires=">=3.8,<4", + python_requires=">=3.9,<4", extras_require=extras_require, py_modules=["ape_vyper"], entry_points={ @@ -86,7 +85,6 @@ "Operating System :: MacOS", "Operating System :: POSIX", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/tests/ape-config.yaml b/tests/ape-config.yaml index ec6f7227..bd35d1e3 100644 --- a/tests/ape-config.yaml +++ b/tests/ape-config.yaml @@ -3,10 +3,5 @@ contracts_folder: contracts/passing_contracts # Specify a dependency to use in Vyper imports. dependencies: - - name: ExampleDependency + - name: exampledependency local: ./ExampleDependency - -vyper: - # Allows importing dependencies. - import_remapping: - - "exampledep=ExampleDependency" diff --git a/tests/conftest.py b/tests/conftest.py index c75352c4..0d33c100 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,21 +1,15 @@ import os import shutil +import tempfile from contextlib import contextmanager from pathlib import Path from tempfile import mkdtemp -from typing import List import ape import pytest import vvm # type: ignore from click.testing import CliRunner -# NOTE: Ensure that we don't use local paths for these -DATA_FOLDER = Path(mkdtemp()).resolve() -PROJECT_FOLDER = Path(mkdtemp()).resolve() -ape.config.DATA_FOLDER = DATA_FOLDER -ape.config.PROJECT_FOLDER = PROJECT_FOLDER - BASE_CONTRACTS_PATH = Path(__file__).parent / "contracts" TEMPLATES_PATH = BASE_CONTRACTS_PATH / "templates" FAILING_BASE = BASE_CONTRACTS_PATH / "failing_contracts" @@ -46,7 +40,32 @@ } -def contract_test_cases(passing: bool) -> List[str]: +@pytest.fixture(scope="session", autouse=True) +def from_tests_dir(): + # Makes default project correct. + here = Path(__file__).parent + orig = Path.cwd() + if orig != here: + os.chdir(f"{here}") + + yield + + if Path.cwd() != orig: + os.chdir(f"{orig}") + + +@pytest.fixture(scope="session", autouse=True) +def config(): + cfg = ape.config + + # Ensure we don't persist any .ape data. + with tempfile.TemporaryDirectory() as temp_dir: + path = Path(temp_dir).resolve() + cfg.DATA_FOLDER = path + yield cfg + + +def contract_test_cases(passing: bool) -> list[str]: """ Returns test-case names for outputting nicely with pytest. """ @@ -126,16 +145,6 @@ def temp_vvm_path(monkeypatch): yield path -@pytest.fixture -def data_folder(): - return DATA_FOLDER - - -@pytest.fixture -def project_folder(): - return PROJECT_FOLDER - - @pytest.fixture def compiler_manager(): return ape.compilers @@ -146,33 +155,24 @@ def compiler(compiler_manager): return compiler_manager.vyper -@pytest.fixture -def config(): - return ape.config - - -@pytest.fixture(autouse=True) -def project(config, project_folder): +@pytest.fixture(scope="session", autouse=True) +def project(config): project_source_dir = Path(__file__).parent - project_dest_dir = project_folder / project_source_dir.name - shutil.rmtree(project_dest_dir, ignore_errors=True) # Delete build / .cache that may exist pre-copy - project_path = Path(__file__).parent - cache = project_path / ".build" + cache = project_source_dir / ".build" shutil.rmtree(cache, ignore_errors=True) - shutil.copytree(project_source_dir, project_dest_dir, dirs_exist_ok=True) - with config.using_project(project_dest_dir) as project: - yield project - shutil.rmtree(project.local_project._cache_folder, ignore_errors=True) + root_project = ape.Project(project_source_dir) + with root_project.isolate_in_tempdir() as tmp_project: + yield tmp_project @pytest.fixture def geth_provider(): - if not ape.networks.active_provider or ape.networks.provider.name != "geth": + if not ape.networks.active_provider or ape.networks.provider.name != "node": with ape.networks.ethereum.local.use_provider( - "geth", provider_settings={"uri": "http://127.0.0.1:5550"} + "node", provider_settings={"uri": "http://127.0.0.1:5550"} ) as provider: yield provider else: diff --git a/tests/contracts/passing_contracts/flatten_me.vy b/tests/contracts/passing_contracts/flatten_me.vy index 456e2457..77b993d2 100644 --- a/tests/contracts/passing_contracts/flatten_me.vy +++ b/tests/contracts/passing_contracts/flatten_me.vy @@ -4,7 +4,7 @@ from vyper.interfaces import ERC20 from interfaces import IFace2 as IFaceTwo import interfaces.IFace as IFace -import exampledep.Dependency as Dep +import exampledependency.Dependency as Dep @external diff --git a/tests/contracts/passing_contracts/use_iface.vy b/tests/contracts/passing_contracts/use_iface.vy index ac136ddc..38f99932 100644 --- a/tests/contracts/passing_contracts/use_iface.vy +++ b/tests/contracts/passing_contracts/use_iface.vy @@ -4,7 +4,7 @@ import interfaces.IFace as IFace # Import from input JSON (ape-config.yaml). -import exampledep.Dependency as Dep +import exampledependency.Dependency as Dep from interfaces import IFace2 as IFace2 diff --git a/tests/projects/coverage_project/ape-config.yaml b/tests/projects/coverage_project/ape-config.yaml index 8db90db3..679a1df8 100644 --- a/tests/projects/coverage_project/ape-config.yaml +++ b/tests/projects/coverage_project/ape-config.yaml @@ -1,8 +1,8 @@ ethereum: local: - default_provider: geth + default_provider: node -geth: +node: ethereum: local: # Connect to node running from other tests. diff --git a/tests/test_cli.py b/tests/test_cli.py index fa9a55db..1e5b3ba9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -21,9 +21,12 @@ ) def test_cli_flatten(project, contract_name, expected, cli_runner): path = project.contracts_folder / contract_name + arguments = ["flatten", str(path)] + end = ("--project", str(project.path)) with create_tempdir() as tmpdir: file = tmpdir / "flatten.vy" - result = cli_runner.invoke(cli, ("flatten", str(path), str(file)), catch_exceptions=False) + arguments.extend([str(file), *end]) + result = cli_runner.invoke(cli, arguments, catch_exceptions=False) assert result.exit_code == 0, result.stderr_bytes output = file.read_text() for expect in expected: diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 7e3cae92..2278acd3 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -1,8 +1,10 @@ import re +from pathlib import Path +import ape import pytest import vvm # type: ignore -from ape.exceptions import ContractLogicError +from ape.exceptions import CompilerError, ContractLogicError from ethpm_types import ContractType from packaging.version import Version from vvm.exceptions import VyperError # type: ignore @@ -48,39 +50,42 @@ def test_compile_project(project): assert len(contracts) == len( [p.name for p in project.contracts_folder.glob("*.vy") if p.is_file()] ) - assert contracts["contract_039"].source_id == "contract_039.vy" - assert contracts["contract_no_pragma"].source_id == "contract_no_pragma.vy" - assert contracts["older_version"].source_id == "older_version.vy" + prefix = "contracts/passing_contracts" + assert contracts["contract_039"].source_id == f"{prefix}/contract_039.vy" + assert contracts["contract_no_pragma"].source_id == f"{prefix}/contract_no_pragma.vy" + assert contracts["older_version"].source_id == f"{prefix}/older_version.vy" @pytest.mark.parametrize("contract_name", PASSING_CONTRACT_NAMES) def test_compile_individual_contracts(project, contract_name, compiler): path = project.contracts_folder / contract_name - assert compiler.compile([path]) + assert list(compiler.compile((path,), project=project)) @pytest.mark.parametrize( "contract_name", [n for n in FAILING_CONTRACT_NAMES if n != "contract_unknown_pragma.vy"] ) def test_compile_failures(contract_name, compiler): + failing_project = ape.Project(FAILING_BASE) path = FAILING_BASE / contract_name with pytest.raises(VyperCompileError, match=EXPECTED_FAIL_PATTERNS[path.stem]) as err: - compiler.compile([path], base_path=FAILING_BASE) + list(compiler.compile((path,), project=failing_project)) assert isinstance(err.value.base_err, VyperError) def test_install_failure(compiler): + failing_project = ape.Project(FAILING_BASE) path = FAILING_BASE / "contract_unknown_pragma.vy" with pytest.raises(VyperInstallError, match="No available version to install."): - compiler.compile([path]) + list(compiler.compile((path,), project=failing_project)) def test_get_version_map(project, compiler, all_versions): vyper_files = [ x for x in project.contracts_folder.iterdir() if x.is_file() and x.suffix == ".vy" ] - actual = compiler.get_version_map(vyper_files) + actual = compiler.get_version_map(vyper_files, project=project) expected_versions = [Version(v) for v in all_versions] for version, sources in actual.items(): @@ -171,9 +176,9 @@ def run_test(manifest): for compiler in (true_latest, vyper_028): assert compiler.settings["optimize"] is True - project.local_project.update_manifest(compilers=[]) + project.update_manifest(compilers=[]) project.load_contracts(use_cache=False) - run_test(project.local_project.manifest) + run_test(project.manifest) man = project.extract_manifest() run_test(man) @@ -185,7 +190,7 @@ def test_compile_parse_dev_messages(compiler, dev_revert_source, project): The compiler will output a map that maps dev messages to line numbers. See contract_with_dev_messages.vy for more information. """ - result = compiler.compile([dev_revert_source], base_path=project.contracts_folder) + result = list(compiler.compile((dev_revert_source,), project=project)) assert len(result) == 1 @@ -204,18 +209,22 @@ def test_get_imports(compiler, project): vyper_files = [ x for x in project.contracts_folder.iterdir() if x.is_file() and x.suffix == ".vy" ] - actual = compiler.get_imports(vyper_files) + actual = compiler.get_imports(vyper_files, project=project) + prefix = "contracts/passing_contracts" builtin_import = "vyper/interfaces/ERC20.json" local_import = "interfaces/IFace.vy" local_from_import = "interfaces/IFace2.vy" - dependency_import = "exampledep/Dependency.json" - - assert len(actual["contract_037.vy"]) == 1 - assert set(actual["contract_037.vy"]) == {builtin_import} - assert len(actual["use_iface.vy"]) == 3 - assert set(actual["use_iface.vy"]) == {local_import, local_from_import, dependency_import} - assert len(actual["use_iface2.vy"]) == 1 - assert set(actual["use_iface2.vy"]) == {local_import} + dep_key = project.dependencies.get_dependency("exampledependency", "local").package_id.replace( + "/", "_" + ) + dependency_import = f"{dep_key}/local/contracts/Dependency.vy" + assert set(actual[f"{prefix}/contract_037.vy"]) == {builtin_import} + assert set(actual[f"{prefix}/use_iface.vy"]) == { + local_import, + local_from_import, + dependency_import, + } + assert set(actual[f"{prefix}/use_iface2.vy"]) == {local_import} @pytest.mark.parametrize("src,vers", [("contract_039", "0.3.9"), ("contract_037", "0.3.7")]) @@ -225,15 +234,16 @@ def test_pc_map(compiler, project, src, vers): from `compile_src()` which includes the uncompressed source map data. """ - path = project.contracts_folder / f"{src}.vy" - result = compiler.compile([path], base_path=project.contracts_folder)[0] + path = project.sources.lookup(src) + result = list(compiler.compile((path,), project=project))[0] actual = result.pcmap.root code = path.read_text() vvm.install_vyper(vers) - compile_result = vvm.compile_source(code, vyper_version=vers, evm_version=compiler.evm_version)[ - "" - ] - src_map = compile_result["source_map"] + cfg = compiler.get_config(project=project) + evm_version = cfg.evm_version + compile_result = vvm.compile_source(code, vyper_version=vers, evm_version=evm_version) + std_result = compile_result[""] + src_map = std_result["source_map"] lines = code.splitlines() # Use the old-fashioned way of gathering PCMap to ensure our creative way works @@ -389,7 +399,7 @@ def test_enrich_error_handle_when_name(compiler, geth_provider, mocker): def test_trace_source(account, geth_provider, project, traceback_contract, arguments): receipt = traceback_contract.addBalance(*arguments, sender=account) actual = receipt.source_traceback - base_folder = project.contracts_folder + base_folder = Path(__file__).parent / "contracts" / "passing_contracts" contract_name = traceback_contract.contract_type.name expected = rf""" Traceback (most recent call last) @@ -437,7 +447,7 @@ def test_trace_err_source(account, geth_provider, project, traceback_contract): receipt = geth_provider.get_receipt(txn.txn_hash.hex()) actual = receipt.source_traceback - base_folder = project.contracts_folder + base_folder = Path(__file__).parent / "contracts" / "passing_contracts" contract_name = traceback_contract.contract_type.name version_key = contract_name.split("traceback_contract_")[-1] expected = rf""" @@ -478,19 +488,20 @@ def test_compile_with_version_set_in_config(config, projects_path, compiler, moc path = projects_path / "version_in_config" version_from_config = "0.3.7" spy = mocker.patch("ape_vyper.compiler.vvm_compile_standard") - with config.using_project(path) as project: - contract = project.contracts_folder / "v_contract.vy" - settings = compiler.get_compiler_settings((contract,)) - assert str(list(settings.keys())[0]) == version_from_config + project = ape.Project(path) - # Show it uses this version in the compiler. - project.load_contracts(use_cache=False) - assert str(spy.call_args[1]["vyper_version"]) == version_from_config + contract = project.contracts_folder / "v_contract.vy" + settings = compiler.get_compiler_settings((contract,), project=project) + assert str(list(settings.keys())[0]) == version_from_config + # Show it uses this version in the compiler. + project.load_contracts(use_cache=False) + assert str(spy.call_args[1]["vyper_version"]) == version_from_config -def test_compile_code(compiler, dev_revert_source): + +def test_compile_code(project, compiler, dev_revert_source): code = dev_revert_source.read_text() - actual = compiler.compile_code(code, contractName="MyContract") + actual = compiler.compile_code(code, project=project, contractName="MyContract") assert isinstance(actual, ContractType) assert actual.name == "MyContract" assert len(actual.abi) > 1 @@ -501,13 +512,13 @@ def test_compile_code(compiler, dev_revert_source): def test_compile_with_version_set_in_settings_dict(config, compiler_manager, projects_path): path = projects_path / "version_in_config" contract = path / "contracts" / "v_contract.vy" - - with config.using_project(path): - expected = ( - '.*Version specification "0.3.10" is not compatible with compiler version "0.3.3"' - ) - with pytest.raises(VyperCompileError, match=expected): - compiler_manager.compile([contract], settings={"version": "0.3.3"}) + project = ape.Project(path) + expected = '.*Version specification "0.3.10" is not compatible with compiler version "0.3.3"' + iterator = compiler_manager.compile( + (contract,), project=project, settings={"vyper": {"version": "0.3.3"}} + ) + with pytest.raises(CompilerError, match=expected): + _ = list(iterator) @pytest.mark.parametrize( @@ -529,8 +540,8 @@ def test_compile_with_version_set_in_settings_dict(config, compiler_manager, pro ) def test_flatten_contract(all_versions, project, contract_name, compiler): path = project.contracts_folder / contract_name - source = compiler.flatten_contract(path) + source = compiler.flatten_contract(path, project=project) source_code = str(source) version = compiler._source_vyper_version(source_code) vvm.install_vyper(str(version)) - vvm.compile_source(source_code, base_path=project.contracts_folder, vyper_version=version) + vvm.compile_source(source_code, base_path=project.path, vyper_version=version) diff --git a/tests/test_coverage.py b/tests/test_coverage.py index 3a8d5287..420a6280 100644 --- a/tests/test_coverage.py +++ b/tests/test_coverage.py @@ -2,10 +2,9 @@ import shutil import xml.etree.ElementTree as ET from pathlib import Path -from typing import List +from typing import Optional import pytest -from ape.utils import create_tempdir LINES_VALID = 8 MISSES = 0 @@ -50,17 +49,37 @@ def coverage_project_path(projects_path): def coverage_project(config, coverage_project_path): build_dir = coverage_project_path / ".build" shutil.rmtree(build_dir, ignore_errors=True) - with create_tempdir(name="coverage_project") as base_dir: - shutil.copytree(coverage_project_path, base_dir, dirs_exist_ok=True) - with config.using_project(base_dir) as project: - yield project + + with config.Project.create_temporary_project() as tmp_project: + shutil.copytree(coverage_project_path, tmp_project.path, dirs_exist_ok=True) + yield tmp_project shutil.rmtree(build_dir, ignore_errors=True) @pytest.fixture -def setup_pytester(pytester, coverage_project_path): - tests_path = coverage_project_path / "tests" +def setup_pytester(pytester, coverage_project): + tests_path = coverage_project.tests_folder + + # Make other files + def _make_all_files(base: Path, prefix: Optional[Path] = None): + if not base.is_dir(): + return + + for file in base.iterdir(): + if file.is_dir() and not file.name == "tests": + _make_all_files(file, prefix=Path(file.name)) + elif file.is_file(): + name = (prefix / file.name).as_posix() if prefix else file.name + + if name == "ape-config.yaml": + # Hack in in-memory overrides for testing purposes. + text = str(coverage_project.config) + else: + text = file.read_text() + + src = {name: text.splitlines()} + pytester.makefile(file.suffix, **src) # Assume all tests should pass num_passes = 0 @@ -80,6 +99,7 @@ def setup_pytester(pytester, coverage_project_path): num_failed += len([x for x in content.split("\n") if x.startswith("def test_fail_")]) pytester.makepyfile(**test_files) + _make_all_files(coverage_project.path) # Check for a conftest.py conftest = tests_path / "conftest.py" @@ -92,14 +112,18 @@ def setup_pytester(pytester, coverage_project_path): def test_coverage(geth_provider, setup_pytester, coverage_project, pytester): passed, failed = setup_pytester - result = pytester.runpytest("--coverage") - result.assert_outcomes(passed=passed, failed=failed) + result = pytester.runpytest_subprocess("--coverage") + try: + result.assert_outcomes(passed=passed, failed=failed) + except ValueError: + pytest.fail(str(result.stderr)) + actual = _get_coverage_report(result.outlines) expected = [x.strip() for x in EXPECTED_COVERAGE_REPORT.split("\n")] _assert_coverage(actual, expected) # Ensure XML was created. - base_dir = coverage_project.local_project._cache_folder + base_dir = pytester.path / ".build" xml_path = base_dir / "coverage.xml" _assert_xml(xml_path) html_path = base_dir / "htmlcov" @@ -109,7 +133,7 @@ def test_coverage(geth_provider, setup_pytester, coverage_project, pytester): _assert_html(index) -def _get_coverage_report(lines: List[str]) -> List[str]: +def _get_coverage_report(lines: list[str]) -> list[str]: ret = [] started = False for line in lines: @@ -134,7 +158,7 @@ def _get_coverage_report(lines: List[str]) -> List[str]: return ret -def _assert_coverage(actual: List[str], expected: List[str]): +def _assert_coverage(actual: list[str], expected: list[str]): for idx, (a_line, e_line) in enumerate(zip(actual, expected)): message = f"Failed at index {idx}. Expected={e_line}, Actual={a_line}" assert re.match(e_line, a_line), message