Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type annotations and use stricter pyright settings #291

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ repos:
- id: check-ast
- id: check-executables-have-shebangs
- id: check-json
exclude: ^pyrightconfig\.json$
- id: check-merge-conflict
- id: check-shebang-scripts-are-executable
- id: check-symlinks
Expand All @@ -33,7 +34,7 @@ repos:
- id: pretty-format-json
args: [ --autofix, --no-sort-keys ]
# ignore jupyter notebooks
exclude: ^.*\.ipynb$
exclude: (^.*\.ipynb$|^pyrightconfig\.json$)
- id: pretty-format-json
args: [ --autofix, --no-sort-keys, --indent=1, --no-ensure-ascii ]
# only jupyter notebooks
Expand Down
4 changes: 3 additions & 1 deletion cicd_utils/cicd/compile_plotly_charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@

from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING

from minify_html import minify
from plotly.offline import get_plotlyjs

from ridgeplot_examples import ALL_EXAMPLES, tighten_margins

if TYPE_CHECKING:
from collections.abc import Callable

import plotly.graph_objects as go


Expand Down
5 changes: 4 additions & 1 deletion cicd_utils/ridgeplot_examples/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

from typing import Callable
from typing import TYPE_CHECKING

import plotly.graph_objects as go

if TYPE_CHECKING:
from collections.abc import Callable


def tighten_margins(fig: go.Figure, px: int = 0) -> go.Figure:
"""Tighten the margins of a Plotly figure."""
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
try:
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata # pyright: ignore[no-redef]
import importlib_metadata

try:
from cicd.compile_plotly_charts import compile_plotly_charts
Expand Down
4 changes: 3 additions & 1 deletion docs/reference/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ This document outlines the list of changes to ridgeplot between each release. Fo
Unreleased changes
------------------

- ...
### Internal

- Improve type annotations and use stricter pyright settings ({gh-pr}`291`)

---

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies = [
"numpy>=1",
"plotly>=5.20", # The `fillgradient` option was added in 5.20
"statsmodels>=0.12,!=0.14.2", # See GH197 for details
'typing-extensions; python_version<"3.13"',
"typing-extensions",
'importlib-resources; python_version<"3.10"',
]

