Skip to content

Commit

Permalink
fix: flake8 type check
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Oct 29, 2024
1 parent 357d602 commit 08f31b9
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 75 deletions.
18 changes: 11 additions & 7 deletions ape_vyper/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@
from collections.abc import Iterable
from enum import Enum
from pathlib import Path
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

import vvm # type: ignore
from ape.exceptions import ProjectError
from ape.logging import logger
from ape.managers import ProjectManager
from ape.types import SourceTraceback
from ape.utils import get_relative_path
from eth_utils import is_0x_prefixed
from ethpm_types import ASTNode, PCMap, SourceMapItem
from ethpm_types.source import Function
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from packaging.version import Version

from ape_vyper.exceptions import RuntimeErrorType, VyperInstallError

if TYPE_CHECKING:
from ape.types.trace import SourceTraceback
from ethpm_types.source import Function
from packaging.version import Version

Optimization = Union[str, bool]
EVM_VERSION_DEFAULT = {
"0.2.15": "berlin",
Expand Down Expand Up @@ -54,7 +56,7 @@ def __str__(self) -> str:
return self.value


def install_vyper(version: Version):
def install_vyper(version: "Version"):
for attempt in range(MAX_INSTALL_RETRIES):
try:
vvm.install_vyper(version, show_progress=True)
Expand Down Expand Up @@ -255,7 +257,7 @@ def seek() -> Optional[Path]:
return None


def safe_append(data: dict, version: Union[Version, SpecifierSet], paths: Union[Path, set]):
def safe_append(data: dict, version: Union["Version", SpecifierSet], paths: Union[Path, set]):
if isinstance(paths, Path):
paths = {paths}
if version in data:
Expand Down Expand Up @@ -478,7 +480,9 @@ def is_immutable_member_load(opcodes: list[str]):
return not is_code_copy and opcodes and is_0x_prefixed(opcodes[0])


def extend_return(function: Function, traceback: SourceTraceback, last_pc: int, source_path: Path):
def extend_return(
function: "Function", traceback: "SourceTraceback", last_pc: int, source_path: Path
):
return_ast_result = [x for x in function.ast.children if x.ast_type == "Return"]
if not return_ast_result:
return
Expand Down
34 changes: 19 additions & 15 deletions ape_vyper/compiler/_versions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from typing import TYPE_CHECKING, Any, Optional

from ape.logging import logger
from ape.managers.project import ProjectManager
from ape.utils import ManagerAccessMixin, clean_path, get_relative_path
from ethpm_types import ASTNode, ContractType, SourceMap
from ethpm_types.ast import ASTClassification
from ethpm_types.source import Content
from packaging.version import Version
from vvm import compile_standard as vvm_compile_standard # type: ignore
from vvm.exceptions import VyperError # type: ignore

Expand All @@ -25,10 +23,13 @@
get_pcmap,
)
from ape_vyper.exceptions import VyperCompileError
from ape_vyper.imports import ImportMap

if TYPE_CHECKING:
from ape.managers.project import ProjectManager
from packaging.version import Version

from ape_vyper.compiler.api import VyperCompiler
from ape_vyper.imports import ImportMap


class BaseVyperCompiler(ManagerAccessMixin):
Expand All @@ -39,7 +40,7 @@ class BaseVyperCompiler(ManagerAccessMixin):
def __init__(self, api: "VyperCompiler"):
self.api = api

def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict[str, dict]:
def get_import_remapping(self, project: Optional["ProjectManager"] = None) -> dict[str, dict]:
# Overridden on 0.4 to not use.
# Import remappings are for Vyper versions 0.2 - 0.3 to
# create the interfaces dict.
Expand All @@ -48,11 +49,11 @@ def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict

