Skip to content

Commit

Permalink
Merge pull request #233 from tpvasconcelos/plotly-colorscale-validator
Browse files Browse the repository at this point in the history
Update `validate_and_coerce_colorscale` to use Plotly's `ColorscaleValidator`
  • Loading branch information
tpvasconcelos authored Oct 18, 2024
2 parents 5bf9fda + edfd9a1 commit 1bde523
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 316 deletions.
2 changes: 1 addition & 1 deletion docs/_static/charts/basic.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/_static/charts/lincoln_weather.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/_static/charts/probly.html

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/reference/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Unreleased changes
- Eagerly validate input shapes in `RidgeplotFigureFactory` ({gh-pr}`222`)
- Vendor `_zip_equal()` from [more-itertools](https://github.com/more-itertools/more-itertools/) ({gh-pr}`222`)
- Improve overall test coverage ({gh-pr}`222`)
- Refactor color validation logic to use helpers provided by Plotly ({gh-pr}`233`)

### Bug fixes

Expand Down
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ python_version = 3.9

# Import discovery ---
files = src/ridgeplot, tests, docs, cicd_utils
namespace_packages = False
mypy_path = cicd_utils
namespace_packages = True
explicit_package_bases = True
exclude = ^docs/build

# Disallow dynamic typing ---
Expand Down
1 change: 1 addition & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pytest and plugins
pytest
pytest-icdiff
pytest-socket

# Coverage
Expand Down
14 changes: 12 additions & 2 deletions src/ridgeplot/_colormodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Protocol

from ridgeplot._colors import ColorScale, apply_alpha, interpolate_color, normalise_colorscale
from ridgeplot._colors import (
ColorScale,
apply_alpha,
interpolate_color,
round_color,
validate_and_coerce_colorscale,
)
from ridgeplot._types import CollectionL2
from ridgeplot._utils import get_xy_extrema, normalise_min_max

Expand Down Expand Up @@ -131,14 +137,18 @@ def compute_trace_colors(
coloralpha: float | None,
interpolation_ctx: InterpolationContext,
) -> ColorsArray:
colorscale = normalise_colorscale(colorscale)
colorscale = validate_and_coerce_colorscale(colorscale)
if coloralpha is not None:
coloralpha = float(coloralpha)

def _get_color(p: float) -> str:
color = interpolate_color(colorscale, p=p)
if coloralpha is not None:
color = apply_alpha(color, alpha=coloralpha)
# This helps us avoid floating point errors when making
# comparisons in our test suite. The user should not
# be able to notice *any* difference in the output
color = round_color(color, ndigits=12)
return color

if colormode not in COLORMODE_MAPS:
Expand Down
195 changes: 68 additions & 127 deletions src/ridgeplot/_colors.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
from __future__ import annotations

import json
import sys
from collections.abc import Collection
from pathlib import Path
from typing import Union, cast
from typing import Any, Union, cast

from _plotly_utils.colors import validate_colors, validate_scale_values
from _plotly_utils.basevalidators import ColorscaleValidator as _ColorscaleValidator
from _plotly_utils.colors import validate_colors
from plotly.colors import find_intermediate_color, hex_to_rgb

if sys.version_info >= (3, 13):
from typing import TypeIs
else:
from typing_extensions import TypeIs

from ridgeplot._css_colors import CSS_NAMED_COLORS, CssNamedColor
from ridgeplot._utils import LazyMapping, get_collection_array_shape, normalise_min_max
from ridgeplot._utils import LazyMapping, normalise_min_max

_PATH_TO_COLORS_JSON = Path(__file__).parent.joinpath("colors.json")

Expand All @@ -25,7 +20,7 @@
CSS color string - including hex, rgb/a, hsl/a, hsv/a, and named CSS colors."""

ColorScale = Collection[tuple[float, Color]]
"""The canonical form for a color scale represented by a collection of tuples of
"""The canonical form for a color scale is represented by a list of tuples of
two elements:
0. the first element (a *scale value*) is a float bounded to the
Expand All @@ -34,17 +29,18 @@
For instance, the Viridis color scale can be represented as:
>>> get_colorscale("Viridis")
((0.0, 'rgb(68, 1, 84)'),
(0.1111111111111111, 'rgb(72, 40, 120)'),
(0.2222222222222222, 'rgb(62, 73, 137)'),
(0.3333333333333333, 'rgb(49, 104, 142)'),
(0.4444444444444444, 'rgb(38, 130, 142)'),
(0.5555555555555556, 'rgb(31, 158, 137)'),
(0.6666666666666666, 'rgb(53, 183, 121)'),
(0.7777777777777777, 'rgb(110, 206, 88)'),
(0.8888888888888888, 'rgb(181, 222, 43)'),
(1.0, 'rgb(253, 231, 37)'))
>>> viridis: ColorScale = [
(0.0, 'rgb(68, 1, 84)'),
(0.1111111111111111, 'rgb(72, 40, 120)'),
(0.2222222222222222, 'rgb(62, 73, 137)'),
(0.3333333333333333, 'rgb(49, 104, 142)'),
(0.4444444444444444, 'rgb(38, 130, 142)'),
(0.5555555555555556, 'rgb(31, 158, 137)'),
(0.6666666666666666, 'rgb(53, 183, 121)'),
(0.7777777777777777, 'rgb(110, 206, 88)'),
(0.8888888888888888, 'rgb(181, 222, 43)'),
(1.0, 'rgb(253, 231, 37)')
]
"""


Expand All @@ -58,32 +54,41 @@ def _colormap_loader() -> dict[str, ColorScale]:
_COLORSCALE_MAPPING: LazyMapping[str, ColorScale] = LazyMapping(loader=_colormap_loader)


def is_canonical_colorscale(
colorscale: ColorScale | Collection[Color] | str,
) -> TypeIs[ColorScale]:
if isinstance(colorscale, str) or not isinstance(colorscale, Collection):
return False
shape = get_collection_array_shape(colorscale)
if not (len(shape) == 2 and shape[1] == 2):
return False
scale, colors = zip(*colorscale)
return (
all(isinstance(s, (int, float)) for s in scale) and
all(isinstance(c, (str, tuple)) for c in colors)
) # fmt: skip


def validate_canonical_colorscale(colorscale: ColorScale) -> None:
"""Validate the structure, scale values, and colors of a colorscale in the
canonical format."""
if not is_canonical_colorscale(colorscale):
raise TypeError(
"The colorscale should be a collection of tuples of two elements: "
"a scale value and a color."
)
scale, colors = zip(*colorscale)
validate_scale_values(scale=scale)
validate_colors(colors=colors)
def list_all_colorscale_names() -> list[str]:
"""Get a list with all available colorscale names.
.. versionadded:: 0.1.21
Replaced the deprecated function ``get_all_colorscale_names()``.
Returns
-------
list[str]
A list with all available colorscale names.
"""
return sorted(_COLORSCALE_MAPPING.keys())


class ColorscaleValidator(_ColorscaleValidator): # type: ignore[misc]
def __init__(self) -> None:
super().__init__("colorscale", "ridgeplot")

@property
def named_colorscales(self) -> dict[str, list[Color]]:
return {
name: [c for _, c in colorscale] for name, colorscale in _COLORSCALE_MAPPING.items()
}

def validate_coerce(self, v: Any) -> ColorScale:
coerced = super().validate_coerce(v)
if coerced is None:
self.raise_invalid_val(coerced)
return cast(ColorScale, [tuple(c) for c in coerced])


def validate_and_coerce_colorscale(colorscale: ColorScale | Collection[Color] | str) -> ColorScale:
"""Convert mixed colorscale representations to the canonical
:data:`ColorScale` format."""
return ColorscaleValidator().validate_coerce(colorscale)


def _any_to_rgb(color: Color) -> str:
Expand Down Expand Up @@ -114,7 +119,7 @@ def _any_to_rgb(color: Color) -> str:
rgb = f"rgb({r}, {g}, {b})"
elif color.startswith("#"):
return _any_to_rgb(cast(str, hex_to_rgb(color)))
elif color.startswith("rgb("):
elif color.startswith(("rgb(", "rgba(")):
rgb = color
elif color in CSS_NAMED_COLORS:
color = cast(CssNamedColor, color)
Expand All @@ -128,84 +133,6 @@ def _any_to_rgb(color: Color) -> str:
return rgb


def list_all_colorscale_names() -> list[str]:
"""Get a list with all available colorscale names.
.. versionadded:: 0.1.21
Replaced the deprecated function ``get_all_colorscale_names()``.
Returns
-------
list[str]
A list with all available colorscale names.
"""
return sorted(_COLORSCALE_MAPPING.keys())


def get_colorscale(name: str) -> ColorScale:
"""Get a colorscale by name.
Parameters
----------
name
The colorscale name. This argument is case-insensitive. For instance,
``"YlOrRd"`` and ``"ylorrd"`` map to the same colorscale. Colorscale
names ending in ``_r`` represent a *reversed* colorscale.
Returns
-------
ColorScale
A colorscale.
Raises
------
:exc:`ValueError`
If an unknown name is provided
"""
name = name.lower()
if name not in _COLORSCALE_MAPPING:
raise ValueError(
f"Unknown color scale name: '{name}'. The available color scale "
f"names are {tuple(_COLORSCALE_MAPPING.keys())}."
)
return _COLORSCALE_MAPPING[name]


def canonical_colorscale_from_list(colors: Collection[Color]) -> ColorScale:
"""Infer a colorscale from a list of colors.
Parameters
----------
colors
An collection of :data:`Color` values.
Returns
-------
ColorScale
A colorscale with the same colors as the input list, but with
scale values evenly spaced between 0 and 1.
"""
colors = list(colors)
n_colors = len(colors)
scale = [i / (n_colors - 1) for i in range(n_colors)]
scale[-1] = 1.0 # Avoid floating point errors
return tuple(zip(scale, colors))


def normalise_colorscale(colorscale: ColorScale | Collection[Color] | str) -> ColorScale:
"""Convert mixed colorscale representations to the canonical
:data:`ColorScale` format."""
if isinstance(colorscale, str):
return get_colorscale(name=colorscale)
if is_canonical_colorscale(colorscale):
validate_canonical_colorscale(colorscale)
return colorscale
# There is a bug in mypy that results in the type narrowing not working
# properly here. See https://github.com/python/mypy/issues/17181
colorscale = canonical_colorscale_from_list(colors=colorscale) # type: ignore[unreachable]
return colorscale


def interpolate_color(colorscale: ColorScale, p: float) -> str:
"""Get a color from a colorscale at a given interpolation point ``p``."""
if not (0 <= p <= 1):
Expand All @@ -231,6 +158,20 @@ def interpolate_color(colorscale: ColorScale, p: float) -> str:
)


