Skip to content

Commit

Permalink
Fixup mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
elchupanebrej committed Dec 10, 2024
1 parent d888f9d commit 11e3632
Showing 1 changed file with 34 additions and 53 deletions.
87 changes: 34 additions & 53 deletions python/src/cucumber_messages/json_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,33 @@
from dataclasses import fields, is_dataclass, Field, MISSING
from datetime import datetime, date
from enum import Enum
from typing import Any, Dict, List, Optional, Union, get_args, get_origin, TypeVar, Type as TypingType, Tuple
from typing import Any, Dict, List, Optional, Union, get_args, get_origin, TypeVar, Type, Tuple, cast

T = TypeVar('T')

def camel_to_snake(name: str) -> str:
"""Convert string from camelCase to snake_case."""
# Validate field name - must start with letter or underscore and contain only alphanumeric and underscore
if not name or not (name[0].isalpha() or name[0] == '_') or not all(c.isalnum() or c == '_' for c in name):
raise ValueError(f"Invalid field name: {name}")

# Convert camelCase to snake_case
pattern = re.compile(r'(?<!^)(?=[A-Z])')
return pattern.sub('_', name).lower()


def snake_to_camel(name: str) -> str:
"""Convert string from snake_case to camelCase."""
# Validate field name - must start with letter or underscore and contain only alphanumeric and underscore
if not name or not (name[0].isalpha() or name[0] == '_') or not all(c.isalnum() or c == '_' for c in name):
raise ValueError(f"Invalid field name: {name}")

# Convert snake_case to camelCase
components = name.split('_')
return components[0] + ''.join(x.title() for x in components[1:])


class GenericTypeResolver:
"""Handles resolution of generic type hints and container types."""

def __init__(self, module_scope: types.ModuleType):
def __init__(self, module_scope: types.ModuleType) -> None:
self.module_scope = module_scope
self.type_cache = {}
self.type_cache: Dict[str, Type[Any]] = {}

