Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typing #74

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
default_language_version:
python: python3.12
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
Expand All @@ -13,11 +15,17 @@ repos:
- id: no-commit-to-branch
args: ["--branch=main"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.6
rev: v0.9.1
hooks:
- id: ruff
args: ["--fix"]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.14.1
hooks:
- id: mypy
additional_dependencies:
- pynndescent
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
from __future__ import annotations

import os

Expand Down
15 changes: 13 additions & 2 deletions examples/rnn_dbscan_big.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,26 @@

"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from joblib import Memory
from sklearn import metrics
from sklearn.datasets import fetch_openml

from sklearn_ann.cluster.rnn_dbscan import simple_rnn_dbscan_pipeline

if TYPE_CHECKING:
from typing import Any

from sklearn.utils import Bunch


# #############################################################################
# Generate sample data
def fetch_mnist():
def fetch_mnist() -> Bunch:
print("Downloading mnist_784")
mnist = fetch_openml("mnist_784")
return mnist.data / 255, mnist.target
Expand All @@ -28,7 +37,9 @@ def fetch_mnist():
X, y = memory.cache(fetch_mnist)()


def run_rnn_dbscan(neighbor_transformer, n_neighbors, **kwargs):
def run_rnn_dbscan(
neighbor_transformer: object, n_neighbors: int, **kwargs: Any
) -> None:
# #############################################################################
# Compute RnnDBSCAN

Expand Down
1 change: 1 addition & 0 deletions examples/rnn_dbscan_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Mostly copypasted from sklearn's DBSCAN example.

"""
from __future__ import annotations

import numpy as np
from sklearn import metrics
Expand Down
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ select = [
"PTH", # Pathlib
"RUF", # Ruff’s own rules
"T20", # print statements
"TC", # type checking
]
ignore = [
# Don’t complain about “confusables”
Expand All @@ -84,6 +85,10 @@ ignore = [
"tests/*.py" = ["T20"]
[tool.ruff.lint.isort]
known-first-party = ["sklearn_ann"]
required-imports = ["from __future__ import annotations"]
[tool.ruff.lint.flake8-type-checking]
exempt-modules = []
strict = true

[tool.hatch.envs.docs]
installer = "uv"
Expand All @@ -97,6 +102,14 @@ features = ["tests", "annlibs"]
[tool.hatch.build.targets.wheel]
packages = ["src/sklearn_ann"]

[tool.mypy]
python_version = "3.11"
mypy_path = ["src", "tests"]
strict = true
explicit_package_bases = true # pytest doesn’t do __init__.py
no_implicit_optional = true
disallow_untyped_decorators = false # e.g. pytest.mark.parametrize

[build-system]
requires = ["hatchling", "hatch-vcs", "hatch-fancy-pypi-readme"]
build-backend = "hatchling.build"
68 changes: 47 additions & 21 deletions src/sklearn_ann/cluster/rnn_dbscan.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
from __future__ import annotations

from collections import deque
from typing import cast
from typing import TYPE_CHECKING, cast

import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.neighbors import KNeighborsTransformer
from sklearn.utils import Tags
from sklearn.utils.validation import validate_data

from ..utils import get_sparse_row

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Literal, Self

from numpy.typing import NDArray
from sklearn.pipeline import Pipeline


UNCLASSIFIED = -2
NOISE = -1

Expand Down Expand Up @@ -37,7 +48,9 @@ def join(it1, it2):
cur_it2 = next(it2, None)


def neighborhood(is_core, knns, rev_knns, idx):
def neighborhood(
is_core: NDArray[np.bool_], knns: csr_matrix, rev_knns: csr_matrix, idx: int
) -> Iterator[tuple[int, float]]:
# TODO: Make this inner bit faster
knn_it = get_sparse_row(knns, idx)
rev_core_knn_it = (
Expand All @@ -52,10 +65,12 @@ def neighborhood(is_core, knns, rev_knns, idx):
)


def rnn_dbscan_inner(is_core, knns, rev_knns, labels):
def rnn_dbscan_inner(
is_core: NDArray[np.bool_], knns: csr_matrix, rev_knns: csr_matrix, labels
) -> list[float]:
cluster = 0
cur_dens = 0
dens = []
cur_dens = 0.0
dens: list[float] = []
for x_idx in range(len(labels)):
if labels[x_idx] == UNCLASSIFIED:
# Expand cluster
Expand All @@ -81,7 +96,7 @@ def rnn_dbscan_inner(is_core, knns, rev_knns, labels):
elif labels[z_idx] == NOISE:
labels[z_idx] = cluster
dens.append(cur_dens)
cur_dens = 0
cur_dens = 0.0
cluster += 1
else:
labels[x_idx] = NOISE
Expand Down Expand Up @@ -138,15 +153,20 @@ class RnnDBSCAN(ClusterMixin, BaseEstimator):
"""

def __init__(
self, n_neighbors=5, *, input_guarantee="none", n_jobs=None, keep_knns=False
):
self,
n_neighbors: int = 5,
*,
input_guarantee: Literal["none", "kneighbors"] = "none",
n_jobs: int | None = None,
keep_knns: bool = False,
) -> None:
self.n_neighbors = n_neighbors
self.input_guarantee = input_guarantee
self.n_jobs = n_jobs
self.keep_knns = keep_knns

def fit(self, X, y=None):
X = validate_data(self, X, accept_sparse="csr")
def fit(self, X: NDArray[np.float64] | csr_matrix, y: None = None) -> Self:
X = cast(csr_matrix, validate_data(self, X, accept_sparse="csr"))
if self.input_guarantee == "none":
algorithm = KNeighborsTransformer(n_neighbors=self.n_neighbors)
X = algorithm.fit_transform(X)
Expand All @@ -157,7 +177,7 @@ def fit(self, X, y=None):
"Expected input_guarantee to be one of 'none', 'kneighbors'"
)