def _unpack_rgb(rgb: str) -> tuple[float, float, float, float] | tuple[float, float, float]:
prefix = rgb.split("(")[0] + "("
values_str = map(str.strip, rgb.removeprefix(prefix).removesuffix(")").split(","))
values_num = tuple(int(v) if v.isdecimal() else float(v) for v in values_str)
return values_num # type: ignore[return-value]


def apply_alpha(color: Color, alpha: float) -> str:
values = _unpack_rgb(_any_to_rgb(color))
return f"rgba({', '.join(map(str, values[:3]))}, {alpha})"


def round_color(color: Color, ndigits: int) -> str:
color = _any_to_rgb(color)
return f"rgba({color[4:-1]}, {alpha})"
prefix = color.split("(")[0] + "("
values_round = tuple(v if isinstance(v, int) else round(v, ndigits) for v in _unpack_rgb(color))
return f"{prefix}{', '.join(map(str, values_round))})"
69 changes: 69 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

if TYPE_CHECKING:
from collections.abc import Collection

from ridgeplot._colors import Color, ColorScale

VIRIDIS = (
(0.0, "rgb(68, 1, 84)"),
(0.1111111111111111, "rgb(72, 40, 120)"),
(0.2222222222222222, "rgb(62, 73, 137)"),
(0.3333333333333333, "rgb(49, 104, 142)"),
(0.4444444444444444, "rgb(38, 130, 142)"),
(0.5555555555555556, "rgb(31, 158, 137)"),
(0.6666666666666666, "rgb(53, 183, 121)"),
(0.7777777777777777, "rgb(110, 206, 88)"),
(0.8888888888888888, "rgb(181, 222, 43)"),
(1.0, "rgb(253, 231, 37)"),
)


