diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 66acbcd..66bc0f3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,5 @@ +default_language_version: + python: python3.12 repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 @@ -13,11 +15,17 @@ repos: - id: no-commit-to-branch args: ["--branch=main"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.9.1 hooks: - id: ruff args: ["--fix"] - id: ruff-format + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.14.1 + hooks: + - id: mypy + additional_dependencies: + - pynndescent - repo: https://github.com/pre-commit/mirrors-prettier rev: v4.0.0-alpha.8 hooks: diff --git a/docs/conf.py b/docs/conf.py index 66b5a00..e684e5b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -3,6 +3,7 @@ # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html +from __future__ import annotations import os diff --git a/examples/rnn_dbscan_big.py b/examples/rnn_dbscan_big.py index d766981..3160b5b 100644 --- a/examples/rnn_dbscan_big.py +++ b/examples/rnn_dbscan_big.py @@ -7,6 +7,10 @@ """ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np from joblib import Memory from sklearn import metrics @@ -14,10 +18,15 @@ from sklearn_ann.cluster.rnn_dbscan import simple_rnn_dbscan_pipeline +if TYPE_CHECKING: + from typing import Any + + from sklearn.utils import Bunch + # ############################################################################# # Generate sample data -def fetch_mnist(): +def fetch_mnist() -> Bunch: print("Downloading mnist_784") mnist = fetch_openml("mnist_784") return mnist.data / 255, mnist.target @@ -28,7 +37,9 @@ def fetch_mnist(): X, y = memory.cache(fetch_mnist)() -def run_rnn_dbscan(neighbor_transformer, n_neighbors, **kwargs): +def run_rnn_dbscan( + neighbor_transformer: object, n_neighbors: int, **kwargs: Any +) -> None: # ############################################################################# # Compute RnnDBSCAN diff --git a/examples/rnn_dbscan_simple.py b/examples/rnn_dbscan_simple.py index 156bd51..4173b9a 100644 --- a/examples/rnn_dbscan_simple.py +++ b/examples/rnn_dbscan_simple.py @@ -8,6 +8,7 @@ Mostly copypasted from sklearn's DBSCAN example. """ +from __future__ import annotations import numpy as np from sklearn import metrics diff --git a/pyproject.toml b/pyproject.toml index 9cdfaa9..d4040e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ select = [ "PTH", # Pathlib "RUF", # Ruff’s own rules "T20", # print statements + "TC", # type checking ] ignore = [ # Don’t complain about “confusables” @@ -84,6 +85,10 @@ ignore = [ "tests/*.py" = ["T20"] [tool.ruff.lint.isort] known-first-party = ["sklearn_ann"] +required-imports = ["from __future__ import annotations"] +[tool.ruff.lint.flake8-type-checking] +exempt-modules = [] +strict = true [tool.hatch.envs.docs] installer = "uv" @@ -97,6 +102,14 @@ features = ["tests", "annlibs"] [tool.hatch.build.targets.wheel] packages = ["src/sklearn_ann"] +[tool.mypy] +python_version = "3.11" +mypy_path = ["src", "tests"] +strict = true +explicit_package_bases = true # pytest doesn’t do __init__.py +no_implicit_optional = true +disallow_untyped_decorators = false # e.g. pytest.mark.parametrize + [build-system] requires = ["hatchling", "hatch-vcs", "hatch-fancy-pypi-readme"] build-backend = "hatchling.build" diff --git a/src/sklearn_ann/cluster/rnn_dbscan.py b/src/sklearn_ann/cluster/rnn_dbscan.py index 0ce7e3b..7eed264 100644 --- a/src/sklearn_ann/cluster/rnn_dbscan.py +++ b/src/sklearn_ann/cluster/rnn_dbscan.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from collections import deque -from typing import cast +from typing import TYPE_CHECKING, cast import numpy as np +from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator, ClusterMixin from sklearn.neighbors import KNeighborsTransformer from sklearn.utils import Tags @@ -9,6 +12,14 @@ from ..utils import get_sparse_row +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Literal, Self + + from numpy.typing import NDArray + from sklearn.pipeline import Pipeline + + UNCLASSIFIED = -2 NOISE = -1 @@ -37,7 +48,9 @@ def join(it1, it2): cur_it2 = next(it2, None) -def neighborhood(is_core, knns, rev_knns, idx): +def neighborhood( + is_core: NDArray[np.bool_], knns: csr_matrix, rev_knns: csr_matrix, idx: int +) -> Iterator[tuple[int, float]]: # TODO: Make this inner bit faster knn_it = get_sparse_row(knns, idx) rev_core_knn_it = ( @@ -52,10 +65,12 @@ def neighborhood(is_core, knns, rev_knns, idx): ) -def rnn_dbscan_inner(is_core, knns, rev_knns, labels): +def rnn_dbscan_inner( + is_core: NDArray[np.bool_], knns: csr_matrix, rev_knns: csr_matrix, labels +) -> list[float]: cluster = 0 - cur_dens = 0 - dens = [] + cur_dens = 0.0 + dens: list[float] = [] for x_idx in range(len(labels)): if labels[x_idx] == UNCLASSIFIED: # Expand cluster @@ -81,7 +96,7 @@ def rnn_dbscan_inner(is_core, knns, rev_knns, labels): elif labels[z_idx] == NOISE: labels[z_idx] = cluster dens.append(cur_dens) - cur_dens = 0 + cur_dens = 0.0 cluster += 1 else: labels[x_idx] = NOISE @@ -138,15 +153,20 @@ class RnnDBSCAN(ClusterMixin, BaseEstimator): """ def __init__( - self, n_neighbors=5, *, input_guarantee="none", n_jobs=None, keep_knns=False - ): + self, + n_neighbors: int = 5, + *, + input_guarantee: Literal["none", "kneighbors"] = "none", + n_jobs: int | None = None, + keep_knns: bool = False, + ) -> None: self.n_neighbors = n_neighbors self.input_guarantee = input_guarantee self.n_jobs = n_jobs self.keep_knns = keep_knns - def fit(self, X, y=None): - X = validate_data(self, X, accept_sparse="csr") + def fit(self, X: NDArray[np.float64] | csr_matrix, y: None = None) -> Self: + X = cast(csr_matrix, validate_data(self, X, accept_sparse="csr")) if self.input_guarantee == "none": algorithm = KNeighborsTransformer(n_neighbors=self.n_neighbors) X = algorithm.fit_transform(X) @@ -157,7 +177,7 @@ def fit(self, X, y=None): "Expected input_guarantee to be one of 'none', 'kneighbors'" ) - XT = X.transpose().tocsr(copy=True) + XT = cast(csr_matrix, X.transpose().tocsr(copy=True)) if self.keep_knns: self.knns_ = X self.rev_knns_ = XT @@ -176,11 +196,11 @@ def fit(self, X, y=None): return self - def fit_predict(self, X, y=None): + def fit_predict(self, X, y=None) -> NDArray[np.int32]: self.fit(X, y=y) return self.labels_ - def drop_knns(self): + def drop_knns(self) -> None: del self.knns_ del self.rev_knns_ @@ -191,22 +211,28 @@ def __sklearn_tags__(self) -> Tags: def simple_rnn_dbscan_pipeline( - neighbor_transformer, n_neighbors, n_jobs=None, keep_knns=None, **kwargs -): + neighbor_transformer: object, + n_neighbors: int, + *, + n_jobs: int | None = None, + keep_knns: bool = False, + input_guarantee: Literal["none", "kneighbors"] = "none", +) -> Pipeline: """ Create a simple pipeline comprising a transformer and RnnDBSCAN. Parameters ---------- - neighbor_transformer : class implementing KNeighborsTransformer interface - n_neighbors: + neighbor_transformer + class implementing KNeighborsTransformer interface + n_neighbors Passed to neighbor_transformer and RnnDBSCAN - n_jobs: + n_jobs Passed to neighbor_transformer and RnnDBSCAN - keep_knns: + keep_knns + Passed to RnnDBSCAN + input_guarantee Passed to RnnDBSCAN - kwargs: - Passed to neighbor_transformer """ from sklearn.pipeline import make_pipeline diff --git a/src/sklearn_ann/kneighbors/annoy.py b/src/sklearn_ann/kneighbors/annoy.py index 67505eb..f975dfe 100644 --- a/src/sklearn_ann/kneighbors/annoy.py +++ b/src/sklearn_ann/kneighbors/annoy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import annoy import numpy as np from scipy.sparse import csr_matrix diff --git a/src/sklearn_ann/kneighbors/faiss.py b/src/sklearn_ann/kneighbors/faiss.py index c349132..e696870 100644 --- a/src/sklearn_ann/kneighbors/faiss.py +++ b/src/sklearn_ann/kneighbors/faiss.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +from typing import TYPE_CHECKING, TypedDict import faiss import numpy as np @@ -13,10 +14,22 @@ from ..utils import TransformerChecksMixin, postprocess_knn_csr +if TYPE_CHECKING: + from typing import Self + + from numpy.typing import ArrayLike, NDArray + + +class MetricInfo(TypedDict): + metric: int + normalize: bool + negate: bool + + L2_INFO = {"metric": faiss.METRIC_L2, "sqrt": True} -METRIC_MAP = { +METRIC_MAP: dict[str, MetricInfo] = { "cosine": { "metric": faiss.METRIC_INNER_PRODUCT, "normalize": True, @@ -34,7 +47,12 @@ } -def mk_faiss_index(feats, inner_metric, index_key="", nprobe=128) -> faiss.Index: +def mk_faiss_index( + feats: NDArray[np.float32], + inner_metric: int, + index_key: str = "", + nprobe: int = 128, +) -> faiss.Index: size, dim = feats.shape if not index_key: if inner_metric == faiss.METRIC_INNER_PRODUCT: @@ -64,15 +82,15 @@ def mk_faiss_index(feats, inner_metric, index_key="", nprobe=128) -> faiss.Index class FAISSTransformer(TransformerChecksMixin, TransformerMixin, BaseEstimator): def __init__( self, - n_neighbors=5, + n_neighbors: int = 5, *, - metric="euclidean", - index_key="", - n_probe=128, - n_jobs=-1, - include_fwd=True, - include_rev=False, - ): + metric: str = "euclidean", + index_key: str = "", + n_probe: int = 128, + n_jobs: int = -1, + include_fwd: bool = True, + include_rev: bool = False, + ) -> None: self.n_neighbors = n_neighbors self.metric = metric self.index_key = index_key @@ -82,10 +100,10 @@ def __init__( self.include_rev = include_rev @property - def _metric_info(self): + def _metric_info(self) -> MetricInfo: return METRIC_MAP[self.metric] - def fit(self, X, y=None): + def fit(self, X: ArrayLike, y: None = None) -> Self: normalize = self._metric_info.get("normalize", False) X = validate_data(self, X, dtype=np.float32, copy=normalize) self.n_samples_fit_ = X.shape[0] @@ -100,14 +118,14 @@ def fit(self, X, y=None): self.faiss_ = mk_faiss_index(X, inner_metric, self.index_key, self.n_probe) return self - def transform(self, X): + def transform(self, X: NDArray[np.number]) -> csr_matrix: normalize = self._metric_info.get("normalize", False) X = self._transform_checks(X, "faiss_", dtype=np.float32, copy=normalize) if normalize: normalize_L2(X) return self._transform(X) - def _transform(self, X): + def _transform(self, X: NDArray[np.float32]) -> csr_matrix: n_samples_transform = self.n_samples_fit_ if X is None else X.shape[0] n_neighbors = self.n_neighbors + 1 if X is None: @@ -156,7 +174,7 @@ def _transform(self, X): mat, include_fwd=self.include_fwd, include_rev=self.include_rev ) - def fit_transform(self, X, y=None): + def fit_transform(self, X: ArrayLike, y: None = None) -> csr_matrix: return self.fit(X, y=y)._transform(X=None) def __sklearn_tags__(self) -> Tags: diff --git a/src/sklearn_ann/kneighbors/nmslib.py b/src/sklearn_ann/kneighbors/nmslib.py index 4736947..9463759 100644 --- a/src/sklearn_ann/kneighbors/nmslib.py +++ b/src/sklearn_ann/kneighbors/nmslib.py @@ -1,5 +1,10 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + import nmslib import numpy as np +from numpy.typing import NDArray from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils import Tags, TargetTags, TransformerTags @@ -7,6 +12,11 @@ from ..utils import TransformerChecksMixin, check_metric +if TYPE_CHECKING: + from typing import Self + + from numpy.typing import ArrayLike + # see more metric in the manual # https://github.com/nmslib/nmslib/tree/master/manual METRIC_MAP = { @@ -22,15 +32,20 @@ class NMSlibTransformer(TransformerChecksMixin, TransformerMixin, BaseEstimator) """Wrapper for using nmslib as sklearn's KNeighborsTransformer""" def __init__( - self, n_neighbors=5, *, metric="euclidean", method="sw-graph", n_jobs=1 - ): + self, + n_neighbors: int = 5, + *, + metric: str = "euclidean", + method: str = "sw-graph", + n_jobs: int = 1, + ) -> None: self.n_neighbors = n_neighbors self.method = method self.metric = metric self.n_jobs = n_jobs - def fit(self, X, y=None): - X = validate_data(self, X) + def fit(self, X: ArrayLike, y: None = None) -> Self: + X = cast(NDArray[np.float64], validate_data(self, X)) self.n_samples_fit_ = X.shape[0] check_metric(self.metric, METRIC_MAP) @@ -41,7 +56,7 @@ def fit(self, X, y=None): self.nmslib_.createIndex() return self - def transform(self, X): + def transform(self, X: NDArray[np.float64]) -> csr_matrix: X = self._transform_checks(X, "nmslib_") n_samples_transform = X.shape[0] diff --git a/src/sklearn_ann/kneighbors/pynndescent.py b/src/sklearn_ann/kneighbors/pynndescent.py index 712261f..dd525b2 100644 --- a/src/sklearn_ann/kneighbors/pynndescent.py +++ b/src/sklearn_ann/kneighbors/pynndescent.py @@ -1,12 +1,20 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from pynndescent import PyNNDescentTransformer as PyNNDescentTransformerBase +if TYPE_CHECKING: + from typing import Self + + from numpy.typing import ArrayLike + -def no_op(): - pass +def no_op() -> None: ... class PyNNDescentTransformer(PyNNDescentTransformerBase): - def fit(self, X, compress_index=True): + def fit(self, X: ArrayLike, compress_index: bool = True) -> Self: super().fit(X, compress_index=compress_index) self.index_.compress_index = no_op return self diff --git a/src/sklearn_ann/kneighbors/sklearn.py b/src/sklearn_ann/kneighbors/sklearn.py index 242374e..d338908 100644 --- a/src/sklearn_ann/kneighbors/sklearn.py +++ b/src/sklearn_ann/kneighbors/sklearn.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from functools import partial from sklearn.neighbors import KNeighborsTransformer diff --git a/src/sklearn_ann/py.typed b/src/sklearn_ann/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/sklearn_ann/test_utils.py b/src/sklearn_ann/test_utils.py index 5a1193a..f85c1bd 100644 --- a/src/sklearn_ann/test_utils.py +++ b/src/sklearn_ann/test_utils.py @@ -1,9 +1,30 @@ +from __future__ import annotations + from enum import Enum from importlib.util import find_spec +from typing import TYPE_CHECKING, overload + +from scipy.sparse import csr_matrix + +if TYPE_CHECKING: + from collections.abc import Callable + from typing import TypeVar + + import numpy as np + import pytest + from numpy.typing import NDArray + + Markable = TypeVar("Markable", bound=Callable[..., object] | type) -def assert_row_close(sp_mat, actual_pdist, row=42, thresh=0.01): +def assert_row_close( + sp_mat: csr_matrix, + actual_pdist: NDArray[np.float64], + row: int = 42, + thresh: float = 0.01, +) -> None: row_mat = sp_mat.getrow(row) + assert isinstance(row_mat, csr_matrix) for col, val in zip(row_mat.indices, row_mat.data): assert abs(actual_pdist[row, col] - val) < thresh @@ -27,7 +48,11 @@ class needs(Enum): nmslib = ("nmslib",) pynndescent = ("pynndescent",) - def __call__(self, fn=None): + @overload + def __call__(self, fn: None = None) -> pytest.MarkDecorator: ... + @overload + def __call__(self, fn: Markable) -> Markable: ... + def __call__(self, fn: Markable | None = None) -> Markable | pytest.MarkDecorator: import pytest what = ( diff --git a/src/sklearn_ann/utils.py b/src/sklearn_ann/utils.py index e275e27..093ab83 100644 --- a/src/sklearn_ann/utils.py +++ b/src/sklearn_ann/utils.py @@ -1,30 +1,34 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np from scipy.sparse import csr_matrix from sklearn.utils.validation import validate_data +if TYPE_CHECKING: + from collections.abc import Container, Iterable + from typing import TypeVar -def check_metric(metric, metrics): - if metric not in metrics: - raise ValueError(f"Unknown metric {metric!r}. Valid metrics are {metrics!r}") + T = TypeVar("T") -def get_sparse_indices(mat, idx): - start_idx = mat.indptr[idx] - end_idx = mat.indptr[idx + 1] - return mat.indices[start_idx:end_idx] +def check_metric(metric: str, metrics: Container[str]) -> None: + if metric not in metrics: + raise ValueError(f"Unknown metric {metric!r}. Valid metrics are {metrics!r}") -def get_sparse_row(mat, idx): +def get_sparse_row(mat: csr_matrix, idx: int) -> Iterable[tuple[int, float]]: start_idx = mat.indptr[idx] end_idx = mat.indptr[idx + 1] return zip(mat.indices[start_idx:end_idx], mat.data[start_idx:end_idx]) -def trunc_csr(csr, k): +def trunc_csr(csr: csr_matrix, k: int) -> csr_matrix: indptr = np.empty_like(csr.indptr) num_rows = len(csr.indptr) - 1 - indices = [None] * num_rows - data = [None] * num_rows + indices = [np.empty(0, dtype=np.float64)] * num_rows + data = [np.empty(0, dtype=np.float64)] * num_rows cur_indptr = 0 for row_idx in range(num_rows): indptr[row_idx] = cur_indptr @@ -39,13 +43,13 @@ def trunc_csr(csr, k): return csr_matrix((np.concatenate(data), np.concatenate(indices), indptr)) -def or_else_csrs(csr1, csr2): +def or_else_csrs(csr1: csr_matrix, csr2: csr_matrix) -> csr_matrix: # Possible TODO: Could use numba/Cython to speed this up? if csr1.shape != csr2.shape: raise ValueError("csr1 and csr2 must be the same shape") indptr = np.empty_like(csr1.indptr) - indices = [] - data = [] + indices: list[int] = [] + data: list[float] = [] for row_idx in range(len(indptr) - 1): indptr[row_idx] = len(indices) csr1_it = iter(get_sparse_row(csr1, row_idx)) @@ -53,29 +57,29 @@ def or_else_csrs(csr1, csr2): cur_csr1 = next(csr1_it, None) cur_csr2 = next(csr2_it, None) while 1: - if cur_csr1 is None and cur_csr2 is None: - break - elif cur_csr1 is None: - cur_index, cur_datum = cur_csr2 - elif cur_csr2 is None: - cur_index, cur_datum = cur_csr1 - elif cur_csr1[0] < cur_csr2[0]: - cur_index, cur_datum = cur_csr1 - cur_csr1 = next(csr1_it, None) - elif cur_csr2[0] < cur_csr1[0]: - cur_index, cur_datum = cur_csr2 - cur_csr2 = next(csr2_it, None) - else: - cur_index, cur_datum = cur_csr1 - cur_csr1 = next(csr1_it, None) - cur_csr2 = next(csr2_it, None) - indices.append(cur_index) + match cur_csr1, cur_csr2: + case None, None: + break + case None, (cur_index, cur_datum): + pass + case (cur_index, cur_datum), None: + pass + case (cur_index, cur_datum), (i2, _) if cur_index < i2: + cur_csr1 = next(csr1_it, None) + case (i1, _), (cur_index, cur_datum) if cur_index < i1: + cur_csr2 = next(csr2_it, None) + case (cur_index, cur_datum), _: # they are equal + cur_csr1 = next(csr1_it, None) + cur_csr2 = next(csr2_it, None) + indices.append(cur_index) # type: ignore[arg-type] # mypy bug data.append(cur_datum) indptr[-1] = len(indices) return csr_matrix((data, indices, indptr), shape=csr1.shape) -def postprocess_knn_csr(knns, include_fwd=True, include_rev=False): +def postprocess_knn_csr( + knns: csr_matrix, *, include_fwd: bool = True, include_rev: bool = False +) -> csr_matrix: if not include_fwd and not include_rev: raise ValueError("One of include_fwd or include_rev must be True") elif include_rev and not include_fwd: @@ -88,7 +92,7 @@ def postprocess_knn_csr(knns, include_fwd=True, include_rev=False): class TransformerChecksMixin: - def _transform_checks(self, X, *fitted_props, **check_params): + def _transform_checks(self, X: T, *fitted_props: str, **check_params: object) -> T: from sklearn.utils.validation import check_is_fitted X = validate_data(self, X, reset=False, **check_params) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index b3d9a75..c150557 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from sklearn.utils.estimator_checks import check_estimator diff --git a/tests/test_kneighbors/conftest.py b/tests/test_kneighbors/conftest.py index 6b45d83..aed524d 100644 --- a/tests/test_kneighbors/conftest.py +++ b/tests/test_kneighbors/conftest.py @@ -1,17 +1,29 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import pytest from numpy.random import default_rng from scipy.spatial.distance import pdist, squareform +if TYPE_CHECKING: + from typing import Literal + + import numpy as np + from numpy.typing import NDArray + @pytest.fixture(scope="module") -def random_small(): +def random_small() -> NDArray[np.float64]: gen = default_rng(42) return 2 * gen.random((64, 128)) - 1 @pytest.fixture(scope="module") -def random_small_pdists(random_small): +def random_small_pdists( + random_small: NDArray[np.float64], +) -> dict[Literal["euclidean", "cosine"], NDArray[np.float64]]: + metrics: list[Literal["euclidean", "cosine"]] = ["euclidean", "cosine"] return { - metric: squareform(pdist(random_small, metric=metric)) - for metric in ["euclidean", "cosine"] + metric: squareform(pdist(random_small, metric=metric)) for metric in metrics } diff --git a/tests/test_kneighbors/test_annoy.py b/tests/test_kneighbors/test_annoy.py index d1c637f..c921496 100644 --- a/tests/test_kneighbors/test_annoy.py +++ b/tests/test_kneighbors/test_annoy.py @@ -1,8 +1,17 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np import pytest from sklearn_ann.test_utils import assert_row_close, needs +if TYPE_CHECKING: + from collections.abc import Mapping + + from numpy.typing import NDArray + try: from sklearn_ann.kneighbors.annoy import AnnoyTransformer except ImportError: @@ -10,7 +19,10 @@ @needs.annoy -def test_euclidean(random_small, random_small_pdists): +def test_euclidean( + random_small: NDArray[np.float64], + random_small_pdists: Mapping[str, NDArray[np.float64]], +) -> None: trans = AnnoyTransformer(metric="euclidean") mat = trans.fit_transform(random_small) euclidean_dist = random_small_pdists["euclidean"] @@ -19,7 +31,10 @@ def test_euclidean(random_small, random_small_pdists): @needs.annoy @pytest.mark.xfail(reason="not sure why this isn't working") -def test_angular(random_small, random_small_pdists): +def test_angular( + random_small: NDArray[np.float64], + random_small_pdists: Mapping[str, NDArray[np.float64]], +) -> None: trans = AnnoyTransformer(metric="angular") mat = trans.fit_transform(random_small) angular_dist = np.arccos(1 - random_small_pdists["cosine"]) diff --git a/tests/test_kneighbors/test_common.py b/tests/test_kneighbors/test_common.py index 65e453c..1046a4b 100644 --- a/tests/test_kneighbors/test_common.py +++ b/tests/test_kneighbors/test_common.py @@ -1,28 +1,42 @@ +from __future__ import annotations + +from importlib.util import find_spec +from typing import TYPE_CHECKING + import numpy as np import pytest +from sklearn.neighbors import KNeighborsTransformer from sklearn.utils.estimator_checks import check_estimator from sklearn_ann.test_utils import needs -try: +if not TYPE_CHECKING: + AnnoyTransformer = FAISSTransformer = NMSlibTransformer = None + PyNNDescentTransformer = KNeighborsTransformer = None +if find_spec("annoy") or TYPE_CHECKING: from sklearn_ann.kneighbors.annoy import AnnoyTransformer -except ImportError: - AnnoyTransformer = "AnnoyTransformer" -try: +if find_spec("faiss") or TYPE_CHECKING: from sklearn_ann.kneighbors.faiss import FAISSTransformer -except ImportError: - FAISSTransformer = "FAISSTransformer" -try: +if find_spec("nmslib") or TYPE_CHECKING: from sklearn_ann.kneighbors.nmslib import NMSlibTransformer -except ImportError: - NMSlibTransformer = "NMSlibTransformer" -try: +if find_spec("pynndescent") or TYPE_CHECKING: from sklearn_ann.kneighbors.pynndescent import PyNNDescentTransformer -except ImportError: - PyNNDescentTransformer = "PyNNDescentTransformer" + from sklearn_ann.kneighbors.sklearn import BallTreeTransformer, KDTreeTransformer -ESTIMATORS = [ +if TYPE_CHECKING: + from _pytest.mark import ParameterSet + from numpy.typing import NDArray + +Estimator = ( + AnnoyTransformer + | FAISSTransformer + | NMSlibTransformer + | PyNNDescentTransformer + | KNeighborsTransformer +) + +ESTIMATORS: ParameterSet[Estimator] = [ pytest.param(AnnoyTransformer, marks=[needs.annoy()]), pytest.param(FAISSTransformer, marks=[needs.faiss()]), pytest.param(NMSlibTransformer, marks=[needs.nmslib()]), @@ -41,7 +55,7 @@ } -def add_mark(param, mark): +def add_mark(param: ParameterSet, mark: pytest.MarkDecorator) -> ParameterSet: return pytest.param(*param.values, marks=[*param.marks, mark], id=param.id) @@ -59,9 +73,9 @@ def add_mark(param, mark): for est in ESTIMATORS ], ) -def test_all_estimators(Estimator): +def test_all_estimators(estim_cls: type[Estimator]) -> None: check_estimator( - Estimator(), + estim_cls(), expected_failed_checks=PER_ESTIMATOR_XFAIL_CHECKS.get(Estimator, {}), ) @@ -81,7 +95,7 @@ def test_all_estimators(Estimator): # (or k+1, as explained in the following note). -def mark_diagonal_0_xfail(est): +def mark_diagonal_0_xfail(est: ParameterSet[Estimator]) -> ParameterSet[Estimator]: """Mark flaky tests as xfail(strict=False).""" # Should probably postprocess these... reasons = { @@ -96,16 +110,18 @@ def mark_diagonal_0_xfail(est): @pytest.mark.parametrize( - "Estimator", [mark_diagonal_0_xfail(est) for est in ESTIMATORS] + "estim_cls", [mark_diagonal_0_xfail(est) for est in ESTIMATORS] ) -def test_all_return_diagonal_0(random_small, Estimator): +def test_all_return_diagonal_0( + random_small: NDArray[np.float64], estim_cls: type[Estimator] +) -> None: # * only explicitly store nearest neighborhoods of each sample with respect to the # training data. This should include those at 0 distance from a query point, # including the matrix diagonal when computing the nearest neighborhoods # between the training data and itself. # Check: do we alway get an "extra" neighbour (diagonal/self) - est = Estimator(n_neighbors=3) + est = estim_cls(n_neighbors=3) knns = est.fit_transform(random_small) assert (knns.getnnz(1) == 4).all() @@ -127,10 +143,10 @@ def test_all_return_diagonal_0(random_small, Estimator): @pytest.mark.parametrize("Estimator", ESTIMATORS) -def test_all_same(random_small, Estimator): +def test_all_same(estim_cls: type[Estimator]) -> None: # Again but for the case of the same element ones = np.ones((64, 4)) - est = Estimator(n_neighbors=3) + est = estim_cls(n_neighbors=3) knns = est.fit_transform(ones) print("knns", knns) assert (knns.getnnz(1) == 4).all() diff --git a/tests/test_kneighbors/test_faiss.py b/tests/test_kneighbors/test_faiss.py index 7f76a96..5bfd7a0 100644 --- a/tests/test_kneighbors/test_faiss.py +++ b/tests/test_kneighbors/test_faiss.py @@ -1,5 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from sklearn_ann.test_utils import assert_row_close, needs +if TYPE_CHECKING: + from collections.abc import Mapping + + import numpy as np + from numpy.typing import NDArray + try: from sklearn_ann.kneighbors.faiss import FAISSTransformer except ImportError: @@ -7,7 +17,10 @@ @needs.faiss -def test_euclidean(random_small, random_small_pdists): +def test_euclidean( + random_small: NDArray[np.float64], + random_small_pdists: Mapping[str, NDArray[np.float64]], +) -> None: trans = FAISSTransformer(metric="euclidean") mat = trans.fit_transform(random_small) euclidean_dist = random_small_pdists["euclidean"] @@ -15,7 +28,10 @@ def test_euclidean(random_small, random_small_pdists): @needs.faiss -def test_cosine(random_small, random_small_pdists): +def test_cosine( + random_small: NDArray[np.float64], + random_small_pdists: Mapping[str, NDArray[np.float64]], +) -> None: trans = FAISSTransformer(metric="cosine") mat = trans.fit_transform(random_small) cosine_dist = random_small_pdists["cosine"] diff --git a/tests/test_kneighbors/test_nmslib.py b/tests/test_kneighbors/test_nmslib.py index a46e4b1..cbc0d80 100644 --- a/tests/test_kneighbors/test_nmslib.py +++ b/tests/test_kneighbors/test_nmslib.py @@ -1,5 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from sklearn_ann.test_utils import assert_row_close, needs +if TYPE_CHECKING: + from collections.abc import Mapping + + import numpy as np + from numpy.typing import NDArray + try: from sklearn_ann.kneighbors.nmslib import NMSlibTransformer except ImportError: @@ -7,7 +17,10 @@ @needs.nmslib -def test_euclidean(random_small, random_small_pdists): +def test_euclidean( + random_small: NDArray[np.float64], + random_small_pdists: Mapping[str, NDArray[np.float64]], +) -> None: trans = NMSlibTransformer(metric="euclidean") mat = trans.fit_transform(random_small) euclidean_dist = random_small_pdists["euclidean"] @@ -15,7 +28,10 @@ def test_euclidean(random_small, random_small_pdists): @needs.nmslib -def test_cosine(random_small, random_small_pdists): +def test_cosine( + random_small: NDArray[np.float64], + random_small_pdists: Mapping[str, NDArray[np.float64]], +) -> None: trans = NMSlibTransformer(metric="cosine") mat = trans.fit_transform(random_small) cosine_dist = random_small_pdists["cosine"]