Skip to content

Commit

Permalink
Resolve types using get_type_hints
Browse files Browse the repository at this point in the history
  • Loading branch information
avalentino committed Jan 5, 2025
1 parent a4527d1 commit 7d5a81a
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 3 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ per-file-ignores =
bpack/utils.py: D103
bpack/codecs.py: A005
bpack/typing.py: A005
bpack/tests/test_future_annotations.py: D,T003,NQA102
statistics = True
count = True
extend-exclude = examples/*
Expand Down
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ repos:
(?x)^(
bpack/tests/test_field_descriptor.py|
bpack/tests/test_record_descriptor.py|
bpack/tests/test_utils.py
bpack/tests/test_utils.py|
bpack/tests/test_future_annotations.py
)
- repo: https://github.com/pycqa/flake8
Expand Down
40 changes: 39 additions & 1 deletion bpack/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import builtins
import warnings
import dataclasses
from typing import Optional, Union
from typing import Optional, Union, get_type_hints
from collections.abc import Iterator, Sequence

import bpack.utils
Expand Down Expand Up @@ -62,6 +62,11 @@ def _resolve_type(type_):
Replace :class:`typing.Annotated` types with the corresponding
not-annotated ones.
"""
if isinstance(type_, str):
raise TypeError(
f"the 'type_' parameter cannot be a string (type_: {type_!r})"
)

if bpack.utils.is_sequence_type(type_):
etype = bpack.utils.effective_type(type_)
try:
Expand Down Expand Up @@ -97,6 +102,10 @@ class BinFieldDescriptor:
def _validate_type(self):
if self.type is None:
raise TypeError(f"invalid type '{self.type!r}'")
elif isinstance(self.type, str):
raise TypeError(
f"'{self.__class__.__name__}.type' cannot be a string"
)

def _validate_size(self):
msg = f"invalid size: {self.size!r} (must be a positive integer)"
Expand Down Expand Up @@ -133,6 +142,12 @@ def _validate_enum_type(self):

def __post_init__(self):
"""Finalize BinFieldDescriptor instance initialization."""
if isinstance(self.type, str):
raise TypeError(
f"the 'type' parameter cannot be a string "
f"(type_: {self.type!r})"
)

if self.offset is not None:
self._validate_offset()

Expand Down Expand Up @@ -201,6 +216,11 @@ def update_from_type(self, type_: type):
"""Update the field descriptor according to the specified type."""
if self.type is not None:
raise TypeError("the type attribute is already set")
if isinstance(type_, str):
raise TypeError(
f"the 'type_' parameter cannot be a string (type_: {type_!r})"
)

if bpack.typing.is_annotated(type_):
_, params = bpack.typing.get_args(type_)
valid = True
Expand All @@ -221,7 +241,11 @@ def update_from_type(self, type_: type):
self.size = params.size
elif bpack.utils.is_sequence_type(type_):
etype = bpack.utils.effective_type(type_, keep_annotations=True)

# this is needed to set "signed" and "size"
self.update_from_type(etype)

# restore the proper sequence type
self.type = _resolve_type(type_)
else:
self.type = type_
Expand Down Expand Up @@ -284,6 +308,12 @@ def is_field(obj) -> bool:


def _update_field_metadata(field_, **kwargs):
type_ = kwargs.get("type")
if isinstance(type_, str):
raise TypeError(
f"the 'type' parameter cannot be a string (type_: {type_!r})"
)

metadata = field_.metadata.copy() if field_.metadata is not None else {}
metadata.update(**kwargs)
field_.metadata = types.MappingProxyType(metadata)
Expand Down Expand Up @@ -444,6 +474,9 @@ class to be decorated
else:
cls = dataclasses.dataclass(cls, **kwargs)

# import inspect
# types_ = inspect.get_annotations(cls)
types_ = get_type_hints(cls, include_extras=True)
fields_ = dataclasses.fields(cls)

# Initialize to a dummy value with initial offset + size = 0
Expand All @@ -457,6 +490,10 @@ class to be decorated
# NOTE: this is ensured by dataclasses but not by attr
assert field_.type is not None

# resolve all types
if isinstance(field_.type, str):
field_.type = types_[field_.name]

if bpack.typing.is_annotated(field_.type):
# check byteorder
_, params = bpack.typing.get_args(field_.type)
Expand All @@ -476,6 +513,7 @@ class to be decorated
except NotFieldDescriptorError:
field_descr = BinFieldDescriptor()
if isinstance(field_, Field):
# set "type", "size" and "signed"
field_descr.update_from_type(field_.type)

if field_descr.size is None:
Expand Down
2 changes: 1 addition & 1 deletion bpack/tests/test_field_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class Record:

@staticmethod
def test_invalid_field_type():
with pytest.raises(TypeError):
with pytest.raises((TypeError, NameError)):

@bpack.descriptor
class Record:
Expand Down
81 changes: 81 additions & 0 deletions bpack/tests/test_future_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Tests bpack with "from __future__ import annotations"."""

from __future__ import annotations

import enum

import pytest

import bpack
from bpack import T


class EEnumType(enum.IntEnum):
A = 1
B = 2
C = 3


@bpack.descriptor
class Record:
field_1: int = bpack.field(size=4, default=11)
field_2: float = bpack.field(size=4, default=22.22)
field_3: EEnumType = bpack.field(size=1, default=EEnumType.A)


@bpack.descriptor
class NestedRecord:
field_1: str = bpack.field(size=10, default="0123456789")
field_2: Record = bpack.field(
size=bpack.calcsize(Record), default_factory=Record
)
field_3: int = bpack.field(size=4, default=3)
field_4: T["f8"] = 0.1 # noqa: F821


def test_nested_records():
record = Record()
nested_record = NestedRecord()

assert nested_record.field_1 == "0123456789"
assert nested_record.field_2 == record
assert nested_record.field_2.field_1 == record.field_1
assert nested_record.field_2.field_2 == record.field_2
assert nested_record.field_2.field_3 is EEnumType.A
assert nested_record.field_3 == 3
assert nested_record.field_4 == 0.1


def test_nested_records_consistency_error():
with pytest.raises(bpack.descriptors.DescriptorConsistencyError):

@bpack.descriptor
class NestedRecord:
field_1: str = bpack.field(size=10, default="0123456789")
field_2: Record = bpack.field(
size=bpack.calcsize(Record) + 1, default_factory=Record
)
field_3: int = bpack.field(size=4, default=3)


def test_nested_records_autosize():
assert bpack.calcsize(NestedRecord) == 31


def test_unexisting_field_type():
with pytest.raises((TypeError, NameError)):

@bpack.descriptor
class Record:
field_1: "unexisting" = bpack.field(size=4) # noqa: F821


invalid = True


def test_invalid_field_type():
with pytest.raises((TypeError, NameError)):

@bpack.descriptor
class Record:
field_1: "invalid" = bpack.field(size=4) # noqa: F821
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Issues = "https://github.com/avalentino/bpack/issues"
packages = [
"bpack",
"bpack.tests",
"bpack.tests.future_annotations",
"bpack.tools",
"bpack.tools.tests",
"bpack.tools.tests.data",
Expand Down

0 comments on commit 7d5a81a

Please sign in to comment.