Skip to content

Commit

Permalink
📝 docs(arrayish): refactor docstring examples (#124)
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Jan 21, 2025
1 parent 011b8b0 commit 3c5fbae
Show file tree
Hide file tree
Showing 8 changed files with 488 additions and 795 deletions.
776 changes: 280 additions & 496 deletions src/quaxed/experimental/_arrayish/binary.py

Large diffs are not rendered by default.

64 changes: 25 additions & 39 deletions src/quaxed/experimental/_arrayish/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
]
# fmt: on

from typing import Protocol, TypeVar, runtime_checkable
from typing import Protocol, runtime_checkable

import quaxed.numpy as qnp

T = TypeVar("T")


@runtime_checkable
class HasShape(Protocol):
Expand All @@ -31,25 +29,22 @@ class LaxLenMixin:
Examples
--------
>>> import jax
>>> import jax.numpy as jnp
>>> from jaxtyping import Array
>>> from quax import ArrayValue
>>> from quaxed.experimental.arrayish import AbstractVal, LaxLenMixin
>>> class MyArray(ArrayValue, LaxLenMixin):
... value: Array
... def aval(self): return jax.core.ShapedArray(self.value.shape, self.value.dtype)
... def materialise(self): return self.value
>>> class Val(AbstractVal, LaxLenMixin):
... v: Array
>>> x = MyArray(jnp.array([1, 2, 3]))
>>> x = Val(jnp.array([1, 2, 3]))
>>> len(x)
3
>>> x = MyArray(jnp.array(1))
>>> x = Val(jnp.array(1))
>>> len(x)
0
""" # noqa: E501
"""

def __len__(self: HasShape) -> int:
return self.shape[0] if self.shape else 0
Expand All @@ -60,25 +55,22 @@ class NumpyLenMixin:
Examples
--------
>>> import jax
>>> import jax.numpy as jnp
>>> from jaxtyping import Array
>>> from quax import ArrayValue
>>> from quaxed.experimental.arrayish import AbstractVal, NumpyLenMixin
>>> class MyArray(ArrayValue, NumpyLenMixin):
... value: Array
... def aval(self): return jax.core.ShapedArray(self.value.shape, self.value.dtype)
... def materialise(self): return self.value
>>> class Val(AbstractVal, NumpyLenMixin):
... v: Array
>>> x = MyArray(jnp.array([1, 2, 3]))
>>> x = Val(jnp.array([1, 2, 3]))
>>> len(x)
3
>>> x = MyArray(jnp.array(1))
>>> x = Val(jnp.array(1))
>>> len(x)
0
""" # noqa: E501
"""

def __len__(self) -> int:
shape = qnp.shape(self)
Expand All @@ -94,25 +86,22 @@ class LaxLengthHintMixin:
Examples
--------
>>> import jax
>>> import jax.numpy as jnp
>>> from jaxtyping import Array
>>> from quax import ArrayValue
>>> from quaxed.experimental.arrayish import AbstractVal, LaxLengthHintMixin
>>> class MyArray(ArrayValue, LaxLengthHintMixin):
... value: Array
... def aval(self): return jax.core.ShapedArray(self.value.shape, self.value.dtype)
... def materialise(self): return self.value
>>> class Val(AbstractVal, LaxLengthHintMixin):
... v: Array
>>> x = MyArray(jnp.array([1, 2, 3]))
>>> x = Val(jnp.array([1, 2, 3]))
>>> x.__length_hint__()
3
>>> x = MyArray(jnp.array(0))
>>> x = Val(jnp.array(0))
>>> x.__length_hint__()
0
""" # noqa: E501
"""

def __length_hint__(self: HasShape) -> int:
return self.shape[0] if self.shape else 0
Expand All @@ -123,25 +112,22 @@ class NumpyLengthHintMixin:
Examples
--------
>>> import jax
>>> import jax.numpy as jnp
>>> from jaxtyping import Array
>>> from quax import ArrayValue
>>> from quaxed.experimental.arrayish import AbstractVal, NumpyLengthHintMixin
>>> class MyArray(ArrayValue, NumpyLengthHintMixin):
... value: Array
... def aval(self): return jax.core.ShapedArray(self.value.shape, self.value.dtype)
... def materialise(self): return self.value
>>> class Val(AbstractVal, NumpyLengthHintMixin):
... v: Array
>>> x = MyArray(jnp.array([1, 2, 3]))
>>> x = Val(jnp.array([1, 2, 3]))
>>> x.__length_hint__()
3
>>> x = MyArray(jnp.array(1))
>>> x = Val(jnp.array(1))
>>> x.__length_hint__()
0
""" # noqa: E501
"""

def __length_hint__(self) -> int:
shape = qnp.shape(self)
Expand Down
27 changes: 10 additions & 17 deletions src/quaxed/experimental/_arrayish/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,18 @@ class NumpyCopyMixin(Generic[RCopy]):
Examples
--------
>>> import copy
>>> import jax
>>> import jax.numpy as jnp
>>> from jaxtyping import Array
>>> from quax import ArrayValue
>>> from quaxed.experimental.arrayish import AbstractVal, NumpyCopyMixin
>>> class MyArray(ArrayValue, NumpyCopyMixin[Any]):
... value: Array
... def aval(self): return jax.core.ShapedArray(self.value.shape, self.value.dtype)
... def materialise(self): return self.value
>>> class Val(AbstractVal, NumpyCopyMixin[Any]):
... v: Array
>>> x = MyArray(jnp.array([1, 2, 3]))
>>> x = Val(jnp.array([1, 2, 3]))
>>> copy.copy(x)
Array([1, 2, 3], dtype=int32)
""" # noqa: E501
"""

def __copy__(self) -> RCopy:
return qnp.copy(self)
Expand All @@ -54,22 +51,18 @@ class NumpyDeepCopyMixin(Generic[RDeepcopy]):
Examples
--------
>>> import copy
>>> import jax
>>> import jax.numpy as jnp
>>> from jaxtyping import Array
>>> from quax import ArrayValue
>>> from quaxed.experimental.arrayish import AbstractVal, NumpyDeepCopyMixin
>>> class MyArray(ArrayValue, NumpyDeepCopyMixin[Any]):
... value: Array
... def aval(self): return jax.core.ShapedArray(self.value.shape, self.value.dtype)
... def materialise(self): return self.value
>>> class Val(AbstractVal, NumpyDeepCopyMixin[Any]):
... v: Array
>>> x = MyArray(jnp.array([1, 2, 3]))
>>> x = Val(jnp.array([1, 2, 3]))
>>> copy.deepcopy(x)
Array([1, 2, 3], dtype=int32)
""" # noqa: E501
"""

def __deepcopy__(self, memo: dict[Any, Any], /) -> RDeepcopy:
return qnp.copy(self)
22 changes: 22 additions & 0 deletions src/quaxed/experimental/_arrayish/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Arrayish."""

__all__: list[str] = ["AbstractVal"]


import equinox as eqx
import jax
from jaxtyping import Array
from quax import ArrayValue


class AbstractVal(ArrayValue): # type: ignore[misc]
"""ABC for example arrayish object."""

#: The array.
v: eqx.AbstractVar[Array]

def aval(self) -> jax.core.ShapedArray:
return jax.core.ShapedArray(self.v.shape, self.v.dtype)

def materialise(self) -> Array:
return self.v
Loading

0 comments on commit 3c5fbae

Please sign in to comment.