@pytest.fixture(scope="session")
def viridis_colorscale() -> ColorScale:
return VIRIDIS


VALID_COLOR_SCALES = [
(VIRIDIS, VIRIDIS),
("viridis", VIRIDIS),
(list(zip(*VIRIDIS))[-1], VIRIDIS),
# List of colors
(["red", "green"], [[0, "red"], [1, "green"]]),
# List of lists
tuple([[[0, "red"], [1, "green"]]] * 2),
# Tuple of tuples
tuple([((0, "red"), (1, "green"))] * 2),
# List of tuples
tuple([[(0, "red"), (0.5, "blue"), (1, "green")]] * 2),
]


@pytest.fixture(scope="session", params=VALID_COLOR_SCALES)
def valid_colorscale(
request: pytest.FixtureRequest,
) -> tuple[ColorScale | Collection[Color] | str, ColorScale]:
return request.param # type: ignore[no-any-return]


INVALID_COLOR_SCALES = [
None,
1,
(1, 2, 3),
VIRIDIS[0],
((1, 2, 3), (4, 5, 6)),
(("a", 1), ("b", 2)),
[(0, "red"), (1.2, "green")],
[(0, "red"), (1, "green", "blue")],
["red", "invalid"],
[[0, "red"], [1, "whodis"]],
]


@pytest.fixture(scope="session", params=INVALID_COLOR_SCALES)
def invalid_colorscale(request: pytest.FixtureRequest) -> ColorScale | Collection[Color] | str:
return request.param # type: ignore[no-any-return]
Loading

0 comments on commit 1bde523

Please sign in to comment.