XT = X.transpose().tocsr(copy=True)
XT = cast(csr_matrix, X.transpose().tocsr(copy=True))
if self.keep_knns:
self.knns_ = X
self.rev_knns_ = XT
Expand All @@ -176,11 +196,11 @@ def fit(self, X, y=None):

return self

def fit_predict(self, X, y=None):
def fit_predict(self, X, y=None) -> NDArray[np.int32]:
self.fit(X, y=y)
return self.labels_

def drop_knns(self):
def drop_knns(self) -> None:
del self.knns_
del self.rev_knns_

Expand All @@ -191,22 +211,28 @@ def __sklearn_tags__(self) -> Tags:


def simple_rnn_dbscan_pipeline(
neighbor_transformer, n_neighbors, n_jobs=None, keep_knns=None, **kwargs
):
neighbor_transformer: object,
n_neighbors: int,
*,
n_jobs: int | None = None,
keep_knns: bool = False,
input_guarantee: Literal["none", "kneighbors"] = "none",
) -> Pipeline:
"""
Create a simple pipeline comprising a transformer and RnnDBSCAN.

Parameters
----------
neighbor_transformer : class implementing KNeighborsTransformer interface
n_neighbors:
neighbor_transformer
class implementing KNeighborsTransformer interface
n_neighbors
Passed to neighbor_transformer and RnnDBSCAN
n_jobs:
n_jobs
Passed to neighbor_transformer and RnnDBSCAN
keep_knns:
keep_knns
Passed to RnnDBSCAN
input_guarantee
Passed to RnnDBSCAN
kwargs:
Passed to neighbor_transformer
"""
from sklearn.pipeline import make_pipeline

Expand Down
2 changes: 2 additions & 0 deletions src/sklearn_ann/kneighbors/annoy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import annoy
import numpy as np
from scipy.sparse import csr_matrix
Expand Down
48 changes: 33 additions & 15 deletions src/sklearn_ann/kneighbors/faiss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
from typing import TYPE_CHECKING, TypedDict

import faiss
import numpy as np
Expand All @@ -13,10 +14,22 @@

from ..utils import TransformerChecksMixin, postprocess_knn_csr

if TYPE_CHECKING:
from typing import Self

from numpy.typing import ArrayLike, NDArray


class MetricInfo(TypedDict):
metric: int
normalize: bool
negate: bool


L2_INFO = {"metric": faiss.METRIC_L2, "sqrt": True}


METRIC_MAP = {
METRIC_MAP: dict[str, MetricInfo] = {
"cosine": {
"metric": faiss.METRIC_INNER_PRODUCT,
"normalize": True,
Expand All @@ -34,7 +47,12 @@
}


def mk_faiss_index(feats, inner_metric, index_key="", nprobe=128) -> faiss.Index:
def mk_faiss_index(
feats: NDArray[np.float32],
inner_metric: int,
index_key: str = "",
nprobe: int = 128,
) -> faiss.Index:
size, dim = feats.shape
if not index_key:
if inner_metric == faiss.METRIC_INNER_PRODUCT:
Expand Down Expand Up @@ -64,15 +82,15 @@ def mk_faiss_index(feats, inner_metric, index_key="", nprobe=128) -> faiss.Index
class FAISSTransformer(TransformerChecksMixin, TransformerMixin, BaseEstimator):
def __init__(
self,
n_neighbors=5,
n_neighbors: int = 5,
*,
metric="euclidean",
index_key="",
n_probe=128,
n_jobs=-1,
include_fwd=True,
include_rev=False,
):
metric: str = "euclidean",
index_key: str = "",
n_probe: int = 128,
n_jobs: int = -1,
include_fwd: bool = True,
include_rev: bool = False,
) -> None:
self.n_neighbors = n_neighbors
self.metric = metric
self.index_key = index_key
Expand All @@ -82,10 +100,10 @@ def __init__(
self.include_rev = include_rev

@property
def _metric_info(self):
def _metric_info(self) -> MetricInfo:
return METRIC_MAP[self.metric]

def fit(self, X, y=None):
def fit(self, X: ArrayLike, y: None = None) -> Self:
normalize = self._metric_info.get("normalize", False)
X = validate_data(self, X, dtype=np.float32, copy=normalize)
self.n_samples_fit_ = X.shape[0]
Expand All @@ -100,14 +118,14 @@ def fit(self, X, y=None):
self.faiss_ = mk_faiss_index(X, inner_metric, self.index_key, self.n_probe)
return self

def transform(self, X):
def transform(self, X: NDArray[np.number]) -> csr_matrix:
normalize = self._metric_info.get("normalize", False)
X = self._transform_checks(X, "faiss_", dtype=np.float32, copy=normalize)
if normalize:
normalize_L2(X)
return self._transform(X)

def _transform(self, X):
def _transform(self, X: NDArray[np.float32]) -> csr_matrix:
n_samples_transform = self.n_samples_fit_ if X is None else X.shape[0]
n_neighbors = self.n_neighbors + 1
if X is None:
Expand Down Expand Up @@ -156,7 +174,7 @@ def _transform(self, X):
mat, include_fwd=self.include_fwd, include_rev=self.include_rev
)

def fit_transform(self, X, y=None):
def fit_transform(self, X: ArrayLike, y: None = None) -> csr_matrix:
return self.fit(X, y=y)._transform(X=None)

def __sklearn_tags__(self) -> Tags:
Expand Down
Loading
Loading