def compile(
self,
vyper_version: Version,
vyper_version: "Version",
settings: dict,
import_map: ImportMap,
import_map: "ImportMap",
compiler_data: dict,
project: Optional[ProjectManager] = None,
project: Optional["ProjectManager"] = None,
):
pm = project or self.local_project
for settings_key, settings_set in settings.items():
Expand Down Expand Up @@ -155,10 +156,10 @@ def compile(

def get_settings(
self,
version: Version,
version: "Version",
source_paths: Iterable[Path],
compiler_data: dict,
project: Optional[ProjectManager] = None,
project: Optional["ProjectManager"] = None,
) -> dict:
pm = project or self.local_project
default_optimization = self._get_default_optimization(version)
Expand Down Expand Up @@ -210,7 +211,7 @@ def _classify_ast(self, _node: ASTNode):
self._classify_ast(child)

def _get_sources_dictionary(
self, source_ids: Iterable[str], project: Optional[ProjectManager] = None, **kwargs
self, source_ids: Iterable[str], project: Optional["ProjectManager"] = None, **kwargs
) -> dict[str, dict]:
"""
Generate input for the "sources" key in the input JSON.
Expand All @@ -225,7 +226,7 @@ def _get_sources_dictionary(
def _get_selection_dictionary(
self,
selection: Iterable[str],
project: Optional[ProjectManager] = None,
project: Optional["ProjectManager"] = None,
**kwargs,
) -> dict:
"""
Expand All @@ -238,7 +239,10 @@ def _get_selection_dictionary(
return {s: ["*"] for s in selection if (pm.path / s).is_file() if "interfaces" not in s}

def _get_compile_kwargs(
self, vyper_version: Version, compiler_data: dict, project: Optional[ProjectManager] = None
self,
vyper_version: "Version",
compiler_data: dict,
project: Optional["ProjectManager"] = None,
) -> dict:
"""
Generate extra kwargs to pass to Vyper.
Expand All @@ -249,14 +253,14 @@ def _get_compile_kwargs(
comp_kwargs["base_path"] = pm.path
return comp_kwargs

def _get_base_compile_kwargs(self, vyper_version: Version, compiler_data: dict):
def _get_base_compile_kwargs(self, vyper_version: "Version", compiler_data: dict):
vyper_binary = compiler_data[vyper_version]["vyper_binary"]
comp_kwargs = {"vyper_version": vyper_version, "vyper_binary": vyper_binary}
return comp_kwargs

def _get_pcmap(
self,
vyper_version: Version,
vyper_version: "Version",
ast: Any,
src_map: list,
opcodes: list[str],
Expand All @@ -274,7 +278,7 @@ def _parse_source_map(self, raw_source_map: Any) -> SourceMap:
# All versions < 0.4 use this one
return SourceMap(root=raw_source_map)

def _get_default_optimization(self, vyper_version: Version) -> Optimization:
def _get_default_optimization(self, vyper_version: "Version") -> Optimization:
"""
The default value for "optimize" in the settings for input JSON.
"""
Expand Down
9 changes: 5 additions & 4 deletions ape_vyper/compiler/_versions/vyper_02.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any

from packaging.version import Version
from typing import TYPE_CHECKING, Any

from ape_vyper._utils import get_legacy_pcmap
from ape_vyper.compiler._versions.base import BaseVyperCompiler

if TYPE_CHECKING:
from packaging.version import Version


class Vyper02Compiler(BaseVyperCompiler):
"""
Expand All @@ -15,7 +16,7 @@ class Vyper02Compiler(BaseVyperCompiler):

def _get_pcmap(
self,
vyper_version: Version,
vyper_version: "Version",
ast: Any,
src_map: list,
opcodes: list[str],
Expand Down
25 changes: 15 additions & 10 deletions ape_vyper/compiler/_versions/vyper_04.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
import os
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from typing import TYPE_CHECKING, Optional

from ape.managers import ProjectManager
from ape.utils import get_full_extension, get_relative_path
from ethpm_types import SourceMap
from packaging.version import Version

from ape_vyper._utils import FileType, Optimization
from ape_vyper.compiler._versions.base import BaseVyperCompiler
from ape_vyper.imports import ImportMap

if TYPE_CHECKING:
from ape.managers.project import ProjectManager
from packaging.version import Version


class Vyper04Compiler(BaseVyperCompiler):
"""
Compiler for Vyper>=0.4.0.
"""

def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict[str, dict]:
def get_import_remapping(self, project: Optional["ProjectManager"] = None) -> dict[str, dict]:
# Import remappings are not used in 0.4.
# You always import via module or package name.
return {}

def get_settings(
self,
version: Version,
version: "Version",
source_paths: Iterable[Path],
compiler_data: dict,
project: Optional[ProjectManager] = None,
project: Optional["ProjectManager"] = None,
) -> dict:
pm = project or self.local_project

Expand All @@ -43,7 +45,7 @@ def get_settings(
return settings

def _get_sources_dictionary(
self, source_ids: Iterable[str], project: Optional[ProjectManager] = None, **kwargs
self, source_ids: Iterable[str], project: Optional["ProjectManager"] = None, **kwargs
) -> dict[str, dict]:
pm = project or self.local_project
if not source_ids:
Expand Down Expand Up @@ -83,18 +85,21 @@ def _get_sources_dictionary(
return src_dict

def _get_compile_kwargs(
self, vyper_version: Version, compiler_data: dict, project: Optional[ProjectManager] = None
self,
vyper_version: "Version",
compiler_data: dict,
project: Optional["ProjectManager"] = None,
) -> dict:
return self._get_base_compile_kwargs(vyper_version, compiler_data)

def _get_default_optimization(self, vyper_version: Version) -> Optimization:
def _get_default_optimization(self, vyper_version: "Version") -> Optimization:
return "gas"

def _parse_source_map(self, raw_source_map: dict) -> SourceMap:
return SourceMap(root=raw_source_map["pc_pos_map_compressed"])

def _get_selection_dictionary(
self, selection: Iterable[str], project: Optional[ProjectManager] = None, **kwargs
self, selection: Iterable[str], project: Optional["ProjectManager"] = None, **kwargs
) -> dict:
pm = project or self.local_project
return {
Expand Down
27 changes: 15 additions & 12 deletions ape_vyper/compiler/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,15 @@
from functools import cached_property
from importlib import import_module
from pathlib import Path
from typing import Optional
from typing import TYPE_CHECKING, Optional

import vvm # type: ignore
from ape.api import CompilerAPI, PluginConfig, TraceAPI
from ape.exceptions import ContractLogicError
from ape.logging import logger
from ape.managers import ProjectManager
from ape.managers.project import LocalProject
from ape.types import ContractSourceCoverage, SourceTraceback
from ape.utils import get_full_extension, get_relative_path
from ape.utils._github import _GithubClient
from eth_pydantic_types import HexBytes
from ethpm_types import ContractType
from ethpm_types.source import Compiler, Content, ContractSource
from packaging.specifiers import SpecifierSet
from packaging.version import Version
Expand All @@ -36,6 +32,13 @@
from ape_vyper.imports import ImportMap, ImportResolver
from ape_vyper.traceback import SourceTracer

if TYPE_CHECKING:
from ape.exceptions import ContractLogicError
from ape.types.coverage import ContractSourceCoverage
from ape.types.trace import SourceTraceback
from eth_pydantic_types import HexBytes
from ethpm_types.contract_type import ContractType


class VyperCompiler(CompilerAPI):
_dependencies_by_project: dict[str, dict[str, ProjectManager]] = {}
Expand Down Expand Up @@ -245,7 +248,7 @@ def compile(
contract_filepaths: Iterable[Path],
project: Optional[ProjectManager] = None,
settings: Optional[dict] = None,
) -> Iterator[ContractType]:
) -> Iterator["ContractType"]:
pm = project or self.local_project
original_settings = self.compiler_settings
self.compiler_settings = {**self.compiler_settings, **(settings or {})}
Expand All @@ -258,7 +261,7 @@ def _compile(
self, contract_filepaths: Iterable[Path], project: Optional[ProjectManager] = None
):
pm = project or self.local_project
contract_types: list[ContractType] = []
contract_types: list["ContractType"] = []
import_map = self._import_resolver.get_imports(pm, contract_filepaths)
config = self.get_config(pm)
version_map = self._get_version_map_from_import_map(
Expand Down Expand Up @@ -328,7 +331,7 @@ def _compile(

def compile_code(
self, code: str, project: Optional[ProjectManager] = None, **kwargs
) -> ContractType:
) -> "ContractType":
# NOTE: We are unable to use `vvm.compile_code()` because it does not
# appear to honor altered VVM install paths, thus always re-installs
# Vyper in our tests because of the monkeypatch. Also, their approach
Expand Down Expand Up @@ -515,18 +518,18 @@ def _get_compiler_settings_from_version_map(
return settings

def init_coverage_profile(
self, source_coverage: ContractSourceCoverage, contract_source: ContractSource
self, source_coverage: "ContractSourceCoverage", contract_source: ContractSource
):
profiler = CoverageProfiler(source_coverage)
profiler.initialize(contract_source)

def enrich_error(self, err: ContractLogicError) -> ContractLogicError:
def enrich_error(self, err: "ContractLogicError") -> "ContractLogicError":
return enrich_error(err)

# TODO: In 0.9, make sure project is a kwarg here.
def trace_source(
self, contract_source: ContractSource, trace: TraceAPI, calldata: HexBytes
) -> SourceTraceback:
self, contract_source: ContractSource, trace: TraceAPI, calldata: "HexBytes"
) -> "SourceTraceback":
return SourceTracer.trace(trace.get_raw_frames(), contract_source, calldata)

def _get_compiler_arguments(
Expand Down
12 changes: 7 additions & 5 deletions ape_vyper/coverage.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from fnmatch import fnmatch
from typing import Optional
from typing import TYPE_CHECKING, Optional

from ape.types import ContractSourceCoverage
from ape.utils import ManagerAccessMixin
from ethpm_types.source import ContractSource
from ethpm_types.utils import SourceLocation

from ape_vyper.exceptions import RuntimeErrorType

if TYPE_CHECKING:
from ape.types import ContractSourceCoverage
from ethpm_types.source import ContractSource


class CoverageProfiler(ManagerAccessMixin):
def __init__(self, source_coverage: ContractSourceCoverage):
def __init__(self, source_coverage: "ContractSourceCoverage"):
self._coverage = source_coverage

def initialize(self, contract_source: ContractSource):
def initialize(self, contract_source: "ContractSource"):
exclusions = self.config_manager.get_config("test").coverage.exclude
contract_name = contract_source.contract_type.name or "__UnknownContract__"

Expand Down
Loading

0 comments on commit 08f31b9

Please sign in to comment.