Expand Down
21 changes: 19 additions & 2 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// https://microsoft.github.io/pyright/#/configuration
{
"include": [
"cicd_utils",
"docs",
"misc",
"src",
"tests",
"docs",
"cicd_utils"
"*.py"
],
"exclude": [
"docs/build",
Expand All @@ -12,7 +15,21 @@
"extraPaths": [
"cicd_utils"
],
"pythonVersion": "3.9",
"pythonPlatform": "All",
"typeCheckingMode": "strict",
// stricter settings
"deprecateTypingAliases": true,
"reportMissingModuleSource": "error",
"reportCallInDefaultInitializer": "error",
"reportImplicitOverride": "error",
"reportImportCycles": "error",
"reportMissingSuperCall": "warning",
"reportPropertyTypeMismatch": "error",
"reportShadowedImports": "error",
"reportUninitializedInstanceVariable": "error",
"reportUnnecessaryTypeIgnoreComment": "error",
// turn off some of the strictest settings
"reportMissingTypeStubs": "none",
"reportUnknownMemberType": "none",
"reportUnknownArgumentType": "none",
Expand Down
5 changes: 4 additions & 1 deletion src/ridgeplot/_color/colorscale.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, cast

import plotly.express as px
from _plotly_utils.basevalidators import ColorscaleValidator as _ColorscaleValidator
from typing_extensions import Any, override

from ridgeplot._color.utils import default_plotly_template
from ridgeplot._types import Color, ColorScale
Expand All @@ -18,13 +19,15 @@ def __init__(self) -> None:
super().__init__("colorscale", "ridgeplot")

@property
@override
def named_colorscales(self) -> dict[str, list[str]]:
named_colorscales = cast(dict[str, list[str]], super().named_colorscales)
if "default" not in named_colorscales:
# Add 'default' for backwards compatibility
named_colorscales["default"] = px.colors.DEFAULT_PLOTLY_COLORS
return named_colorscales

@override
def validate_coerce(self, v: Any) -> ColorScale:
coerced = super().validate_coerce(v)
if coerced is None: # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion src/ridgeplot/_color/css_colors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Literal
from typing_extensions import Literal

# Taken from https://www.w3.org/TR/css-color-3/#svg-color

Expand Down
4 changes: 3 additions & 1 deletion src/ridgeplot/_color/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Protocol
from typing import TYPE_CHECKING

from typing_extensions import Literal, Protocol

from ridgeplot._color.utils import apply_alpha, round_color, to_rgb, unpack_rgb
from ridgeplot._types import CollectionL2, ColorScale
Expand Down
3 changes: 2 additions & 1 deletion src/ridgeplot/_figure_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, cast

from plotly import graph_objects as go
from typing_extensions import Literal

from ridgeplot._color.colorscale import validate_coerce_colorscale
from ridgeplot._color.interpolation import (
Expand Down
11 changes: 3 additions & 8 deletions src/ridgeplot/_kde.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from __future__ import annotations

import sys
from collections.abc import Collection
from collections.abc import Callable, Collection
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Union, cast
from typing import TYPE_CHECKING, Union, cast

import numpy as np
import numpy.typing as npt
import statsmodels.api as sm
from statsmodels.sandbox.nonparametric.kernels import CustomKernel as StatsmodelsKernel

if sys.version_info >= (3, 13):
from typing import TypeIs
else:
from typing_extensions import TypeIs
from typing_extensions import Any, TypeIs

from ridgeplot._types import (
CollectionL1,
Expand Down
10 changes: 3 additions & 7 deletions src/ridgeplot/_missing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from __future__ import annotations

import sys
from enum import Enum
from typing import Final, Literal

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from typing import Final

from typing_extensions import Literal, TypeAlias, override

if "_is_loaded" in globals():
raise RuntimeError("Reloading ridgeplot._missing is not allowed")
Expand Down Expand Up @@ -44,6 +39,7 @@ class _Missing(Enum):

MISSING = "MISSING"

@override
def __repr__(self) -> str:
return "<MISSING>"

Expand Down
4 changes: 3 additions & 1 deletion src/ridgeplot/_obj/traces/area.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import Any, ClassVar
from typing import ClassVar

from plotly import graph_objects as go
from typing_extensions import Any, override

from ridgeplot._color.interpolation import slice_colorscale
from ridgeplot._color.utils import apply_alpha
Expand Down Expand Up @@ -49,6 +50,7 @@ def _get_coloring_kwargs(self, ctx: ColoringContext) -> dict[str, Any]:
)
return color_kwargs

@override
def draw(self, fig: go.Figure, coloring_ctx: ColoringContext) -> go.Figure:
# Draw an invisible trace at constance y=y_base so that we
# can set fill="tonexty" below and get a filled area plot
Expand Down
4 changes: 3 additions & 1 deletion src/ridgeplot/_obj/traces/bar.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from typing import Any, ClassVar
from typing import ClassVar

from plotly import graph_objects as go
from typing_extensions import Any, override

from ridgeplot._color.interpolation import interpolate_color
from ridgeplot._obj.traces.base import DEFAULT_HOVERTEMPLATE, ColoringContext, RidgeplotTrace
Expand Down Expand Up @@ -33,6 +34,7 @@ def _get_coloring_kwargs(self, ctx: ColoringContext) -> dict[str, Any]:
)
return color_kwargs

@override
def draw(self, fig: go.Figure, coloring_ctx: ColoringContext) -> go.Figure:
fig.add_trace(
go.Bar(
Expand Down
4 changes: 3 additions & 1 deletion src/ridgeplot/_obj/traces/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Literal
from typing import TYPE_CHECKING, ClassVar

from typing_extensions import Literal

from ridgeplot._vendor.more_itertools import zip_strict

Expand Down
2 changes: 1 addition & 1 deletion src/ridgeplot/_ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

if TYPE_CHECKING:
from collections.abc import Collection
from typing import Literal

import plotly.graph_objects as go
from typing_extensions import Literal

from ridgeplot._color.interpolation import SolidColormode
from ridgeplot._kde import (
Expand Down
11 changes: 3 additions & 8 deletions src/ridgeplot/_types.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from __future__ import annotations

import sys
from collections.abc import Collection
from typing import Any, Literal, Optional, TypeVar, Union
from typing import Optional, Union

import numpy as np

if sys.version_info >= (3, 13):
from typing import TypeIs
else:
from typing_extensions import TypeIs
from typing_extensions import Any, Literal, TypeIs, TypeVar

# Snippet used to generate and store the image artefacts:
# >>> def save_fig(fig, name):
Expand Down Expand Up @@ -547,7 +542,7 @@ def is_trace_type(obj: Any) -> TypeIs[TraceType]:
>>> is_trace_type(42)
False
"""
from typing import get_args
from typing_extensions import get_args

return isinstance(obj, str) and obj in get_args(TraceType)

Expand Down
7 changes: 4 additions & 3 deletions src/ridgeplot/_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from collections.abc import Collection
from typing import (
TYPE_CHECKING,
from typing import TYPE_CHECKING

from typing_extensions import (
TypeVar,
)

if TYPE_CHECKING:
from typing import Any
from typing_extensions import Any

from ridgeplot._types import CollectionL2, Densities, NormalisationOption, Numeric

Expand Down
4 changes: 3 additions & 1 deletion src/ridgeplot/_vendor/more_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
import functools
import sys
from itertools import zip_longest
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from typing_extensions import Any

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down
2 changes: 1 addition & 1 deletion src/ridgeplot/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from importlib_resources import as_file, files

if TYPE_CHECKING:
from typing import Literal
from typing_extensions import Literal

import pandas as pd

Expand Down
4 changes: 3 additions & 1 deletion tests/e2e/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING

import pytest

from ridgeplot_examples import ALL_EXAMPLES

if TYPE_CHECKING:
from collections.abc import Callable

import plotly.graph_objects as go

ROOT_DIR = Path(__file__).parents[2]
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/color/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_interpolate_color_p_not_in_scale(viridis_colorscale: ColorScale) -> Non
@pytest.mark.parametrize("p", [-10.0, -1.3, 1.9, 100.0])
def test_interpolate_color_fails_for_p_out_of_bounds(p: float) -> None:
with pytest.raises(ValueError, match="should be a float value between 0 and 1"):
interpolate_color(colorscale=..., p=p) # pyright: ignore[reportArgumentType]
interpolate_color(colorscale=..., p=p)


# ==============================================================
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/test_figure_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ def test_densities_must_be_4d(self, densities: Densities) -> None:
with pytest.raises(ValueError, match="Expected a 4D array of densities"):
create_ridgeplot(
densities=densities,
trace_types=..., # pyright: ignore[reportArgumentType]
colorscale=..., # pyright: ignore[reportArgumentType]
opacity=..., # pyright: ignore[reportArgumentType]
colormode=..., # pyright: ignore[reportArgumentType]
trace_labels=..., # pyright: ignore[reportArgumentType]
line_color=..., # pyright: ignore[reportArgumentType]
line_width=..., # pyright: ignore[reportArgumentType]
spacing=..., # pyright: ignore[reportArgumentType]
show_yticklabels=..., # pyright: ignore[reportArgumentType]
xpad=..., # pyright: ignore[reportArgumentType]
trace_types=...,
colorscale=...,
opacity=...,
colormode=...,
trace_labels=...,
line_color=...,
line_width=...,
spacing=...,
show_yticklabels=...,
xpad=...,
)
Loading
Loading