diff --git a/src/phantom/_base.py b/src/phantom/_base.py index 63f36a0..8ce7c80 100644 --- a/src/phantom/_base.py +++ b/src/phantom/_base.py @@ -2,6 +2,7 @@ import abc from typing import Any +from typing import Optional from typing import Callable from typing import ClassVar from typing import Generic @@ -107,6 +108,11 @@ class Phantom(PhantomBase, Generic[T]): abstract. * ``abstract: bool`` - Set to ``True`` to create an abstract phantom type. This allows deferring definitions of ``predicate`` and ``bound`` to concrete subtypes. + * ``use_docstring: bool`` - Set to ``True`` to override the schema description + with the class docstring. This will take precedence over any descriptions + given in ``__schema__``. The behavior is inherited (i.e. the respective class + docstring will be used as description) until a subclass sets ``use_docstring`` + back to ``False``. """ __predicate__: Predicate[T] @@ -125,13 +131,26 @@ def __init_subclass__( predicate: Predicate[T] | None = None, bound: type[T] | None = None, abstract: bool = False, + use_docstring: Optional[bool] = None, **kwargs: Any, ) -> None: + if kwargs: + raise RuntimeError("Unknown phantom type argument(s): {kwargs}") + super().__init_subclass__(**kwargs) resolve_class_attr(cls, "__abstract__", abstract) resolve_class_attr(cls, "__predicate__", predicate) cls._resolve_bound(bound) + if use_docstring is not None: # manual override + setattr(cls, "__use_docstring__", use_docstring) + elif not hasattr(cls, "__use_docstring__"): # missing, set default + setattr(cls, "__use_docstring__", False) + if getattr(cls, "__use_docstring__") and not cls.__doc__: + msg = f"{cls} has no docstring, but use_docstring is set or inherited!" + raise RuntimeError(msg) + + @classmethod def _interpret_implicit_bound(cls) -> BoundType: def discover_bounds() -> Iterable[type]: diff --git a/src/phantom/schema.py b/src/phantom/schema.py index 58c2cfd..d919a3c 100644 --- a/src/phantom/schema.py +++ b/src/phantom/schema.py @@ -1,10 +1,13 @@ from typing import Optional from typing import Sequence +from typing_extensions import ClassVar from typing_extensions import Literal from typing_extensions import TypedDict from typing_extensions import final +import textwrap + class Schema(TypedDict, total=False): title: str @@ -22,7 +25,26 @@ class Schema(TypedDict, total=False): maxLength: Optional[int] +def desc_from_docstring(cls) -> str: + # ds = next(filter(lambda x: x.__doc__, cls.__mro__), cls).__doc__ + # ds = (ds or "").strip() + ds = cls.__doc__ + + if not ds: + return "" + lines = ds.split("\n", maxsplit=1) + if len(lines) == 1: + return ds + rest = textwrap.dedent(lines[1]) + if fst := lines[0].strip(): + return f"{fst}\n{rest}" + else: + return rest + + class SchemaField: + __use_docstring__: ClassVar[str] + @classmethod @final def __modify_schema__(cls, field_schema: dict) -> None: @@ -35,6 +57,9 @@ def __modify_schema__(cls, field_schema: dict) -> None: field_schema.update( {key: value for key, value in cls.__schema__().items() if value is not None} ) + if cls.__use_docstring__: + if ds := desc_from_docstring(cls): + field_schema["description"] = ds @classmethod def __schema__(cls) -> Schema: