diff --git a/python/src/cucumber_messages/json_converter.py b/python/src/cucumber_messages/json_converter.py index 8b31c583..fe9b74bf 100644 --- a/python/src/cucumber_messages/json_converter.py +++ b/python/src/cucumber_messages/json_converter.py @@ -6,41 +6,33 @@ from dataclasses import fields, is_dataclass, Field, MISSING from datetime import datetime, date from enum import Enum -from typing import Any, Dict, List, Optional, Union, get_args, get_origin, TypeVar, Type as TypingType, Tuple +from typing import Any, Dict, List, Optional, Union, get_args, get_origin, TypeVar, Type, Tuple, cast T = TypeVar('T') def camel_to_snake(name: str) -> str: """Convert string from camelCase to snake_case.""" - # Validate field name - must start with letter or underscore and contain only alphanumeric and underscore if not name or not (name[0].isalpha() or name[0] == '_') or not all(c.isalnum() or c == '_' for c in name): raise ValueError(f"Invalid field name: {name}") - - # Convert camelCase to snake_case pattern = re.compile(r'(? str: """Convert string from snake_case to camelCase.""" - # Validate field name - must start with letter or underscore and contain only alphanumeric and underscore if not name or not (name[0].isalpha() or name[0] == '_') or not all(c.isalnum() or c == '_' for c in name): raise ValueError(f"Invalid field name: {name}") - - # Convert snake_case to camelCase components = name.split('_') return components[0] + ''.join(x.title() for x in components[1:]) - class GenericTypeResolver: """Handles resolution of generic type hints and container types.""" - def __init__(self, module_scope: types.ModuleType): + def __init__(self, module_scope: types.ModuleType) -> None: self.module_scope = module_scope - self.type_cache = {} + self.type_cache: Dict[str, Type[Any]] = {} - def resolve_container_type(self, container_name: str) -> TypingType: - container_types = { + def resolve_container_type(self, container_name: str) -> Type[Any]: + container_types: Dict[str, Type[Any]] = { 'Sequence': typing.Sequence, 'List': list, 'Set': typing.Set, @@ -51,7 +43,7 @@ def resolve_container_type(self, container_name: str) -> TypingType: raise ValueError(f"Unsupported container type: {container_name}") return container_types[container_name] - def parse_generic_type(self, type_str: str) -> Tuple[TypingType, List[TypingType]]: + def parse_generic_type(self, type_str: str) -> Tuple[Type[Any], List[Type[Any]]]: container_name = type_str[:type_str.index('[')].strip() args_str = type_str[type_str.index('[') + 1:type_str.rindex(']')] @@ -64,34 +56,33 @@ def parse_generic_type(self, type_str: str) -> Tuple[TypingType, List[TypingType return container_type, arg_types - def resolve_type(self, type_str: str) -> TypingType: + def resolve_type(self, type_str: str) -> Type[Any]: if type_str in self.type_cache: return self.type_cache[type_str] if type_str in {'str', 'int', 'float', 'bool', 'Any'}: - resolved = Any if type_str == 'Any' else eval(type_str) + resolved: Type[Any] = Any if type_str == 'Any' else eval(type_str) elif '[' in type_str: container_type, arg_types = self.parse_generic_type(type_str) resolved = container_type[tuple(arg_types) if len(arg_types) > 1 else arg_types[0]] elif '|' in type_str: - resolved = self._resolve_union_type(type_str) + resolved = cast(Type[Any], self._resolve_union_type(type_str)) elif hasattr(self.module_scope, type_str): - resolved = getattr(self.module_scope, type_str) + resolved = cast(Type[Any], getattr(self.module_scope, type_str)) else: raise ValueError(f"Could not resolve type: {type_str}") self.type_cache[type_str] = resolved return resolved - def _resolve_union_type(self, type_str: str) -> TypingType: + def _resolve_union_type(self, type_str: str) -> Type[Any]: types_str = [t.strip() for t in type_str.split('|')] resolved_types = [ self.resolve_type(t) for t in types_str if t != 'None' ] - return resolved_types[0] if len(resolved_types) == 1 else Union[tuple(resolved_types)] - + return cast(Type[Any], resolved_types[0] if len(resolved_types) == 1 else Union[tuple(resolved_types)]) class DataclassJSONEncoder: """Handles encoding of dataclass instances to JSON-compatible dictionaries.""" @@ -118,7 +109,7 @@ def _encode_value(cls, value: Any) -> Any: @classmethod def _encode_dataclass(cls, obj: Any) -> Dict[str, Any]: - result = {} + result: Dict[str, Any] = {} for field in fields(obj): value = getattr(obj, field.name) if value is not None: @@ -128,11 +119,10 @@ def _encode_dataclass(cls, obj: Any) -> Dict[str, Any]: raise ValueError(f"Error encoding field {field.name}: {str(e)}") return result - class DataclassJSONDecoder: """Handles decoding of JSON data to dataclass instances.""" - def __init__(self, module_scope: Optional[types.ModuleType] = None): + def __init__(self, module_scope: Optional[types.ModuleType] = None) -> None: self.module_scope = module_scope or sys.modules[__name__] self.type_resolver = GenericTypeResolver(self.module_scope) @@ -158,10 +148,10 @@ def decode(self, data: Any, target_type: Any) -> Any: return self._decode_dataclass(data, target_type) if origin is datetime or target_type is datetime: - return datetime.fromisoformat(data) + return datetime.fromisoformat(data) if data else None if origin is date or target_type is date: - return date.fromisoformat(data) + return date.fromisoformat(data) if data else None if isinstance(target_type, type) and issubclass(target_type, Enum): try: @@ -171,7 +161,7 @@ def decode(self, data: Any, target_type: Any) -> Any: return data - def _get_type_info(self, type_hint: Any) -> Tuple[Any, Tuple]: + def _get_type_info(self, type_hint: Any) -> Tuple[Any, Tuple[Any, ...]]: origin = get_origin(type_hint) args = get_args(type_hint) if origin is None and isinstance(type_hint, type): @@ -181,18 +171,18 @@ def _get_type_info(self, type_hint: Any) -> Tuple[Any, Tuple]: def _is_sequence_type(self, type_hint: Any) -> bool: origin = get_origin(type_hint) return ( - origin is list or - origin is typing.Sequence or - (isinstance(origin, type) and issubclass(origin, typing.Sequence)) + origin is list or + origin is typing.Sequence or + (isinstance(origin, type) and issubclass(origin, typing.Sequence)) ) - def _decode_sequence(self, data: Any, type_args: Tuple) -> List: + def _decode_sequence(self, data: Any, type_args: Tuple[Any, ...]) -> List[Any]: item_type = type_args[0] if type_args else Any if not isinstance(data, list): data = [data] if data is not None else [] return [self.decode(item, item_type) for item in data] - def _decode_union(self, data: Any, union_types: Tuple) -> Any: + def _decode_union(self, data: Any, union_types: Tuple[Any, ...]) -> Any: types = [t for t in union_types if t is not type(None)] if not types: return data @@ -206,7 +196,7 @@ def _decode_union(self, data: Any, union_types: Tuple) -> Any: continue raise ValueError(f"Could not decode {data} as any type in {types}") - def _decode_dict(self, data: Dict, type_args: Tuple) -> Dict: + def _decode_dict(self, data: Dict[Any, Any], type_args: Tuple[Any, ...]) -> Dict[Any, Any]: if not isinstance(data, dict): raise TypeError(f"Expected dict but got {type(data)}") key_type = type_args[0] if type_args else str @@ -216,16 +206,14 @@ def _decode_dict(self, data: Dict, type_args: Tuple) -> Dict: for k, v in data.items() } - def _decode_dataclass(self, data: Dict[str, Any], target_class: Any) -> Any: - """Decode dictionary into a dataclass instance.""" + def _decode_dataclass(self, data: Dict[str, Any], target_class: Type[Any]) -> Any: if not isinstance(data, dict): raise TypeError(f"Expected dict but got {type(data)}") - field_values = {} + field_values: Dict[str, Any] = {} class_fields = {field.name: field for field in fields(target_class)} - # Create a mapping for both camelCase and original field names - field_mapping = {} + field_mapping: Dict[str, str] = {} for field_name in class_fields: try: camel_name = snake_to_camel(field_name) @@ -234,16 +222,12 @@ def _decode_dataclass(self, data: Dict[str, Any], target_class: Any) -> Any: except ValueError: continue - # Process input fields for key, value in data.items(): - # Try to map the input key to a field name field_name = None - # First try direct mapping if key in field_mapping: field_name = field_mapping[key] else: - # Try converting the key try: snake_key = camel_to_snake(key) if snake_key in class_fields: @@ -258,30 +242,27 @@ def _decode_dataclass(self, data: Dict[str, Any], target_class: Any) -> Any: except Exception as e: raise TypeError(f"Error decoding field {key}: {str(e)}") - # Check for missing required fields missing_required = [ name for name, field in class_fields.items() if name not in field_values - and field.default is MISSING - and field.default_factory is MISSING + and field.default is MISSING + and field.default_factory is MISSING ] if missing_required: raise TypeError(f"Missing required arguments: {', '.join(missing_required)}") - # Add default values for missing optional fields self._apply_default_values(field_values, class_fields) return target_class(**field_values) def _apply_default_values(self, field_values: Dict[str, Any], class_fields: Dict[str, Field]) -> None: for field_name, field in class_fields.items(): if field_name not in field_values: - if field.default is not Field: + if field.default is not MISSING: field_values[field_name] = field.default - elif field.default_factory is not Field: + elif field.default_factory is not MISSING: field_values[field_name] = field.default_factory() - class JSONDataclassMixin: """Mixin class to add JSON conversion capabilities to dataclasses.""" @@ -289,14 +270,14 @@ def to_json(self) -> str: return json.dumps(self.to_dict()) def to_dict(self) -> Dict[str, Any]: - return DataclassJSONEncoder.encode(self) + return cast(Dict[str, Any], DataclassJSONEncoder.encode(self)) @classmethod - def from_json(cls: TypingType[T], json_str: str, module_scope: Optional[types.ModuleType] = None) -> T: + def from_json(cls: Type[T], json_str: str, module_scope: Optional[types.ModuleType] = None) -> T: data = json.loads(json_str) return cls.from_dict(data, module_scope) @classmethod - def from_dict(cls: TypingType[T], data: Dict[str, Any], module_scope: Optional[types.ModuleType] = None) -> T: + def from_dict(cls: Type[T], data: Dict[str, Any], module_scope: Optional[types.ModuleType] = None) -> T: decoder = DataclassJSONDecoder(module_scope or sys.modules[cls.__module__]) - return decoder.decode(data, cls) \ No newline at end of file + return cast(T, decoder.decode(data, cls)) \ No newline at end of file