Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: pydantic v2 support [APE-1412] #55

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml

Expand All @@ -10,7 +10,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.10.1
hooks:
- id: black
name: black
Expand All @@ -21,7 +21,7 @@ repos:
- id: flake8

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
rev: v1.6.1
hooks:
- id: mypy
additional_dependencies: [types-PyYAML, types-requests, types-setuptools, pydantic]
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ web3 = Web3(HTTPProvider("https://path.to.my.node"))
txn_hash = "0x..."
struct_logs = web3.manager.request_blocking("debug_traceTransaction", [txn_hash]).structLogs
for item in struct_logs:
frame = TraceFrame.parse_obj(item)
frame = TraceFrame.model_validate(item)
```

If you want to get the call-tree node, you can do:
Expand Down Expand Up @@ -69,7 +69,7 @@ If you are using a node that supports the `trace_transaction` RPC, you can use `
from evm_trace import CallType, ParityTraceList

raw_trace_list = web3.manager.request_blocking("trace_transaction", [txn_hash])
trace_list = ParityTraceList.parse_obj(raw_trace_list)
trace_list = ParityTraceList.model_validate(raw_trace_list)
```

And to make call-tree nodes, you can do:
Expand Down
8 changes: 4 additions & 4 deletions evm_trace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .base import CallTreeNode
from .enums import CallType
from .geth import (
from evm_trace.base import CallTreeNode
from evm_trace.enums import CallType
from evm_trace.geth import (
TraceFrame,
create_trace_frames,
get_calltree_from_geth_call_trace,
get_calltree_from_geth_trace,
)
from .parity import ParityTrace, ParityTraceList, get_calltree_from_parity_trace
from evm_trace.parity import ParityTrace, ParityTraceList, get_calltree_from_parity_trace

__all__ = [
"CallTreeNode",
Expand Down
7 changes: 0 additions & 7 deletions evm_trace/_pydantic_compat.py

This file was deleted.

31 changes: 11 additions & 20 deletions evm_trace/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
from functools import cached_property, singledispatchmethod
from typing import List, Optional

from ethpm_types import BaseModel as _BaseModel
from ethpm_types import HexBytes
from eth_pydantic_types import HexBytes
from pydantic import BaseModel as _BaseModel
from pydantic import ConfigDict, field_validator

from evm_trace._pydantic_compat import validator
from evm_trace.display import get_tree_display
from evm_trace.enums import CallType


class BaseModel(_BaseModel):
class Config:
# NOTE: Due to https://github.com/samuelcolvin/pydantic/issues/1241 we have
# to add this cached property workaround in order to avoid this error:

# TypeError: cannot pickle '_thread.RLock' object

keep_untouched = (cached_property, singledispatchmethod)
arbitrary_types_allowed = True
underscore_attrs_are_private = True
copy_on_model_validation = "none"
model_config = ConfigDict(
ignored_types=(cached_property, singledispatchmethod),
arbitrary_types_allowed=True,
)


class CallTreeNode(BaseModel):
Expand Down Expand Up @@ -77,17 +71,14 @@ def __repr__(self) -> str:
def __getitem__(self, index: int) -> "CallTreeNode":
return self.calls[index]

@validator("calldata", "returndata", "address", pre=True)
@field_validator("calldata", "returndata", "address", mode="before")
def validate_bytes(cls, value):
return HexBytes(value) if isinstance(value, str) else value

@validator("value", "depth", pre=True)
@field_validator("value", "depth", mode="before")
def validate_ints(cls, value):
if not value:
return 0

return int(value, 16) if isinstance(value, str) else value
return (int(value, 16) if isinstance(value, str) else value) if value else 0

@validator("gas_limit", "gas_cost", pre=True)
@field_validator("gas_limit", "gas_cost", mode="before")
def validate_optional_ints(cls, value):
return int(value, 16) if isinstance(value, str) else value
29 changes: 12 additions & 17 deletions evm_trace/geth.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import math
from typing import Dict, Iterator, List, Optional

from eth_pydantic_types import HashBytes20, HexBytes
from eth_utils import to_int
from ethpm_types import HexBytes
from pydantic import Field, RootModel, field_validator

from evm_trace._pydantic_compat import Field, validator
from evm_trace.base import BaseModel, CallTreeNode
from evm_trace.enums import CALL_OPCODES, CallType
from evm_trace.utils import to_address


class TraceMemory(BaseModel):
__root__: List[HexBytes] = []
class TraceMemory(RootModel[List[HexBytes]]):
root: List[HexBytes] = []
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still defined root here because that was the only way to make it default to empty list when it got used in other models.


def get(self, offset: HexBytes, size: HexBytes):
return extract_memory(offset, size, self.__root__)
return extract_memory(offset, size, self.root)


class TraceFrame(BaseModel):
Expand Down Expand Up @@ -49,15 +48,15 @@ class TraceFrame(BaseModel):
storage: Dict[HexBytes, HexBytes] = {}
"""Contract storage."""

contract_address: Optional[HexBytes] = None
contract_address: Optional[HashBytes20] = None
"""The address producing the frame."""

@validator("pc", "gas", "gas_cost", "depth", pre=True)
@field_validator("pc", "gas", "gas_cost", "depth", mode="before")
def validate_ints(cls, value):
return int(value, 16) if isinstance(value, str) else value

@property
def address(self) -> Optional[HexBytes]:
def address(self) -> Optional[HashBytes20]:
"""
The address of this CALL frame.
Only returns a value if this frame's opcode is a call-based opcode.
Expand All @@ -66,7 +65,7 @@ def address(self) -> Optional[HexBytes]:
if not self.contract_address and (
self.op in CALL_OPCODES and CallType.CREATE.value not in self.op
):
self.contract_address = HexBytes(self.stack[-2][-20:])
self.contract_address = HashBytes20.__eth_pydantic_validate__(self.stack[-2][-20:])

return self.contract_address

Expand Down Expand Up @@ -104,7 +103,7 @@ def _get_create_frames(frame: TraceFrame, frames: Iterator[Dict]) -> List[TraceF
create_frames = [frame]
start_depth = frame.depth
for next_frame in frames:
next_frame_obj = TraceFrame.parse_obj(next_frame)
next_frame_obj = TraceFrame.model_validate(next_frame)
depth = next_frame_obj.depth

if CallType.CREATE.value in next_frame_obj.op:
Expand All @@ -116,11 +115,7 @@ def _get_create_frames(frame: TraceFrame, frames: Iterator[Dict]) -> List[TraceF
# the first frame after the CREATE with an equal depth.
if len(next_frame_obj.stack) > 0:
raw_addr = HexBytes(next_frame_obj.stack[-1][-40:])
try:
frame.contract_address = HexBytes(to_address(raw_addr))
except Exception:
# Potentially, a transaction was made with poor data.
frame.contract_address = raw_addr
frame.contract_address = HashBytes20.__eth_pydantic_validate__(raw_addr)

create_frames.append(next_frame_obj)
break
Expand Down Expand Up @@ -279,7 +274,7 @@ def _create_node(
node_kwargs["last_create_depth"].pop()
for subcall in node_kwargs.get("calls", [])[::-1]:
if subcall.call_type in (CallType.CREATE, CallType.CREATE2):
subcall.address = HexBytes(to_address(frame.stack[-1][-40:]))
subcall.address = HashBytes20.__eth_pydantic_validate__(frame.stack[-1][-40:])
if len(frame.stack) >= 5:
subcall.calldata = frame.memory.get(frame.stack[-4], frame.stack[-5])

Expand Down
39 changes: 16 additions & 23 deletions evm_trace/parity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, Union, cast

from evm_trace._pydantic_compat import Field, validator
from pydantic import Field, RootModel, field_validator

from evm_trace.base import BaseModel, CallTreeNode
from evm_trace.enums import CallType

Expand All @@ -18,7 +19,7 @@ class CallAction(BaseModel):
# only used to recover the specific call type
call_type: str = Field(alias="callType", repr=False)

@validator("value", "gas", pre=True)
@field_validator("value", "gas", mode="before")
def convert_integer(cls, v):
return int(v, 16)

Expand All @@ -32,7 +33,7 @@ class CreateAction(BaseModel):
init: str
value: int

@validator("value", "gas", pre=True)
@field_validator("value", "gas", mode="before")
def convert_integer(cls, v):
return int(v, 16)

Expand All @@ -41,7 +42,7 @@ class SelfDestructAction(BaseModel):
address: str
balance: int

@validator("balance", pre=True)
@field_validator("balance", mode="before")
def convert_integer(cls, v):
return int(v, 16) if isinstance(v, str) else int(v)

Expand All @@ -59,7 +60,7 @@ class ActionResult(BaseModel):
for gas per zero byte and ``16`` gas per non-zero byte.
"""

@validator("gas_used", pre=True)
@field_validator("gas_used", mode="before")
def convert_integer(cls, v):
return int(v, 16) if isinstance(v, str) else int(v)

Expand Down Expand Up @@ -95,26 +96,19 @@ class ParityTrace(BaseModel):
trace_address: List[int] = Field(alias="traceAddress")
transaction_hash: str = Field(alias="transactionHash")

@validator("call_type", pre=True)
def convert_call_type(cls, v, values) -> CallType:
if isinstance(values["action"], CallAction):
v = values["action"].call_type
value = v.upper()
@field_validator("call_type", mode="before")
def convert_call_type(cls, value, info) -> CallType:
if isinstance(info.data["action"], CallAction):
value = info.data["action"].call_type

value = value.upper()
if value == "SUICIDE":
value = "SELFDESTRUCT"

return CallType(value)


class ParityTraceList(BaseModel):
__root__: List[ParityTrace]

# pydantic models with custom root don't have this by default
def __iter__(self):
return iter(self.__root__)

def __getitem__(self, item):
return self.__root__[item]
ParityTraceList = RootModel[List[ParityTrace]]


def get_calltree_from_parity_trace(
Expand All @@ -137,7 +131,7 @@ def get_calltree_from_parity_trace(
Returns:
:class:`~evm_trace.base.CallTreeNode`
"""
root = root or traces[0]
root = root or traces.root[0]
failed = root.error is not None
node_kwargs: Dict[Any, Any] = {
"call_type": root.call_type,
Expand Down Expand Up @@ -187,7 +181,7 @@ def get_calltree_from_parity_trace(
address=selfdestruct_action.address,
)

trace_list: List[ParityTrace] = list(traces)
trace_list: List[ParityTrace] = traces.root
subtraces: List[ParityTrace] = [
sub
for sub in trace_list
Expand All @@ -196,5 +190,4 @@ def get_calltree_from_parity_trace(
]
node_kwargs["calls"] = [get_calltree_from_parity_trace(traces, root=sub) for sub in subtraces]
node_kwargs = {**node_kwargs, **root_kwargs}
node = CallTreeNode.parse_obj(node_kwargs)
return node
return CallTreeNode.model_validate(node_kwargs)
7 changes: 0 additions & 7 deletions evm_trace/utils.py

This file was deleted.

15 changes: 5 additions & 10 deletions evm_trace/vmtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

from eth.vm.memory import Memory
from eth.vm.stack import Stack
from eth_pydantic_types import Address, HexBytes
from eth_utils import to_int
from ethpm_types import HexBytes
from msgspec import Struct
from msgspec.json import Decoder

from evm_trace.utils import to_address

# opcodes grouped by the number of items they pop from the stack
# fmt: off
POP_OPCODES = {
Expand Down Expand Up @@ -141,7 +139,7 @@ def to_trace_frames(
)

if op.op in ["CALL", "DELEGATECALL", "STATICCALL"]:
call_address = to_address(stack.values[-2][1])
call_address = Address.__eth_pydantic_validate__(stack.values[-2][1])

if op.ex:
if op.ex.mem:
Expand Down Expand Up @@ -188,9 +186,6 @@ def from_rpc_response(buffer: bytes) -> Union[VMTrace, List[VMTrace]]:
"""
Decode structured data from a raw `trace_replayTransaction` or `trace_replayBlockTransactions`.
"""
resp = Decoder(RPCResponse, dec_hook=dec_hook).decode(buffer)

if isinstance(resp.result, list):
return [i.vmTrace for i in resp.result]

return resp.result.vmTrace
response = Decoder(RPCResponse, dec_hook=dec_hook).decode(buffer)
result: Union[List[RPCTraceResult], RPCTraceResult] = response.result
return [i.vmTrace for i in result] if isinstance(result, list) else result.vmTrace
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ exclude =
.eggs
docs
build
evm_trace/version.py
per-file-ignores =
# The traces have to be formatted this way for the tests.
tests/expected_traces.py: E501
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"eth-hash[pysha3]", # For eth-utils address checksumming
],
"lint": [
"black>=23.9.1,<24", # Auto-formatter and linter
"mypy>=1.5.1,<2", # Static type analyzer
"black>=23.10.1,<24", # Auto-formatter and linter
"mypy>=1.6.1,<2", # Static type analyzer
"types-setuptools", # Needed for mypy type shed
"flake8>=6.1.0,<7", # Style linter
"isort>=5.10.1,<6", # Import sorting linter
Expand Down Expand Up @@ -57,11 +57,11 @@
url="https://github.com/ApeWorX/evm-trace",
include_package_data=True,
install_requires=[
"pydantic>=1.10.1,<3",
"pydantic>=2.3.0,<3",
"py-evm>=0.7.0a3,<0.8",
"eth-utils>=2.1,<3",
"ethpm-types>=0.5.0,<0.6",
"msgspec>=0.8",
"eth-pydantic-types>=0.1.0a3",
],
python_requires=">=3.8,<4",
extras_require=extras_require,
Expand Down
Loading