def resolve_container_type(self, container_name: str) -> TypingType:
container_types = {
def resolve_container_type(self, container_name: str) -> Type[Any]:
container_types: Dict[str, Type[Any]] = {
'Sequence': typing.Sequence,
'List': list,
'Set': typing.Set,
Expand All @@ -51,7 +43,7 @@ def resolve_container_type(self, container_name: str) -> TypingType:
raise ValueError(f"Unsupported container type: {container_name}")
return container_types[container_name]

def parse_generic_type(self, type_str: str) -> Tuple[TypingType, List[TypingType]]:
def parse_generic_type(self, type_str: str) -> Tuple[Type[Any], List[Type[Any]]]:
container_name = type_str[:type_str.index('[')].strip()
args_str = type_str[type_str.index('[') + 1:type_str.rindex(']')]

Expand All @@ -64,34 +56,33 @@ def parse_generic_type(self, type_str: str) -> Tuple[TypingType, List[TypingType

return container_type, arg_types

def resolve_type(self, type_str: str) -> TypingType:
def resolve_type(self, type_str: str) -> Type[Any]:
if type_str in self.type_cache:
return self.type_cache[type_str]

if type_str in {'str', 'int', 'float', 'bool', 'Any'}:
resolved = Any if type_str == 'Any' else eval(type_str)
resolved: Type[Any] = Any if type_str == 'Any' else eval(type_str)
elif '[' in type_str:
container_type, arg_types = self.parse_generic_type(type_str)
resolved = container_type[tuple(arg_types) if len(arg_types) > 1 else arg_types[0]]
elif '|' in type_str:
resolved = self._resolve_union_type(type_str)
resolved = cast(Type[Any], self._resolve_union_type(type_str))
elif hasattr(self.module_scope, type_str):
resolved = getattr(self.module_scope, type_str)
resolved = cast(Type[Any], getattr(self.module_scope, type_str))
else:
raise ValueError(f"Could not resolve type: {type_str}")

self.type_cache[type_str] = resolved
return resolved

def _resolve_union_type(self, type_str: str) -> TypingType:
def _resolve_union_type(self, type_str: str) -> Type[Any]:
types_str = [t.strip() for t in type_str.split('|')]
resolved_types = [
self.resolve_type(t)
for t in types_str
if t != 'None'
]
return resolved_types[0] if len(resolved_types) == 1 else Union[tuple(resolved_types)]

return cast(Type[Any], resolved_types[0] if len(resolved_types) == 1 else Union[tuple(resolved_types)])

class DataclassJSONEncoder:
"""Handles encoding of dataclass instances to JSON-compatible dictionaries."""
Expand All @@ -118,7 +109,7 @@ def _encode_value(cls, value: Any) -> Any:

@classmethod
def _encode_dataclass(cls, obj: Any) -> Dict[str, Any]:
result = {}
result: Dict[str, Any] = {}
for field in fields(obj):
value = getattr(obj, field.name)
if value is not None:
Expand All @@ -128,11 +119,10 @@ def _encode_dataclass(cls, obj: Any) -> Dict[str, Any]:
raise ValueError(f"Error encoding field {field.name}: {str(e)}")
return result


class DataclassJSONDecoder:
"""Handles decoding of JSON data to dataclass instances."""

def __init__(self, module_scope: Optional[types.ModuleType] = None):
def __init__(self, module_scope: Optional[types.ModuleType] = None) -> None:
self.module_scope = module_scope or sys.modules[__name__]
self.type_resolver = GenericTypeResolver(self.module_scope)

Expand All @@ -158,10 +148,10 @@ def decode(self, data: Any, target_type: Any) -> Any:
return self._decode_dataclass(data, target_type)

if origin is datetime or target_type is datetime:
return datetime.fromisoformat(data)
return datetime.fromisoformat(data) if data else None

if origin is date or target_type is date:
return date.fromisoformat(data)
return date.fromisoformat(data) if data else None

if isinstance(target_type, type) and issubclass(target_type, Enum):
try:
Expand All @@ -171,7 +161,7 @@ def decode(self, data: Any, target_type: Any) -> Any:

return data

def _get_type_info(self, type_hint: Any) -> Tuple[Any, Tuple]:
def _get_type_info(self, type_hint: Any) -> Tuple[Any, Tuple[Any, ...]]:
origin = get_origin(type_hint)
args = get_args(type_hint)
if origin is None and isinstance(type_hint, type):
Expand All @@ -181,18 +171,18 @@ def _get_type_info(self, type_hint: Any) -> Tuple[Any, Tuple]:
def _is_sequence_type(self, type_hint: Any) -> bool:
origin = get_origin(type_hint)
return (
origin is list or
origin is typing.Sequence or
(isinstance(origin, type) and issubclass(origin, typing.Sequence))
origin is list or
origin is typing.Sequence or
(isinstance(origin, type) and issubclass(origin, typing.Sequence))
)

def _decode_sequence(self, data: Any, type_args: Tuple) -> List:
def _decode_sequence(self, data: Any, type_args: Tuple[Any, ...]) -> List[Any]:
item_type = type_args[0] if type_args else Any
if not isinstance(data, list):
data = [data] if data is not None else []
return [self.decode(item, item_type) for item in data]

def _decode_union(self, data: Any, union_types: Tuple) -> Any:
def _decode_union(self, data: Any, union_types: Tuple[Any, ...]) -> Any:
types = [t for t in union_types if t is not type(None)]
if not types:
return data
Expand All @@ -206,7 +196,7 @@ def _decode_union(self, data: Any, union_types: Tuple) -> Any:
continue
raise ValueError(f"Could not decode {data} as any type in {types}")

def _decode_dict(self, data: Dict, type_args: Tuple) -> Dict:
def _decode_dict(self, data: Dict[Any, Any], type_args: Tuple[Any, ...]) -> Dict[Any, Any]:
if not isinstance(data, dict):
raise TypeError(f"Expected dict but got {type(data)}")
key_type = type_args[0] if type_args else str
Expand All @@ -216,16 +206,14 @@ def _decode_dict(self, data: Dict, type_args: Tuple) -> Dict:
for k, v in data.items()
}

def _decode_dataclass(self, data: Dict[str, Any], target_class: Any) -> Any:
"""Decode dictionary into a dataclass instance."""
def _decode_dataclass(self, data: Dict[str, Any], target_class: Type[Any]) -> Any:
if not isinstance(data, dict):
raise TypeError(f"Expected dict but got {type(data)}")

field_values = {}
field_values: Dict[str, Any] = {}
class_fields = {field.name: field for field in fields(target_class)}

# Create a mapping for both camelCase and original field names
field_mapping = {}
field_mapping: Dict[str, str] = {}
for field_name in class_fields:
try:
camel_name = snake_to_camel(field_name)
Expand All @@ -234,16 +222,12 @@ def _decode_dataclass(self, data: Dict[str, Any], target_class: Any) -> Any:
except ValueError:
continue

# Process input fields
for key, value in data.items():
# Try to map the input key to a field name
field_name = None

# First try direct mapping
if key in field_mapping:
field_name = field_mapping[key]
else:
# Try converting the key
try:
snake_key = camel_to_snake(key)
if snake_key in class_fields:
Expand All @@ -258,45 +242,42 @@ def _decode_dataclass(self, data: Dict[str, Any], target_class: Any) -> Any:
except Exception as e:
raise TypeError(f"Error decoding field {key}: {str(e)}")

# Check for missing required fields
missing_required = [
name for name, field in class_fields.items()
if name not in field_values
and field.default is MISSING
and field.default_factory is MISSING
and field.default is MISSING
and field.default_factory is MISSING
]

if missing_required:
raise TypeError(f"Missing required arguments: {', '.join(missing_required)}")

# Add default values for missing optional fields
self._apply_default_values(field_values, class_fields)
return target_class(**field_values)

def _apply_default_values(self, field_values: Dict[str, Any], class_fields: Dict[str, Field]) -> None:
for field_name, field in class_fields.items():
if field_name not in field_values:
if field.default is not Field:
if field.default is not MISSING:
field_values[field_name] = field.default
elif field.default_factory is not Field:
elif field.default_factory is not MISSING:
field_values[field_name] = field.default_factory()


class JSONDataclassMixin:
"""Mixin class to add JSON conversion capabilities to dataclasses."""

def to_json(self) -> str:
return json.dumps(self.to_dict())

def to_dict(self) -> Dict[str, Any]:
return DataclassJSONEncoder.encode(self)
return cast(Dict[str, Any], DataclassJSONEncoder.encode(self))

@classmethod
def from_json(cls: TypingType[T], json_str: str, module_scope: Optional[types.ModuleType] = None) -> T:
def from_json(cls: Type[T], json_str: str, module_scope: Optional[types.ModuleType] = None) -> T:
data = json.loads(json_str)
return cls.from_dict(data, module_scope)

@classmethod
def from_dict(cls: TypingType[T], data: Dict[str, Any], module_scope: Optional[types.ModuleType] = None) -> T:
def from_dict(cls: Type[T], data: Dict[str, Any], module_scope: Optional[types.ModuleType] = None) -> T:
decoder = DataclassJSONDecoder(module_scope or sys.modules[cls.__module__])
return decoder.decode(data, cls)
return cast(T, decoder.decode(data, cls))

0 comments on commit 11e3632

Please sign in to comment.