from __future__ import annotations

import math
import os
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from copy import copy
from functools import partial
from pathlib import Path
from types import MappingProxyType
from typing import Any, Literal

import dask
import datashader as ds
import matplotlib
import matplotlib.patches as mpatches
import matplotlib.path as mpath
import matplotlib.pyplot as plt
import matplotlib.ticker
import matplotlib.transforms as mtransforms
import numpy as np
import numpy.ma as ma
import numpy.typing as npt
import pandas as pd
import shapely
import spatialdata as sd
from anndata import AnnData
from cycler import Cycler, cycler
from datashader.core import Canvas
from geopandas import GeoDataFrame
from matplotlib import colors, patheffects, rcParams
from matplotlib.axes import Axes
from matplotlib.collections import PatchCollection
from matplotlib.colors import (
    ColorConverter,
    Colormap,
    LinearSegmentedColormap,
    ListedColormap,
    Normalize,
    to_rgba,
)
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from matplotlib.transforms import CompositeGenericTransform
from matplotlib_scalebar.scalebar import ScaleBar
from numpy.ma.core import MaskedArray
from numpy.random import default_rng
from pandas.api.types import CategoricalDtype, is_bool_dtype, is_numeric_dtype, is_string_dtype
from pandas.core.arrays.categorical import Categorical
from scanpy import settings
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation
from scanpy.plotting.palettes import default_20, default_28, default_102
from scipy.spatial import ConvexHull
from shapely.errors import GEOSException
from skimage.color import label2rgb
from skimage.morphology import erosion, square
from skimage.segmentation import find_boundaries
from skimage.util import map_array
from spatialdata import (
    SpatialData,
    get_element_annotators,
    get_extent,
    get_values,
    rasterize,
)
from spatialdata._core.query.relational_query import _locate_value
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement, get_table_keys
from spatialdata.transformations.operations import get_transformation
from spatialdata.transformations.transformations import Scale
from xarray import DataArray, DataTree

from spatialdata_plot._logging import logger
from spatialdata_plot.pl.render_params import (
    CmapParams,
    Color,
    ColorbarSpec,
    FigParams,
    ImageRenderParams,
    LabelsRenderParams,
    OutlineParams,
    PointsRenderParams,
    ScalebarParams,
    ShapesRenderParams,
    _FontSize,
    _FontWeight,
)

to_hex = partial(colors.to_hex, keep_alpha=True)

# replace with
# from spatialdata._types import ColorLike
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
ColorLike = tuple[float, ...] | list[float] | str


def _extract_scalar_value(value: Any, default: float = 0.0) -> float:
    """
    Extract a scalar float value from various data types.

    Handles pandas Series, arrays, lists, and other iterables by taking the first element.
    Converts non-numeric values to the default value.

    Parameters
    ----------
    value : Any
        The value to extract a scalar from
    default : float, default 0.0
        Default value to return if conversion fails

    Returns
    -------
    float
        The extracted scalar value
    """
    try:
        # Handle pandas Series or similar objects with iloc
        if hasattr(value, "iloc"):
            if len(value) > 0:
                value = value.iloc[0]
            else:
                return default

        # Handle other array-like objects
        elif hasattr(value, "__len__") and not isinstance(value, (str, bytes)):
            if len(value) > 0:
                value = value[0]
            else:
                return default

        # Convert to float, handling NaN values
        if pd.isna(value):
            return default

        return float(value)

    except (TypeError, ValueError, IndexError):
        return default


def _verify_plotting_tree(sdata: SpatialData) -> SpatialData:
    """Verify that the plotting tree exists, and if not, create it."""
    if not hasattr(sdata, "plotting_tree"):
        sdata.plotting_tree = OrderedDict()

    return sdata


def _get_coordinate_system_mapping(sdata: SpatialData) -> dict[str, list[str]]:
    coordsys_keys = sdata.coordinate_systems
    image_keys = [] if sdata.images is None else sdata.images.keys()
    label_keys = [] if sdata.labels is None else sdata.labels.keys()
    shape_keys = [] if sdata.shapes is None else sdata.shapes.keys()
    point_keys = [] if sdata.points is None else sdata.points.keys()

    mapping: dict[str, list[str]] = {}

    if len(coordsys_keys) < 1:
        raise ValueError("SpatialData object must have at least one coordinate system to generate a mapping.")

    for key in coordsys_keys:
        mapping[key] = []

        for image_key in image_keys:
            transformations = get_transformation(sdata.images[image_key], get_all=True)

            if key in list(transformations.keys()):
                mapping[key].append(image_key)

        for label_key in label_keys:
            transformations = get_transformation(sdata.labels[label_key], get_all=True)

            if key in list(transformations.keys()):
                mapping[key].append(label_key)

        for shape_key in shape_keys:
            transformations = get_transformation(sdata.shapes[shape_key], get_all=True)

            if key in list(transformations.keys()):
                mapping[key].append(shape_key)

        for point_key in point_keys:
            transformations = get_transformation(sdata.points[point_key], get_all=True)

            if key in list(transformations.keys()):
                mapping[key].append(point_key)

    return mapping


def _is_color_like(color: Any) -> bool:
    """Check if a value is a valid color.

    For discussion, see: https://github.com/scverse/spatialdata-plot/issues/327.
    matplotlib accepts strings in [0, 1] as grey-scale values - therefore,
    "0" and "1" are considered valid colors. However, we won't do that
    so we're filtering these out.
    """
    if isinstance(color, str):
        try:
            num_value = float(color)
            if 0 <= num_value <= 1:
                return False
        except ValueError:
            # we're not dealing with what matplotlib considers greyscale
            pass
        if color.startswith("#") and len(color) not in [7, 9]:
            # we only accept hex colors in the form #RRGGBB or #RRGGBBAA, not short forms as matplotlib does
            return False

    return bool(colors.is_color_like(color))


def _prepare_params_plot(
    # this param is inferred when `pl.show`` is called
    num_panels: int,
    # this args are passed at `pl.show``
    figsize: tuple[float, float] | None = None,
    dpi: int | None = None,
    fig: Figure | None = None,
    ax: Axes | Sequence[Axes] | None = None,
    wspace: float | None = None,
    hspace: float = 0.25,
    ncols: int = 4,
    frameon: bool | None = None,
    # this args will be inferred from coordinate system
    scalebar_dx: float | Sequence[float] | None = None,
    scalebar_units: str | Sequence[str] | None = None,
) -> tuple[FigParams, ScalebarParams]:
    # handle axes and size
    wspace = 0.75 / rcParams["figure.figsize"][0] + 0.02 if wspace is None else wspace
    figsize = rcParams["figure.figsize"] if figsize is None else figsize
    dpi = rcParams["figure.dpi"] if dpi is None else dpi
    if num_panels > 1 and ax is None:
        fig, grid = _panel_grid(
            num_panels=num_panels,
            hspace=hspace,
            wspace=wspace,
            ncols=ncols,
            dpi=dpi,
            figsize=figsize,
        )
        axs: None | Sequence[Axes] = [plt.subplot(grid[c]) for c in range(num_panels)]
    elif num_panels > 1:
        if not isinstance(ax, Sequence):
            raise TypeError(f"Expected `ax` to be a `Sequence`, but got {type(ax).__name__}")
        if ax is not None and len(ax) != num_panels:
            raise ValueError(f"Len of `ax`: {len(ax)} is not equal to number of panels: {num_panels}.")
        if fig is None:
            raise ValueError(
                f"Invalid value of `fig`: {fig}. If a list of `Axes` is passed, a `Figure` must also be specified."
            )
        assert ax is None or isinstance(ax, Sequence), f"Invalid type of `ax`: {type(ax)}, expected `Sequence`."
        axs = ax
    else:
        axs = None
        if ax is None:
            fig, ax = plt.subplots(figsize=figsize, dpi=dpi, constrained_layout=True)
        elif isinstance(ax, Axes):
            # needed for rasterization if user provides Axes object
            fig = ax.get_figure()
            fig.set_dpi(dpi)

    # set scalebar
    if scalebar_dx is not None:
        scalebar_dx, scalebar_units = _get_scalebar(scalebar_dx, scalebar_units, num_panels)

    fig_params = FigParams(
        fig=fig,
        ax=ax,
        axs=axs,
        num_panels=num_panels,
        frameon=frameon,
    )
    scalebar_params = ScalebarParams(scalebar_dx=scalebar_dx, scalebar_units=scalebar_units)

    return fig_params, scalebar_params


def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame:
    """Check which coordinate systems contain which elements and return that info."""
    cs_mapping = _get_coordinate_system_mapping(sdata)
    content_flags = ["has_images", "has_labels", "has_points", "has_shapes"]
    cs_contents = pd.DataFrame(columns=["cs"] + content_flags)

    for cs_name, element_ids in cs_mapping.items():
        # determine if coordinate system has the respective elements
        cs_has_images = any(e in sdata.images for e in element_ids)
        cs_has_labels = any(e in sdata.labels for e in element_ids)
        cs_has_points = any(e in sdata.points for e in element_ids)
        cs_has_shapes = any(e in sdata.shapes for e in element_ids)

        cs_contents = pd.concat(
            [
                cs_contents,
                pd.DataFrame(
                    {
                        "cs": cs_name,
                        "has_images": [cs_has_images],
                        "has_labels": [cs_has_labels],
                        "has_points": [cs_has_points],
                        "has_shapes": [cs_has_shapes],
                    }
                ),
            ]
        )

        cs_contents["has_images"] = cs_contents["has_images"].astype("bool")
        cs_contents["has_labels"] = cs_contents["has_labels"].astype("bool")
        cs_contents["has_points"] = cs_contents["has_points"].astype("bool")
        cs_contents["has_shapes"] = cs_contents["has_shapes"].astype("bool")

    return cs_contents


def _get_centroid_of_pathpatch(pathpatch: mpatches.PathPatch) -> tuple[float, float]:
    # Extract the vertices from the PathPatch
    path = pathpatch.get_path()
    vertices = path.vertices
    x = vertices[:, 0]
    y = vertices[:, 1]

    area = 0.5 * np.sum(x[:-1] * y[1:] - x[1:] * y[:-1])

    # Calculate the centroid coordinates
    centroid_x = np.sum((x[:-1] + x[1:]) * (x[:-1] * y[1:] - x[1:] * y[:-1])) / (6 * area)
    centroid_y = np.sum((y[:-1] + y[1:]) * (x[:-1] * y[1:] - x[1:] * y[:-1])) / (6 * area)

    return centroid_x, centroid_y


def _scale_pathpatch_around_centroid(pathpatch: mpatches.PathPatch, scale_factor: float) -> None:
    scale_value = _extract_scalar_value(scale_factor, default=1.0)
    centroid = _get_centroid_of_pathpatch(pathpatch)
    vertices = pathpatch.get_path().vertices
    scaled_vertices = np.array([centroid + (vertex - centroid) * scale_value for vertex in vertices])
    pathpatch.get_path().vertices = scaled_vertices


def _get_collection_shape(
    shapes: list[GeoDataFrame],
    c: Any,
    s: float,
    norm: Any,
    render_params: ShapesRenderParams,
    fill_alpha: None | float = None,
    outline_alpha: None | float = None,
    outline_color: None | str | list[float] = "white",
    linewidth: float = 0.0,
    **kwargs: Any,
) -> PatchCollection:
    """
    Build a PatchCollection for shapes with correct handling of.

      - continuous numeric vectors with NaNs,
      - per-row RGBA arrays,
      - a single color or a list of color specs.

    Only NaNs are painted with na_color; finite values are mapped via norm+cmap.
    """
    cmap = kwargs["cmap"]

    # Resolve na color once
    na_rgba = colors.to_rgba(render_params.cmap_params.na_color.get_hex_with_alpha())

    # Try to interpret c as numpy array
    c_arr = np.asarray(c)
    fill_c: np.ndarray

    def _as_rgba_array(x: Any) -> np.ndarray:
        return np.asarray(ColorConverter().to_rgba_array(x))

    # Case A: per-row numeric colors given as Nx3 or Nx4 float array
    if (
        c_arr.ndim == 2
        and c_arr.shape[0] == len(shapes)
        and c_arr.shape[1] in (3, 4)
        and np.issubdtype(c_arr.dtype, np.number)
    ):
        fill_c = _as_rgba_array(c_arr)

    # Case B: continuous numeric vector len == n_shapes (possibly with NaNs)
    elif c_arr.ndim == 1 and len(c_arr) == len(shapes) and np.issubdtype(c_arr.dtype, np.number):
        finite_mask = np.isfinite(c_arr)

        # Select or build a normalization that ignores NaNs for scaling
        if isinstance(norm, Normalize):
            used_norm: Normalize = norm
        else:
            if finite_mask.any():
                vmin = float(np.nanmin(c_arr[finite_mask]))
                vmax = float(np.nanmax(c_arr[finite_mask]))
                if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
                    vmin, vmax = 0.0, 1.0
            else:
                vmin, vmax = 0.0, 1.0
            used_norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False)

        # Map finite values through cmap(norm(.)); NaNs get na_color
        fill_c = np.empty((len(c_arr), 4), dtype=float)
        fill_c[:] = na_rgba
        if finite_mask.any():
            fill_c[finite_mask] = cmap(used_norm(c_arr[finite_mask]))

    elif c_arr.ndim == 1 and len(c_arr) == len(shapes) and c_arr.dtype == object:
        # Split into numeric vs color-like
        c_series = pd.Series(c_arr, copy=False)
        num = pd.to_numeric(c_series, errors="coerce").to_numpy()
        is_num = np.isfinite(num)

        # init with na color
        fill_c = np.empty((len(c_series), 4), dtype=float)
        fill_c[:] = na_rgba

        # numeric entries via cmap(norm)
        if is_num.any():
            if isinstance(norm, Normalize):
                used_norm = norm
            else:
                vmin = float(np.nanmin(num[is_num])) if is_num.any() else 0.0
                vmax = float(np.nanmax(num[is_num])) if is_num.any() else 1.0
                if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
                    vmin, vmax = 0.0, 1.0
                used_norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False)
            fill_c[is_num] = cmap(used_norm(num[is_num]))

        # non-numeric, non-NaN entries as explicit colors
        non_numeric_color_mask = (~is_num) & c_series.notna().to_numpy()
        if non_numeric_color_mask.any():
            fill_c[non_numeric_color_mask] = ColorConverter().to_rgba_array(c_series[non_numeric_color_mask].tolist())

    # Case C: single color or list of color-like specs (strings or tuples)
    else:
        fill_c = _as_rgba_array(c)

    # Apply optional fill alpha without destroying existing transparency
    if fill_alpha is not None:
        nonzero_alpha = fill_c[..., -1] > 0
        fill_c[nonzero_alpha, -1] = fill_alpha

    # Outline handling
    if outline_alpha and outline_alpha > 0.0:
        outline_c_array = _as_rgba_array(outline_color)
        outline_c_array[..., -1] = outline_alpha
        outline_c = outline_c_array.tolist()
    else:
        outline_c = [None] * fill_c.shape[0]

    if isinstance(shapes, GeoDataFrame):
        shapes_df: GeoDataFrame | pd.DataFrame = shapes.copy()
    else:
        shapes_df = pd.DataFrame(shapes, copy=True)

    # Robustly normalise geometries to a canonical representation.
    # This ensures consistent exterior/interior ring orientation so that
    # matplotlib's fill rules handle holes correctly regardless of user input.
    if "geometry" in shapes_df.columns:

        def _normalize_geom(geom: Any) -> Any:
            if geom is None or getattr(geom, "is_empty", False):
                return geom
            # shapely.normalize is available in shapely>=2; fall back to geom.normalize()
            normalize_func = getattr(shapely, "normalize", None)
            if callable(normalize_func):
                try:
                    return normalize_func(geom)
                except (GEOSException, TypeError, ValueError):
                    return geom
            if hasattr(geom, "normalize"):
                try:
                    return geom.normalize()
                except (GEOSException, TypeError, ValueError):
                    return geom
            return geom

        shapes_df["geometry"] = shapes_df["geometry"].apply(_normalize_geom)

    shapes_df = shapes_df[shapes_df["geometry"].apply(lambda geom: not geom.is_empty)]
    shapes_df = shapes_df.reset_index(drop=True)

    def _assign_fill_and_outline_to_row(
        fill_colors: list[Any],
        outline_colors: list[Any],
        row: dict[str, Any],
        idx: int,
        is_multiple_shapes: bool,
    ) -> None:
        if is_multiple_shapes and len(fill_colors) == 1:
            row["fill_c"] = fill_colors[0]
            row["outline_c"] = outline_colors[0]
        else:
            row["fill_c"] = fill_colors[idx]
            row["outline_c"] = outline_colors[idx]

    def _process_polygon(row: pd.Series, scale: float) -> dict[str, Any]:
        coords = np.array(row["geometry"].exterior.coords)
        centroid = np.mean(coords, axis=0)
        scale_value = _extract_scalar_value(scale, default=1.0)
        scaled = (centroid + (coords - centroid) * scale_value).tolist()
        return {**row.to_dict(), "geometry": mpatches.Polygon(scaled, closed=True)}

    def _process_multipolygon(row: pd.Series, scale: float) -> list[dict[str, Any]]:
        mp = _make_patch_from_multipolygon(row["geometry"])
        row_dict = row.to_dict()
        for m in mp:
            _scale_pathpatch_around_centroid(m, scale)
        return [{**row_dict, "geometry": m} for m in mp]

    def _process_point(row: pd.Series, scale: float) -> dict[str, Any]:
        radius_value = _extract_scalar_value(row["radius"], default=0.0)
        scale_value = _extract_scalar_value(scale, default=1.0)
        radius = radius_value * scale_value

        return {
            **row.to_dict(),
            "geometry": mpatches.Circle((row["geometry"].x, row["geometry"].y), radius=radius),
        }

    def _create_patches(
        shapes_df_: GeoDataFrame, fill_colors: list[Any], outline_colors: list[Any], scale: float
    ) -> pd.DataFrame:
        rows: list[dict[str, Any]] = []
        is_multiple = len(shapes_df_) > 1
        for idx, row in shapes_df_.iterrows():
            geom_type = row["geometry"].geom_type
            processed: list[dict[str, Any]] = []
            if geom_type == "Polygon":
                processed.append(_process_polygon(row, scale))
            elif geom_type == "MultiPolygon":
                processed.extend(_process_multipolygon(row, scale))
            elif geom_type == "Point":
                processed.append(_process_point(row, scale))
            for pr in processed:
                _assign_fill_and_outline_to_row(fill_colors, outline_colors, pr, idx, is_multiple)
                rows.append(pr)
        return pd.DataFrame(rows)

    patches = _create_patches(
        shapes_df, fill_c.tolist(), outline_c.tolist() if hasattr(outline_c, "tolist") else outline_c, s
    )

    return PatchCollection(
        patches["geometry"].values.tolist(),
        snap=False,
        lw=linewidth,
        facecolor=patches["fill_c"],
        edgecolor=None if all(o is None for o in outline_c) else outline_c,
        **kwargs,
    )


def _panel_grid(
    num_panels: int,
    hspace: float,
    wspace: float,
    ncols: int,
    figsize: tuple[float, float],
    dpi: int | None = None,
) -> tuple[Figure, GridSpec]:
    n_panels_x = min(ncols, num_panels)
    n_panels_y = np.ceil(num_panels / n_panels_x).astype(int)

    fig = plt.figure(
        figsize=(figsize[0] * n_panels_x * (1 + wspace), figsize[1] * n_panels_y),
        dpi=dpi,
    )
    left = 0.2 / n_panels_x
    bottom = 0.13 / n_panels_y
    gs = GridSpec(
        nrows=n_panels_y,
        ncols=n_panels_x,
        left=left,
        right=1 - (n_panels_x - 1) * left - 0.01 / n_panels_x,
        bottom=bottom,
        top=1 - (n_panels_y - 1) * bottom - 0.1 / n_panels_y,
        hspace=hspace,
        wspace=wspace,
    )
    return fig, gs


def _get_scalebar(
    scalebar_dx: float | Sequence[float] | None = None,
    scalebar_units: str | Sequence[str] | None = None,
    len_lib: int | None = None,
) -> tuple[Sequence[float] | None, Sequence[str] | None]:
    if scalebar_dx is not None:
        _scalebar_dx = _get_list(scalebar_dx, _type=float, ref_len=len_lib, name="scalebar_dx")
        scalebar_units = "um" if scalebar_units is None else scalebar_units
        _scalebar_units = _get_list(scalebar_units, _type=str, ref_len=len_lib, name="scalebar_units")
    else:
        _scalebar_dx = None
        _scalebar_units = None

    return _scalebar_dx, _scalebar_units


def _prepare_cmap_norm(
    cmap: Colormap | str | None = None,
    norm: Normalize | None = None,
    na_color: Color = Color(),
) -> CmapParams:
    # TODO: check refactoring norm out here as it gets overwritten later
    cmap_is_default = cmap is None
    if cmap is None:
        cmap = rcParams["image.cmap"]
    if isinstance(cmap, str):
        cmap = matplotlib.colormaps[cmap]

    cmap = copy(cmap)

    assert isinstance(cmap, Colormap), f"Invalid type of `cmap`: {type(cmap)}, expected `Colormap`."

    if norm is None:
        norm = Normalize(vmin=None, vmax=None, clip=False)

    cmap.set_bad(na_color.get_hex_with_alpha())

    return CmapParams(
        cmap=cmap,
        norm=norm,
        na_color=na_color,
        cmap_is_default=cmap_is_default,
    )


def _set_outline(
    outline_alpha: float | int | tuple[float | int, float | int] | None,
    outline_width: int | float | tuple[float | int, float | int] | None,
    outline_color: Color | tuple[Color, Color | None] | None,
    **kwargs: Any,
) -> tuple[tuple[float, float], OutlineParams]:
    """Create OutlineParams object for shapes, including possibility of double outline.

    Rules for outline rendering:
    1) outline_alpha always takes precedence if given by the user.
    In absence of outline_alpha:
    2) If outline_color is specified and implying an alpha (e.g. RGBA array or #RRGGBBAA): that alpha is used
    3) If outline_color (w/o implying an alpha) and/or outline_width is specified: alpha of outlines set to 1.0
    """
    # A) User doesn't want to see outlines
    if (
        (outline_alpha and outline_alpha == 0.0)
        or (isinstance(outline_alpha, tuple) and np.all(np.array(outline_alpha) == 0.0))
        or not (outline_alpha or outline_width or outline_color)
    ):
        return (0.0, 0.0), OutlineParams(None, 1.5, None, 0.5)

    # B) User wants to see at least 1 outline
    if isinstance(outline_width, tuple):
        if len(outline_width) != 2:
            raise ValueError(
                f"Tuple of length {len(outline_width)} was passed for outline_width. When specifying multiple outlines,"
                " please pass a tuple of exactly length 2."
            )
        if not outline_color:
            outline_color = (Color("#000000"), Color("#ffffff"))
        elif not isinstance(outline_color, tuple):
            raise ValueError(
                "No tuple was passed for outline_color, while two outlines were specified by using the outline_width "
                "argument. Please specify the outline colors in a tuple of length two."
            )

    if isinstance(outline_color, tuple):
        if len(outline_color) != 2:
            raise ValueError(
                f"Tuple of length {len(outline_color)} was passed for outline_color. When specifying multiple outlines,"
                " please pass a tuple of exactly length 2."
            )
        if not outline_width:
            outline_width = (1.5, 0.5)
        elif not isinstance(outline_width, tuple):
            raise ValueError(
                "No tuple was passed for outline_width, while two outlines were specified by using the outline_color "
                "argument. Please specify the outline widths in a tuple of length two."
            )

    if isinstance(outline_width, float | int):
        outline_width = (outline_width, 0.0)
    elif not outline_width:
        outline_width = (1.5, 0.0)
    if isinstance(outline_color, Color):
        outline_color = (outline_color, None)
    elif not outline_color:
        outline_color = (Color("#000000ff"), None)

    assert isinstance(outline_color, tuple), "outline_color is not a tuple"  # shut up mypy
    assert isinstance(outline_width, tuple), "outline_width is not a tuple"

    for ow in outline_width:
        if not isinstance(ow, int | float):
            raise TypeError(f"Invalid type of `outline_width`: {type(ow)}, expected `int` or `float`.")

    if outline_alpha:
        if isinstance(outline_alpha, int | float):
            # for a single outline: second width value is 0.0
            outline_alpha = (outline_alpha, 0.0) if outline_width[1] == 0.0 else (outline_alpha, outline_alpha)
    else:
        # if alpha wasn't explicitly specified by the user
        outer_ol_alpha = outline_color[0].get_alpha_as_float() if isinstance(outline_color[0], Color) else 1.0
        inner_ol_alpha = outline_color[1].get_alpha_as_float() if isinstance(outline_color[1], Color) else 1.0
        outline_alpha = (outer_ol_alpha, inner_ol_alpha)

    # handle possible linewidths of 0.0 => outline won't be rendered in the first place
    if outline_width[0] == 0.0:
        outline_alpha = (0.0, outline_alpha[1])
    if outline_width[1] == 0.0:
        outline_alpha = (outline_alpha[0], 0.0)

    if outline_alpha[0] > 0.0 or outline_alpha[1] > 0.0:
        kwargs.pop("edgecolor", None)  # remove edge from kwargs if present
        kwargs.pop("alpha", None)  # remove alpha from kwargs if present

    return outline_alpha, OutlineParams(
        outline_color[0],
        outline_width[0],
        outline_color[1],
        outline_width[1],
    )


def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int = 3) -> plt.Figure | plt.Axes:
    """Set up the axs objects.

    Parameters
    ----------
    num_images
        Number of images to plot. Must be greater than 1.
    ncols
        Number of columns in the subplot grid, by default 4
    width
        Width of each subplot, by default 4

    Returns
    -------
    Union[plt.Figure, plt.Axes]
        Matplotlib figure and axes object.
    """
    if num_images < ncols:
        nrows = 1
        ncols = num_images
    else:
        nrows, reminder = divmod(num_images, ncols)

        if nrows == 0:
            nrows = 1
        if reminder > 0:
            nrows += 1

    fig, axes = plt.subplots(nrows, ncols, figsize=(width * ncols, height * nrows))

    if not isinstance(axes, Iterable):
        axes = np.array([axes])

    # get rid of the empty axes
    _ = [ax.axis("off") for ax in axes.flatten()[num_images:]]
    return fig, axes


def _get_colors_for_categorical_obs(
    categories: Sequence[str | int],
    palette: ListedColormap | str | list[str] | None = None,
    alpha: float = 1.0,
    cmap_params: CmapParams | None = None,
) -> list[str]:
    """
    Return a list of colors for a categorical observation.

    Parameters
    ----------
    adata
        AnnData object
    value_to_plot
        Name of a valid categorical observation
    categories
        categories of the categorical observation.

    Returns
    -------
    None
    """
    len_cat = len(categories)

    # check if default matplotlib palette has enough colors
    if palette is None:
        if cmap_params is not None and not cmap_params.cmap_is_default:
            palette = cmap_params.cmap
        elif len(rcParams["axes.prop_cycle"].by_key()["color"]) >= len_cat:
            cc = rcParams["axes.prop_cycle"]()
            palette = [next(cc)["color"] for _ in range(len_cat)]
        elif len_cat <= 20:
            palette = default_20
        elif len_cat <= 28:
            palette = default_28
        elif len_cat <= len(default_102):  # 103 colors
            palette = default_102
        else:
            palette = ["grey" for _ in range(len_cat)]
            logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.")
    else:
        # raise error when user didn't provide the right number of colors in palette
        if isinstance(palette, list) and len(palette) != len(categories):
            raise ValueError(
                f"The number of provided values in the palette ({len(palette)}) doesn't agree with the number of "
                f"categories that should be colored ({categories})."
            )

    # otherwise, single channels turn out grey
    color_idx = np.linspace(0, 1, len_cat) if len_cat > 1 else [0.7]

    if isinstance(palette, str):
        palette = [to_hex(palette)]
    elif isinstance(palette, list):
        palette = [to_hex(x) for x in palette]
    elif isinstance(palette, ListedColormap):
        palette = [to_hex(x) for x in palette(color_idx, alpha=alpha)]
    elif isinstance(palette, LinearSegmentedColormap):
        palette = [to_hex(palette(x, alpha=alpha)) for x in color_idx]  # type: ignore[attr-defined]
    else:
        raise TypeError(f"Palette is {type(palette)} but should be string or list.")

    return palette[:len_cat]  # type: ignore[return-value]


def _format_element_names(element_name: list[str] | str | None) -> str:
    if element_name is None:
        return "the requested element"
    if isinstance(element_name, str):
        return f"'{element_name}'"
    return ", ".join(f"'{name}'" for name in element_name)


def _format_element_name(element_name: list[str] | str | None) -> str:
    if isinstance(element_name, str):
        return element_name
    if isinstance(element_name, list) and len(element_name) > 0:
        return ", ".join(element_name)
    return "<unknown>"


def _preview_values(values: Sequence[Any], limit: int = 5) -> str:
    values = list(values)
    preview = ", ".join(map(str, values[:limit]))
    if len(values) > limit:
        preview += ", ..."
    return preview


def _ensure_one_to_one_mapping(
    sdata: SpatialData,
    element: SpatialElement | None,
    element_name: list[str] | str | None,
    table_name: str | None,
) -> None:
    if table_name is None or element_name is None:
        return

    table = sdata.get(table_name, None)
    if table is None:
        return

    _validate_table_instance_uniqueness(table, element_name, table_name)
    _validate_shape_index_uniqueness(element, element_name, table_name)


def _validate_shape_index_uniqueness(
    element: SpatialElement | None,
    element_name: list[str] | str | None,
    table_name: str,
) -> None:
    if not isinstance(element, GeoDataFrame):
        return

    duplicates = element.index[element.index.duplicated(keep=False)]
    if duplicates.empty:
        return

    element_label = _format_element_names(element_name)
    preview = _preview_values(pd.Index(duplicates).unique())
    raise ValueError(
        f"{element_label} contains duplicate index values ({preview}) while table '{table_name}' "
        "requires a one-to-one mapping between shapes and annotations. "
        "Please ensure each spatial element has a unique index."
    )


def _validate_table_instance_uniqueness(
    table: AnnData,
    element_name: list[str] | str | None,
    table_name: str,
) -> None:
    try:
        _, region_key, instance_key = get_table_keys(table)
    except (AttributeError, KeyError, ValueError):
        return

    if instance_key is None or instance_key not in table.obs.columns:
        return

    obs = table.obs
    if region_key is not None and region_key in obs.columns and element_name is not None:
        element_names = [element_name] if isinstance(element_name, str) else list(element_name)
        obs = obs[obs[region_key].isin(element_names)]

    if obs.empty:
        return

    duplicates_mask = obs[instance_key].duplicated(keep=False)
    if not duplicates_mask.any():
        return

    element_label = _format_element_names(element_name)
    preview = _preview_values(obs.loc[duplicates_mask, instance_key].astype(str).unique())
    raise ValueError(
        f"Table '{table_name}' contains duplicate '{instance_key}' values for {element_label}: {preview}. "
        "Each observation must annotate a single spatial element. Please deduplicate the table or subset it "
        "before plotting."
    )


def _infer_color_data_kind(
    series: pd.Series,
    value_to_plot: str,
    element_name: list[str] | str | None,
    table_name: str | None,
    warn_on_object_to_categorical: bool = False,
) -> tuple[Literal["numeric", "categorical"], pd.Series | pd.Categorical]:
    element_label = _format_element_name(element_name)

    if isinstance(series.dtype, pd.CategoricalDtype):
        return "categorical", pd.Categorical(series)

    if is_bool_dtype(series.dtype):
        return "numeric", series.astype(float)

    if is_numeric_dtype(series.dtype):
        return "numeric", pd.to_numeric(series, errors="coerce")

    if is_string_dtype(series.dtype) or series.dtype == object:
        non_na = series[~pd.isna(series)]
        if len(non_na) == 0:
            return "numeric", pd.to_numeric(series, errors="coerce")

        numeric_like = pd.to_numeric(non_na, errors="coerce")
        has_numeric = numeric_like.notna().any()
        has_non_numeric = numeric_like.isna().any()

        if has_numeric and has_non_numeric:
            invalid_examples = non_na[numeric_like.isna()].astype(str).unique()[:3]
            location = f" in table '{table_name}'" if table_name is not None else ""
            raise TypeError(
                f"Column '{value_to_plot}' for element '{element_label}'{location} contains both numeric and "
                f"non-numeric values (e.g. {', '.join(invalid_examples)}). "
                "Please ensure that the column stores consistent data."
            )

        if has_numeric:
            return "numeric", pd.to_numeric(series, errors="coerce")

        if warn_on_object_to_categorical:
            logger.warning(
                f"Converting copy of '{value_to_plot}' column to categorical dtype for categorical plotting. "
                "Consider converting before plotting."
            )

        return "categorical", pd.Categorical(series)

    return "numeric", pd.to_numeric(series, errors="coerce")


def _set_color_source_vec(
    sdata: sd.SpatialData,
    element: SpatialElement | None,
    value_to_plot: str | None,
    na_color: Color,
    element_name: list[str] | str | None = None,
    groups: list[str] | str | None = None,
    palette: list[str] | str | None = None,
    cmap_params: CmapParams | None = None,
    alpha: float = 1.0,
    table_name: str | None = None,
    table_layer: str | None = None,
    render_type: Literal["points"] | None = None,
    coordinate_system: str | None = None,
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
    if value_to_plot is None and element is not None:
        color = np.full(len(element), na_color.get_hex_with_alpha())
        return color, color, False

    # Figure out where to get the color from
    origins = _locate_value(
        value_key=value_to_plot,
        sdata=sdata,
        element_name=element_name,
        table_name=table_name,
    )

    if len(origins) > 1:
        raise ValueError(
            f"Color key '{value_to_plot}' for element '{element_name}' been found in multiple locations: {origins}."
        )

    if len(origins) == 1 and value_to_plot is not None:
        if table_name is not None:
            _ensure_one_to_one_mapping(
                sdata=sdata,
                element=element,
                element_name=element_name,
                table_name=table_name,
            )
        color_source_vector = get_values(
            value_key=value_to_plot,
            sdata=sdata,
            element_name=element_name,
            table_name=table_name,
            table_layer=table_layer,
        )[value_to_plot]

        color_series = (
            color_source_vector if isinstance(color_source_vector, pd.Series) else pd.Series(color_source_vector)
        )

        kind, processed = _infer_color_data_kind(
            series=color_series,
            value_to_plot=value_to_plot,
            element_name=element_name,
            table_name=table_name,
            warn_on_object_to_categorical=table_name is not None,
        )

        if kind == "numeric":
            numeric_vector = processed
            if (
                not isinstance(element, GeoDataFrame)
                and isinstance(palette, list)
                and palette[0] is not None
                or isinstance(element, GeoDataFrame)
                and isinstance(palette, list)
            ):
                logger.warning(
                    "Ignoring categorical palette which is given for a continuous variable. "
                    "Consider using `cmap` to pass a ColorMap."
                )
            return None, numeric_vector, False

        assert isinstance(processed, pd.Categorical)
        color_source_vector = processed  # convert, e.g., `pd.Series`

        # Use the provided table_name parameter, fall back to only one present
        table_to_use: str | None
        if table_name is not None and table_name in sdata.tables:
            table_to_use = table_name
        elif table_name is not None and table_name not in sdata.tables:
            logger.warning(f"Table '{table_name}' not found in `sdata.tables`. Falling back to default behavior.")
            table_to_use = None
        else:
            table_keys = list(sdata.tables.keys())
            if table_keys:
                table_to_use = table_keys[0]
                logger.warning(f"No table name provided, using '{table_to_use}' as fallback for color mapping.")
            else:
                table_to_use = None

        adata_for_mapping = sdata[table_to_use] if table_to_use is not None else None

        # Check if custom colors exist in the table's .uns slot
        if value_to_plot is not None and _has_colors_in_uns(sdata, table_name, value_to_plot):
            # Extract colors directly from the table's .uns slot
            # Convert Color to ColorLike (str) for the function
            na_color_like: ColorLike = na_color.get_hex() if isinstance(na_color, Color) else na_color
            color_mapping = _extract_colors_from_table_uns(
                sdata=sdata,
                table_name=table_name,
                col_to_colorby=value_to_plot,
                color_source_vector=color_source_vector,
                na_color=na_color_like,
            )
            if color_mapping is not None:
                if isinstance(palette, str):
                    palette = [palette]
                color_mapping = _modify_categorical_color_mapping(
                    mapping=color_mapping,
                    groups=groups,
                    palette=palette,
                )
            else:
                logger.warning(f"Failed to extract colors for '{value_to_plot}', falling back to default mapping.")
                # Fall back to the existing method if extraction fails
                color_mapping = _get_categorical_color_mapping(
                    adata=sdata[table_to_use],
                    cluster_key=value_to_plot,
                    color_source_vector=color_source_vector,
                    cmap_params=cmap_params,
                    alpha=alpha,
                    groups=groups,
                    palette=palette,
                    na_color=na_color,
                    render_type=render_type,
                )
        else:
            color_mapping = None

        if color_mapping is None:
            # Use the existing color mapping method
            color_mapping = _get_categorical_color_mapping(
                adata=adata_for_mapping,
                cluster_key=value_to_plot,
                color_source_vector=color_source_vector,
                cmap_params=cmap_params,
                alpha=alpha,
                groups=groups,
                palette=palette,
                na_color=na_color,
                render_type=render_type,
            )

        color_source_vector = color_source_vector.set_categories(color_mapping.keys())
        if color_mapping is None:
            raise ValueError("Unable to create color palette.")

        # do not rename categories, as colors need not be unique
        color_vector = color_source_vector.map(color_mapping)

        return color_source_vector, color_vector, True

    if table_name is None:
        raise KeyError(
            f"Unable to locate color key '{value_to_plot}' for element '{element_name}'. "
            "Please ensure the key exists in a table annotating this element."
        )
    raise KeyError(
        f"Unable to locate color key '{value_to_plot}' in table '{table_name}' for element '{element_name}'."
    )


def _map_color_seg(
    seg: ArrayLike,
    cell_id: ArrayLike,
    color_vector: ArrayLike | pd.Series[CategoricalDtype],
    color_source_vector: pd.Series[CategoricalDtype],
    cmap_params: CmapParams,
    na_color: Color,
    seg_erosionpx: int | None = None,
    seg_boundaries: bool = False,
) -> ArrayLike:
    cell_id = np.array(cell_id)

    if pd.api.types.is_categorical_dtype(color_vector.dtype):
        # Case A: users wants to plot a categorical column
        if np.any(color_source_vector.isna()):
            cell_id[color_source_vector.isna()] = 0
        val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1)
        cols = colors.to_rgba_array(color_vector.categories)
    elif pd.api.types.is_numeric_dtype(color_vector.dtype):
        # Case B: user wants to plot a continous column
        if isinstance(color_vector, pd.Series):
            color_vector = color_vector.to_numpy()
        cols = cmap_params.cmap(cmap_params.norm(color_vector))
        val_im = map_array(seg.copy(), cell_id, cell_id)
    else:
        # Case C: User didn't specify any colors
        if color_source_vector is not None and (
            set(color_vector) == set(color_source_vector)
            and len(set(color_vector)) == 1
            and set(color_vector) == {na_color.get_hex_with_alpha()}
            and not na_color.color_modified_by_user()
        ):
            val_im = map_array(seg.copy(), cell_id, cell_id)
            RNG = default_rng(42)
            cols = RNG.random((len(color_vector), 3))
        else:
            # Case D: User didn't specify a column to color by, but modified the na_color
            val_im = map_array(seg.copy(), cell_id, cell_id)
            first_value = color_vector.iloc[0] if isinstance(color_vector, pd.Series) else color_vector[0]
            if _is_color_like(first_value):
                # we have color-like values (e.g., hex or named colors)
                assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like."
                cols = colors.to_rgba_array(color_vector)
            else:
                cols = cmap_params.cmap(cmap_params.norm(color_vector))

    if seg_erosionpx is not None:
        val_im[val_im == erosion(val_im, square(seg_erosionpx))] = 0

    seg_im: ArrayLike = label2rgb(
        label=val_im,
        colors=cols,
        bg_label=0,
        bg_color=(1, 1, 1),  # transparency doesn't really work
        image_alpha=0,
    )

    if seg_boundaries:
        if seg.shape[0] == 1:
            seg = np.squeeze(seg, axis=0)
        seg_bound: ArrayLike = np.clip(seg_im - find_boundaries(seg)[:, :, None], 0, 1)
        return np.dstack((seg_bound, np.where(val_im > 0, 1, 0)))  # add transparency here

    if len(val_im.shape) != len(seg_im.shape):
        val_im = np.expand_dims((val_im > 0).astype(int), axis=-1)
    return np.dstack((seg_im, val_im))


def _generate_base_categorial_color_mapping(
    adata: AnnData | None,
    cluster_key: str,
    color_source_vector: ArrayLike | pd.Series[CategoricalDtype],
    na_color: Color,
    cmap_params: CmapParams | None = None,
) -> Mapping[str, str]:
    if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns:
        colors = adata.uns[f"{cluster_key}_colors"]
        categories = color_source_vector.categories.tolist() + ["NaN"]

        colors = [to_hex(to_rgba(color)[:3]) for color in colors]

        if len(categories) > len(colors):
            return dict(zip(categories, colors + [na_color.get_hex_with_alpha()], strict=True))

        return dict(zip(categories, colors, strict=True))

    return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params)


def _has_colors_in_uns(
    sdata: sd.SpatialData,
    table_name: str | None,
    col_to_colorby: str,
) -> bool:
    """
    Check if <column_name>_colors exists in the specified table's .uns slot.

    Parameters
    ----------
    sdata
        SpatialData object containing tables
    table_name
        Name of the table to check. If None, uses the first available table.
    col_to_colorby
        Name of the categorical column (e.g., "celltype")

    Returns
    -------
    True if <col_to_colorby>_colors exists in the table's .uns, False otherwise
    """
    color_key = f"{col_to_colorby}_colors"

    # Determine which table to use
    if table_name is not None:
        if table_name not in sdata.tables:
            return False
        table_to_use = table_name
    else:
        if len(sdata.tables.keys()) == 0:
            return False
        # When no table is specified, check all tables for the color key
        return any(color_key in adata.uns for adata in sdata.tables.values())

    adata = sdata.tables[table_to_use]
    return color_key in adata.uns


def _extract_colors_from_table_uns(
    sdata: sd.SpatialData,
    table_name: str | None,
    col_to_colorby: str,
    color_source_vector: ArrayLike | pd.Series[CategoricalDtype],
    na_color: ColorLike,
) -> Mapping[str, str] | None:
    """
    Extract categorical colors from the <column_name>_colors pattern in adata.uns.

    This function looks for colors stored in the format <col_to_colorby>_colors in the
    specified table's .uns slot and creates a mapping from categories to colors.

    Parameters
    ----------
    sdata
        SpatialData object containing tables
    table_name
        Name of the table to look in. If None, uses the first available table.
    col_to_colorby
        Name of the categorical column (e.g., "celltype")
    color_source_vector
        Categorical vector containing the categories to map
    na_color
        Color to use for NaN/missing values

    Returns
    -------
    Mapping from category names to hex colors, or None if colors not found
    """
    color_key = f"{col_to_colorby}_colors"

    # Determine which table to use
    if table_name is not None:
        if table_name not in sdata.tables:
            logger.warning(f"Table '{table_name}' not found in sdata. Available tables: {list(sdata.tables.keys())}")
            return None
        table_to_use = table_name
    else:
        if len(sdata.tables) == 0:
            logger.warning("No tables found in sdata.")
            return None
        # No explicit table provided: search all tables for the color key
        candidate_tables: list[str] = [
            name
            for name, ad in sdata.tables.items()
            if color_key in ad.uns  # type: ignore[union-attr]
        ]
        if not candidate_tables:
            logger.debug(f"Color key '{color_key}' not found in any table uns.")
            return None
        table_to_use = candidate_tables[0]
        if len(candidate_tables) > 1:
            logger.warning(
                f"Color key '{color_key}' found in multiple tables {candidate_tables}; using table '{table_to_use}'."
            )
        logger.info(f"No table name provided, using '{table_to_use}' for color extraction.")

    adata = sdata.tables[table_to_use]

    # Check if the color pattern exists
    if color_key not in adata.uns:
        logger.debug(f"Color key '{color_key}' not found in table '{table_to_use}' uns.")
        return None

    # Extract colors and categories
    stored_colors = adata.uns[color_key]
    categories = color_source_vector.categories.tolist()

    # Validate na_color format and convert to hex string
    if isinstance(na_color, Color):
        na_color_hex = na_color.get_hex()
    else:
        na_color_str = str(na_color)
        if "#" not in na_color_str:
            logger.warning("Expected `na_color` to be a hex color, converting...")
            na_color_hex = to_hex(to_rgba(na_color)[:3])
        else:
            na_color_hex = na_color_str

    # Strip alpha channel from na_color if present
    if len(na_color_hex) == 9:  # #rrggbbaa format
        na_color_hex = na_color_hex[:7]  # Keep only #rrggbb

    def _to_hex_no_alpha(color_value: Any) -> str | None:
        try:
            rgba = to_rgba(color_value)[:3]
            hex_color: str = to_hex(rgba)
            if len(hex_color) == 9:
                hex_color = hex_color[:7]
            return hex_color
        except (TypeError, ValueError) as e:
            logger.warning(f"Error converting color '{color_value}' to hex format: {e}")
            return None

    color_mapping: dict[str, str] = {}

    if isinstance(stored_colors, Mapping):
        for category in categories:
            raw_color = stored_colors.get(category)
            if raw_color is None:
                logger.warning(f"No color specified for '{category}' in '{color_key}', using na_color.")
                color_mapping[category] = na_color_hex
                continue
            hex_color = _to_hex_no_alpha(raw_color)
            color_mapping[category] = hex_color if hex_color is not None else na_color_hex
        logger.info(f"Successfully extracted {len(color_mapping)} colors from '{color_key}' in table '{table_to_use}'.")
    else:
        try:
            hex_colors = [_to_hex_no_alpha(color) for color in stored_colors]
        except TypeError:
            logger.warning(f"Unsupported color storage for '{color_key}'. Expected sequence or mapping.")
            return None

        for i, category in enumerate(categories):
            if i < len(hex_colors) and hex_colors[i] is not None:
                hex_color = hex_colors[i]
                assert hex_color is not None  # type narrowing for mypy
                color_mapping[category] = hex_color
            else:
                logger.warning(f"Not enough colors provided for category '{category}', using na_color.")
                color_mapping[category] = na_color_hex
        logger.info(f"Successfully extracted {len(hex_colors)} colors from '{color_key}' in table '{table_to_use}'.")

    color_mapping["NaN"] = na_color_hex
    return color_mapping


def _modify_categorical_color_mapping(
    mapping: Mapping[str, str],
    groups: list[str] | str | None = None,
    palette: list[str] | str | None = None,
) -> Mapping[str, str]:
    if groups is None or isinstance(groups, list) and groups[0] is None:
        return mapping

    if palette is None or isinstance(palette, list) and palette[0] is None:
        # subset base mapping to only those specified in groups
        modified_mapping = {key: mapping[key] for key in mapping if key in groups or key == "NaN"}
    elif len(palette) == len(groups) and isinstance(groups, list) and isinstance(palette, list):
        modified_mapping = dict(zip(groups, palette, strict=True))
    else:
        raise ValueError(f"Expected palette to be of length `{len(groups)}`, found `{len(palette)}`.")

    return modified_mapping


def _get_default_categorial_color_mapping(
    color_source_vector: ArrayLike | pd.Series[CategoricalDtype],
    cmap_params: CmapParams | None = None,
) -> Mapping[str, str]:
    len_cat = len(color_source_vector.categories.unique())
    # Try to use provided colormap first
    if cmap_params is not None and cmap_params.cmap is not None and not cmap_params.cmap_is_default:
        # Generate evenly spaced indices for the colormap
        color_idx = np.linspace(0, 1, len_cat)
        if isinstance(cmap_params.cmap, ListedColormap):
            palette = [to_hex(x) for x in cmap_params.cmap(color_idx)]
        elif isinstance(cmap_params.cmap, LinearSegmentedColormap):
            palette = [to_hex(cmap_params.cmap(x)) for x in color_idx]
        else:
            # Fall back to default palettes if cmap is not of expected type
            palette = None
    else:
        palette = None

    # Fall back to default palettes if needed
    if palette is None:
        if len_cat <= 20:
            palette = default_20
        elif len_cat <= 28:
            palette = default_28
        elif len_cat <= len(default_102):  # 103 colors
            palette = default_102
        else:
            palette = ["grey"] * len_cat
            logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.")

    return dict(zip(color_source_vector.categories, palette[:len_cat], strict=True))


def _get_categorical_color_mapping(
    adata: AnnData | None,
    na_color: Color,
    cluster_key: str | None = None,
    color_source_vector: ArrayLike | pd.Series[CategoricalDtype] | None = None,
    cmap_params: CmapParams | None = None,
    alpha: float = 1,
    groups: list[str] | str | None = None,
    palette: list[str] | str | None = None,
    render_type: Literal["points"] | None = None,
) -> Mapping[str, str]:
    if not isinstance(color_source_vector, Categorical):
        raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}")

    if isinstance(groups, str):
        groups = [groups]

    if not palette and render_type == "points" and cmap_params is not None and not cmap_params.cmap_is_default:
        palette = cmap_params.cmap

        color_idx = color_idx = np.linspace(0, 1, len(color_source_vector.categories))
        if isinstance(palette, ListedColormap):
            palette = [to_hex(x) for x in palette(color_idx, alpha=alpha)]
        elif isinstance(palette, LinearSegmentedColormap):
            palette = [to_hex(palette(x, alpha=alpha)) for x in color_idx]  # type: ignore[attr-defined]
        return dict(zip(color_source_vector.categories, palette, strict=True))

    if isinstance(palette, str):
        palette = [palette]

    if cluster_key is None:
        # user didn't specify a column to use for coloring
        base_mapping = _get_default_categorial_color_mapping(
            color_source_vector=color_source_vector, cmap_params=cmap_params
        )
    else:
        base_mapping = _generate_base_categorial_color_mapping(
            adata=adata,
            cluster_key=cluster_key,
            color_source_vector=color_source_vector,
            na_color=na_color,
            cmap_params=cmap_params,
        )

    return _modify_categorical_color_mapping(mapping=base_mapping, groups=groups, palette=palette)


def _maybe_set_colors(
    source: AnnData,
    target: AnnData,
    key: str,
    palette: str | ListedColormap | Cycler | Sequence[Any] | None = None,
) -> None:
    color_key = f"{key}_colors"
    try:
        if palette is not None:
            raise KeyError("Unable to copy the palette when there was other explicitly specified.")
        target.uns[color_key] = source.uns[color_key]
    except KeyError:
        if isinstance(palette, str):
            palette = ListedColormap([palette])
        if isinstance(palette, ListedColormap):  # `scanpy` requires it
            palette = cycler(color=palette.colors)
        palette = None
        add_colors_for_categorical_sample_annotation(target, key=key, force_update_colors=True, palette=palette)


def _decorate_axs(
    ax: Axes,
    cax: PatchCollection,
    fig_params: FigParams,
    value_to_plot: str | None,
    color_source_vector: pd.Series[CategoricalDtype] | Categorical,
    color_vector: pd.Series[CategoricalDtype] | Categorical,
    adata: AnnData | None = None,
    palette: ListedColormap | str | list[str] | None = None,
    alpha: float = 1.0,
    na_color: Color = Color("default"),
    legend_fontsize: int | float | _FontSize | None = None,
    legend_fontweight: int | _FontWeight = "bold",
    legend_loc: str | None = "right margin",
    legend_fontoutline: int | None = None,
    na_in_legend: bool = True,
    colorbar: bool = True,
    colorbar_params: dict[str, object] | None = None,
    colorbar_requests: list[ColorbarSpec] | None = None,
    colorbar_label: str | None = None,
    scalebar_dx: Sequence[float] | None = None,
    scalebar_units: Sequence[str] | None = None,
    scalebar_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> Axes:
    if value_to_plot is not None:
        # if only dots were plotted without an associated value
        # there is not need to plot a legend or a colorbar

        if legend_fontoutline is not None:
            path_effect = [patheffects.withStroke(linewidth=legend_fontoutline, foreground="w")]
        else:
            path_effect = []

        # Adding legends
        if color_source_vector is not None and isinstance(color_source_vector.dtype, pd.CategoricalDtype):
            # order of clusters should agree to palette order
            clusters = color_source_vector.remove_unused_categories().unique()
            clusters = clusters[~clusters.isnull()]
            # derive mapping from color_source_vector and color_vector
            group_to_color_matching = pd.DataFrame(
                {
                    "cats": color_source_vector.remove_unused_categories(),
                    "color": color_vector,
                }
            )
            color_mapping = group_to_color_matching.drop_duplicates("cats").set_index("cats")["color"].to_dict()
            _add_categorical_legend(
                ax,
                pd.Categorical(values=color_source_vector, categories=clusters),
                palette=color_mapping,
                legend_loc=legend_loc,
                legend_fontweight=legend_fontweight,
                legend_fontsize=legend_fontsize,
                legend_fontoutline=path_effect,
                na_color=[na_color.get_hex()],
                na_in_legend=na_in_legend,
                multi_panel=fig_params.axs is not None,
            )
        elif colorbar and colorbar_requests is not None and cax is not None:
            colorbar_requests.append(
                ColorbarSpec(
                    ax=ax,
                    mappable=cax,
                    params=colorbar_params,
                    label=colorbar_label,
                    alpha=alpha,
                )
            )

    if isinstance(scalebar_dx, list) and isinstance(scalebar_units, list):
        scalebar = ScaleBar(scalebar_dx, units=scalebar_units, **scalebar_kwargs)
        ax.add_artist(scalebar)

    return ax


def _get_list(
    var: Any,
    _type: type[Any] | tuple[type[Any], ...],
    ref_len: int | None = None,
    name: str | None = None,
) -> list[Any]:
    """
    Get a list from a variable.

    Parameters
    ----------
    var
        Variable to convert to a list.
    _type
        Type of the elements in the list.
    ref_len
        Reference length of the list.
    name
        Name of the variable.

    Returns
    -------
    List
    """
    if isinstance(var, _type):
        return [var] if ref_len is None else ([var] * ref_len)
    if isinstance(var, list):
        if ref_len is not None and ref_len != len(var):
            raise ValueError(
                f"Variable: `{name}` has length: {len(var)}, which is not equal to reference length: {ref_len}."
            )
        for v in var:
            if not isinstance(v, _type):
                raise ValueError(f"Variable: `{name}` has invalid type: {type(v)}, expected: {_type}.")
        return var

    raise ValueError(f"Can't make a list from variable: `{var}`")


def save_fig(
    fig: Figure,
    path: str | Path,
    make_dir: bool = True,
    ext: str = "png",
    **kwargs: Any,
) -> None:
    """
    Save a figure.

    Parameters
    ----------
    fig
        Figure to save.
    path
        Path where to save the figure. If path is relative, save it under :attr:`scanpy.settings.figdir`.
    make_dir
        Whether to try making the directory if it does not exist.
    ext
        Extension to use if none is provided.
    kwargs
        Keyword arguments for :func:`matplotlib.figure.Figure.savefig`.

    Returns
    -------
    None
        Just saves the plot.
    """
    if os.path.splitext(path)[1] == "":
        path = f"{path}.{ext}"

    path = Path(path)

    if not path.is_absolute():
        path = Path(settings.figdir) / path

    if make_dir:
        try:
            path.parent.mkdir(parents=True, exist_ok=True)
        except OSError as e:
            logger.debug(f"Unable to create directory `{path.parent}`. Reason: `{e}`")

    logger.debug(f"Saving figure to `{path!r}`")

    kwargs.setdefault("bbox_inches", "tight")
    kwargs.setdefault("transparent", True)

    fig.savefig(path, **kwargs)


def _get_linear_colormap(colors: list[str], background: str) -> list[LinearSegmentedColormap]:
    return [LinearSegmentedColormap.from_list(c, [background, c], N=256) for c in colors]


def _validate_polygons(shapes: GeoDataFrame) -> GeoDataFrame:
    """
    Convert Polygons with holes to MultiPolygons to keep interior rings during rendering.

    Parameters
    ----------
    shapes
        GeoDataFrame containing a `geometry` column.

    Returns
    -------
    GeoDataFrame
        ``shapes`` with holed Polygons converted to MultiPolygons.
    """
    if "geometry" not in shapes:
        return shapes

    converted_count = 0
    for idx, geom in shapes["geometry"].items():
        if isinstance(geom, shapely.Polygon) and len(geom.interiors) > 0:
            shapes.at[idx, "geometry"] = shapely.MultiPolygon([geom])
            converted_count += 1

    if converted_count > 0:
        logger.info(
            "Converted %d Polygon(s) with holes to MultiPolygon(s) for correct rendering.",
            converted_count,
        )

    return shapes


def _make_patch_from_multipolygon(mp: shapely.MultiPolygon) -> list[mpatches.PathPatch]:
    """
    Create PathPatches from a MultiPolygon, preserving holes robustly.

    This follows the same strategy as GeoPandas' internal Polygon plotting:
    each (multi)polygon part becomes a compound Path composed of the exterior
    ring and all interior rings. Orientation is handled by prior geometry
    normalization rather than manual ring reversal.
    """
    patches: list[mpatches.PathPatch] = []

    for poly in mp.geoms:
        if poly.is_empty:
            continue

        # Ensure 2D vertices in case geometries carry Z
        exterior = np.asarray(poly.exterior.coords)[..., :2]
        interiors = [np.asarray(ring.coords)[..., :2] for ring in poly.interiors]

        if len(interiors) == 0:
            # Simple polygon without holes
            patches.append(mpatches.Polygon(exterior, closed=True))
            continue

        # Build a compound path: exterior + all interior rings
        compound_path = mpath.Path.make_compound_path(
            mpath.Path(exterior, closed=True),
            *[mpath.Path(ring, closed=True) for ring in interiors],
        )
        patches.append(mpatches.PathPatch(compound_path))

    return patches


def _mpl_ax_contains_elements(ax: Axes) -> bool:
    """Check if any objects have been plotted on the axes object.

    While extracting the extent, we need to know if the axes object has just been
    initialised and therefore has extent (0, 1), (0,1) or if it has been plotted on
    and therefore has a different extent.

    Based on: https://stackoverflow.com/a/71966295
    """
    return (
        len(ax.lines) > 0 or len(ax.collections) > 0 or len(ax.images) > 0 or len(ax.patches) > 0 or len(ax.tables) > 0
    )


def _get_valid_cs(
    sdata: sd.SpatialData,
    coordinate_systems: list[str],
    render_images: bool,
    render_labels: bool,
    render_points: bool,
    render_shapes: bool,
    elements: list[str],
) -> list[str]:
    """Get names of the valid coordinate systems.

    Valid cs are cs that contain elements to be rendered:
    1. In case the user specified elements:
        all cs that contain at least one of those elements
    2. Else:
        all cs that contain at least one element that should
        be rendered (depending on whether images/points/labels/...
        should be rendered)
    """
    cs_mapping = _get_coordinate_system_mapping(sdata)
    valid_cs = []
    for cs in coordinate_systems:
        if (
            elements
            and any(e in elements for e in cs_mapping[cs])
            or not elements
            and (
                (len(sdata.images.keys()) > 0 and render_images)
                or (len(sdata.labels.keys()) > 0 and render_labels)
                or (len(sdata.points.keys()) > 0 and render_points)
                or (len(sdata.shapes.keys()) > 0 and render_shapes)
            )
        ):  # not nice, but ruff wants it (SIM114)
            valid_cs.append(cs)
        else:
            logger.info(f"Dropping coordinate system '{cs}' since it doesn't have relevant elements.")
    return valid_cs


def _rasterize_if_necessary(
    image: DataArray,
    dpi: float,
    width: float,
    height: float,
    coordinate_system: str,
    extent: dict[str, tuple[float, float]],
) -> DataArray:
    """Ensure fast rendering by adapting the resolution if necessary.

    A DataArray is prepared for plotting. To improve performance, large images are rasterized.

    Parameters
    ----------
    image
        Input spatial image that should be rendered
    dpi
        Resolution of the figure
    width
        Width (in inches) of the figure
    height
        Height (in inches) of the figure
    coordinate_system
        name of the coordinate system the image belongs to
    extent
        extent of the (full size) image. Must be a dict containing a tuple with min and
        max extent for the keys "x" and "y".

    Returns
    -------
    DataArray
        Spatial image ready for rendering
    """
    has_c_dim = len(image.shape) == 3
    if has_c_dim:
        y_dims = image.shape[1]
        x_dims = image.shape[2]
    else:
        y_dims = image.shape[0]
        x_dims = image.shape[1]

    target_y_dims = dpi * height
    target_x_dims = dpi * width

    # Heuristics for when to rasterize
    do_rasterization = y_dims > target_y_dims + 100 or x_dims > target_x_dims + 100
    if x_dims < 2000 and y_dims < 2000:
        do_rasterization = False

    if do_rasterization:
        logger.info("Rasterizing image for faster rendering.")
        target_unit_to_pixels = min(target_y_dims / y_dims, target_x_dims / x_dims)
        image = rasterize(
            image,
            ("y", "x"),
            [extent["y"][0], extent["x"][0]],
            [extent["y"][1], extent["x"][1]],
            coordinate_system,
            target_unit_to_pixels=target_unit_to_pixels,
        )

    return image


def _multiscale_to_spatial_image(
    multiscale_image: DataTree,
    dpi: float,
    width: float,
    height: float,
    scale: str | None = None,
    is_label: bool = False,
) -> DataArray:
    """Extract the DataArray to be rendered from a multiscale image.

    From the `DataTree`, the scale that fits the given image size and dpi most is selected
    and returned. In case the lowest resolution is still too high, a rasterization step is added.

    Parameters
    ----------
    multiscale_image
        `DataTree` that should be rendered
    dpi
        dpi of the target image
    width
        width of the target image in inches
    height
        height of the target image in inches
    scale
        specific scale that the user chose, if None the heuristic is used
    is_label
        When True, the multiscale image contains labels which don't contain the `c` dimension

    Returns
    -------
    DataArray
        To be rendered, extracted from the DataTree respecting the dpi and size of the target image.
    """
    scales = [leaf.name for leaf in multiscale_image.leaves]
    x_dims = [multiscale_image[scale].dims["x"] for scale in scales]
    y_dims = [multiscale_image[scale].dims["y"] for scale in scales]

    if isinstance(scale, str):
        if scale not in scales and scale != "full":
            raise ValueError(f'Scale {scale} does not exist. Please select one of {scales} or set scale = "full"!')
        optimal_scale = scale
        if scale == "full":
            # use scale with highest resolution
            optimal_scale = scales[np.argmax(x_dims)]
    else:
        # ensure that lists are sorted
        order = np.argsort(x_dims)
        scales = [scales[i] for i in order]
        x_dims = [x_dims[i] for i in order]
        y_dims = [y_dims[i] for i in order]

        optimal_x = width * dpi
        optimal_y = height * dpi

        # get scale where the dimensions are close to the optimal values
        # when possible, pick higher resolution (worst case: downscaled afterwards)
        optimal_index_y = np.searchsorted(y_dims, optimal_y)
        if optimal_index_y == len(y_dims):
            optimal_index_y -= 1
        optimal_index_x = np.searchsorted(x_dims, optimal_x)
        if optimal_index_x == len(x_dims):
            optimal_index_x -= 1

        # pick the scale with higher resolution (worst case: downscaled afterwards)
        optimal_scale = scales[min(int(optimal_index_x), int(optimal_index_y))]

    # NOTE: problematic if there are cases with > 1 data variable
    data_var_keys = list(multiscale_image[optimal_scale].data_vars)
    image = multiscale_image[optimal_scale][data_var_keys[0]]

    return Labels2DModel.parse(image) if is_label else Image2DModel.parse(image, c_coords=image.coords["c"].values)


def _get_elements_to_be_rendered(
    render_cmds: list[
        tuple[
            str,
            ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
        ]
    ],
    cs_contents: pd.DataFrame,
    cs: str,
) -> list[str]:
    """
    Get the names of the elements to be rendered in the plot.

    Parameters
    ----------
    render_cmds
        List of tuples containing the commands and their respective parameters.
    cs_contents
        The dataframe indicating for each coordinate system which SpatialElements it contains.
    cs
        The name of the coordinate system to query cs_contents for.

    Returns
    -------
    List of names of the SpatialElements to be rendered in the plot.
    """
    elements_to_be_rendered: list[str] = []
    render_cmds_map = {
        "render_images": "has_images",
        "render_shapes": "has_shapes",
        "render_points": "has_points",
        "render_labels": "has_labels",
    }

    cs_query = cs_contents.query(f"cs == '{cs}'")

    for cmd, params in render_cmds:
        key = render_cmds_map.get(cmd)
        if key and cs_query[key][0]:
            elements_to_be_rendered += [params.element]

    return elements_to_be_rendered


def _validate_show_parameters(
    coordinate_systems: list[str] | str | None,
    legend_fontsize: int | float | _FontSize | None,
    legend_fontweight: int | _FontWeight,
    legend_loc: str | None,
    legend_fontoutline: int | None,
    na_in_legend: bool,
    colorbar: bool,
    colorbar_params: dict[str, object] | None,
    wspace: float | None,
    hspace: float,
    ncols: int,
    frameon: bool | None,
    figsize: tuple[float, float] | None,
    dpi: int | None,
    fig: Figure | None,
    title: list[str] | str | None,
    share_extent: bool,
    pad_extent: int | float,
    ax: list[Axes] | Axes | None,
    return_ax: bool,
    save: str | Path | None,
) -> None:
    if coordinate_systems is not None and not isinstance(coordinate_systems, list | str):
        raise TypeError("Parameter 'coordinate_systems' must be a string or a list of strings.")

    font_weights = ["light", "normal", "medium", "semibold", "bold", "heavy", "black"]
    if legend_fontweight is not None and (
        not isinstance(legend_fontweight, int | str)
        or (isinstance(legend_fontweight, str) and legend_fontweight not in font_weights)
    ):
        readable_font_weights = ", ".join(font_weights[:-1]) + ", or " + font_weights[-1]
        raise TypeError(
            "Parameter 'legend_fontweight' must be an integer or one of",
            f"the following strings: {readable_font_weights}.",
        )

    font_sizes = [
        "xx-small",
        "x-small",
        "small",
        "medium",
        "large",
        "x-large",
        "xx-large",
    ]

    if legend_fontsize is not None and (
        not isinstance(legend_fontsize, int | float | str)
        or (isinstance(legend_fontsize, str) and legend_fontsize not in font_sizes)
    ):
        readable_font_sizes = ", ".join(font_sizes[:-1]) + ", or " + font_sizes[-1]
        raise TypeError(
            "Parameter 'legend_fontsize' must be an integer, a float, or ",
            f"one of the following strings: {readable_font_sizes}.",
        )

    if legend_loc is not None and not isinstance(legend_loc, str):
        raise TypeError("Parameter 'legend_loc' must be a string.")

    if legend_fontoutline is not None and not isinstance(legend_fontoutline, int):
        raise TypeError("Parameter 'legend_fontoutline' must be an integer.")

    if not isinstance(na_in_legend, bool):
        raise TypeError("Parameter 'na_in_legend' must be a boolean.")

    if not isinstance(colorbar, bool):
        raise TypeError("Parameter 'colorbar' must be a boolean.")

    if colorbar_params is not None and not isinstance(colorbar_params, dict):
        raise TypeError("Parameter 'colorbar_params' must be a dictionary or None.")

    if wspace is not None and not isinstance(wspace, float):
        raise TypeError("Parameter 'wspace' must be a float.")

    if not isinstance(hspace, float):
        raise TypeError("Parameter 'hspace' must be a float.")

    if not isinstance(ncols, int):
        raise TypeError("Parameter 'ncols' must be an integer.")

    if frameon is not None and not isinstance(frameon, bool):
        raise TypeError("Parameter 'frameon' must be a boolean.")

    if figsize is not None and not isinstance(figsize, tuple):
        raise TypeError("Parameter 'figsize' must be a tuple of two floats.")

    if dpi is not None and not isinstance(dpi, int):
        raise TypeError("Parameter 'dpi' must be an integer.")

    if fig is not None and not isinstance(fig, Figure):
        raise TypeError("Parameter 'fig' must be a matplotlib.figure.Figure.")

    if title is not None and not isinstance(title, list | str):
        raise TypeError("Parameter 'title' must be a string or a list of strings.")

    if not isinstance(share_extent, bool):
        raise TypeError("Parameter 'share_extent' must be a boolean.")

    if not isinstance(pad_extent, int | float):
        raise TypeError("Parameter 'pad_extent' must be numeric.")

    if ax is not None and not isinstance(ax, Axes | list):
        raise TypeError("Parameter 'ax' must be a matplotlib.axes.Axes or a list of Axes.")

    if not isinstance(return_ax, bool):
        raise TypeError("Parameter 'return_ax' must be a boolean.")

    if save is not None and not isinstance(save, str | Path):
        raise TypeError("Parameter 'save' must be a string or a pathlib.Path.")


def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[str, Any]:
    colorbar = param_dict.get("colorbar", "auto")
    if colorbar not in {True, False, None, "auto"}:
        raise TypeError("Parameter 'colorbar' must be one of True, False or 'auto'.")

    colorbar_params = param_dict.get("colorbar_params")
    if colorbar_params is not None and not isinstance(colorbar_params, dict):
        raise TypeError("Parameter 'colorbar_params' must be a dictionary or None.")

    element = param_dict.get("element")
    if element is not None and not isinstance(element, str):
        raise ValueError(
            "Parameter 'element' must be a string. If you want to display more elements, pass `element` "
            "as `None` or chain pl.render(...).pl.render(...).pl.show()"
        )
    if element_type == "images":
        param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].images.keys())
    elif element_type == "labels":
        param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].labels.keys())
    elif element_type == "points":
        param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].points.keys())
    elif element_type == "shapes":
        param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].shapes.keys())

    channel = param_dict.get("channel")
    if channel is not None and not isinstance(channel, list | str | int):
        raise TypeError("Parameter 'channel' must be a string, an integer, or a list of strings or integers.")
    if isinstance(channel, list):
        if not all(isinstance(c, str | int) for c in channel):
            raise TypeError("Each item in 'channel' list must be a string or an integer.")
        if not all(isinstance(c, type(channel[0])) for c in channel):
            raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.")

    elif "channel" in param_dict:
        param_dict["channel"] = [channel] if channel is not None else None

    contour_px = param_dict.get("contour_px")
    if contour_px and not isinstance(contour_px, int):
        raise TypeError("Parameter 'contour_px' must be an integer.")

    color = param_dict.get("color")
    if color and element_type in {
        "shapes",
        "points",
        "labels",
    }:
        if not isinstance(color, str | tuple | list):
            raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.")
        if element_type in {"shapes", "points"}:
            if _is_color_like(color):
                logger.info("Value for parameter 'color' appears to be a color, using it as such.")
                param_dict["col_for_color"] = None
                param_dict["color"] = Color(color)
                if param_dict["color"].alpha_is_user_defined():
                    if element_type == "points" and param_dict.get("alpha") is None:
                        param_dict["alpha"] = param_dict["color"].get_alpha_as_float()
                    elif element_type == "shapes" and param_dict.get("fill_alpha") is None:
                        param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float()
                    else:
                        logger.info(
                            f"Alpha implied by color '{color}' is ignored since the parameter 'alpha' or 'fill_alpha' "
                            "is set and its value takes precedence."
                        )
            elif isinstance(color, str):
                param_dict["col_for_color"] = color
                param_dict["color"] = None
            else:
                raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.")
    elif "color" in param_dict and element_type != "labels":
        param_dict["col_for_color"] = None

    outline_width = param_dict.get("outline_width")
    if outline_width:
        # outline_width only exists for shapes at the moment
        if isinstance(outline_width, tuple):
            for ow in outline_width:
                if isinstance(ow, float | int):
                    if ow < 0:
                        raise ValueError("Parameter 'outline_width' cannot contain negative values.")
                else:
                    raise TypeError("Parameter 'outline_width' must contain only numerics when it is a tuple.")
        elif not isinstance(outline_width, float | int):
            raise TypeError("Parameter 'outline_width' must be numeric or a tuple of two numerics.")
        if isinstance(outline_width, float | int) and outline_width < 0:
            raise ValueError("Parameter 'outline_width' cannot be negative.")

    outline_alpha = param_dict.get("outline_alpha")
    if outline_alpha:
        if isinstance(outline_alpha, tuple):
            if element_type != "shapes":
                raise ValueError("Parameter 'outline_alpha' must be a single numeric.")
            if len(outline_alpha) == 1:
                if not isinstance(outline_alpha[0], float | int) or not 0 <= outline_alpha[0] <= 1:
                    raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.")
                param_dict["outline_alpha"] = outline_alpha[0]
            elif len(outline_alpha) < 1:
                raise ValueError("Empty tuple is not supported as input for outline_alpha!")
            else:
                if len(outline_alpha) > 2:
                    logger.warning(
                        f"Tuple of length {len(outline_alpha)} was passed for outline_alpha, only first two positions "
                        "are used since more than 2 outlines are not supported!"
                    )
                if (
                    not isinstance(outline_alpha[0], float | int)
                    or not isinstance(outline_alpha[1], float | int)
                    or not 0 <= outline_alpha[0] <= 1
                    or not 0 <= outline_alpha[1] <= 1
                ):
                    raise TypeError("Parameter 'outline_alpha' must contain numeric values between 0 and 1.")
                param_dict["outline_alpha"] = (outline_alpha[0], outline_alpha[1])
        elif not isinstance(outline_alpha, float | int) or not 0 <= outline_alpha <= 1:
            raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.")

    outline_color = param_dict.get("outline_color")
    if outline_color:
        if not isinstance(outline_color, str | tuple | list):
            raise TypeError("Parameter 'color' must be a string or a tuple/list of floats or colors.")
        if isinstance(outline_color, tuple | list):
            if len(outline_color) < 1:
                raise ValueError("Empty tuple is not supported as input for outline_color!")
            if len(outline_color) == 1:
                param_dict["outline_color"] = Color(outline_color[0])
            elif len(outline_color) == 2:
                # assuming the case of 2 outlines
                param_dict["outline_color"] = (Color(outline_color[0]), Color(outline_color[1]))
            elif len(outline_color) in [3, 4]:
                # assuming RGB(A) array
                param_dict["outline_color"] = Color(outline_color)
            else:
                raise ValueError(
                    f"Tuple/List of length {len(outline_color)} was passed for outline_color. Valid options would be: "
                    "tuple of 2 colors (for 2 outlines) or an RGB(A) array, aka a list/tuple of 3-4 floats."
                )
        else:
            param_dict["outline_color"] = Color(outline_color)

    if contour_px is not None and contour_px <= 0:
        raise ValueError("Parameter 'contour_px' must be a positive number.")

    alpha = param_dict.get("alpha")
    if alpha is not None:
        if not isinstance(alpha, float | int):
            raise TypeError("Parameter 'alpha' must be numeric.")
        if not 0 <= alpha <= 1:
            raise ValueError("Parameter 'alpha' must be between 0 and 1.")
    elif element_type == "points":
        # set default alpha for points if not given by user explicitly or implicitly (as part of color)
        param_dict["alpha"] = 1.0

    fill_alpha = param_dict.get("fill_alpha")
    if fill_alpha is not None:
        if not isinstance(fill_alpha, float | int):
            raise TypeError("Parameter 'fill_alpha' must be numeric.")
        if fill_alpha < 0:
            raise ValueError("Parameter 'fill_alpha' cannot be negative.")
    elif element_type == "shapes":
        # set default fill_alpha for shapes if not given by user explicitly or implicitly (as part of color)
        param_dict["fill_alpha"] = 1.0

    cmap = param_dict.get("cmap")
    palette = param_dict.get("palette")
    if cmap is not None and palette is not None:
        raise ValueError("Both `palette` and `cmap` are specified. Please specify only one of them.")
    param_dict["cmap"] = cmap

    groups = param_dict.get("groups")
    if groups is not None:
        if not isinstance(groups, list | str):
            raise TypeError("Parameter 'groups' must be a string or a list of strings.")
        if isinstance(groups, str):
            param_dict["groups"] = [groups]
        elif not all(isinstance(g, str) for g in groups):
            raise TypeError("Each item in 'groups' must be a string.")

    palette = param_dict["palette"]

    if isinstance(palette, list):
        if not all(isinstance(p, str) for p in palette):
            raise ValueError("If specified, parameter 'palette' must contain only strings.")
    elif isinstance(palette, str | type(None)) and "palette" in param_dict:
        param_dict["palette"] = [palette] if palette is not None else None

    palette_group = param_dict.get("palette")
    if element_type in ["shapes", "points", "labels"] and palette_group is not None:
        groups = param_dict.get("groups")
        if groups is None:
            raise ValueError("When specifying 'palette', 'groups' must also be specified.")
        if len(groups) != len(palette_group):
            raise ValueError(
                f"The length of 'palette' and 'groups' must be the same, length is {len(palette_group)} and"
                f"{len(groups)} respectively."
            )

    if isinstance(cmap, list):
        if not all(isinstance(c, Colormap | str) for c in cmap):
            raise TypeError("Each item in 'cmap' list must be a string or a Colormap.")
    elif isinstance(cmap, Colormap | str | type(None)):
        if "cmap" in param_dict:
            param_dict["cmap"] = [cmap] if cmap is not None else None
    else:
        raise TypeError("Parameter 'cmap' must be a string, a Colormap, or a list of these types.")

    # validation happens within Color constructor
    param_dict["na_color"] = Color(param_dict.get("na_color"))

    norm = param_dict.get("norm")
    if norm is not None:
        if element_type in {"images", "labels"} and not isinstance(norm, Normalize):
            raise TypeError("Parameter 'norm' must be of type Normalize.")
        if element_type in {"shapes", "points"} and not isinstance(norm, bool | Normalize):
            raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.")

    scale = param_dict.get("scale")
    if scale is not None:
        if element_type in {"images", "labels"} and not isinstance(scale, str):
            raise TypeError("Parameter 'scale' must be a string if specified.")
        if element_type == "shapes":
            if not isinstance(scale, float | int):
                raise TypeError("Parameter 'scale' must be numeric.")
            if scale < 0:
                raise ValueError("Parameter 'scale' must be a positive number.")

    size = param_dict.get("size")
    if size:
        if not isinstance(size, float | int):
            raise TypeError("Parameter 'size' must be numeric.")
        if size < 0:
            raise ValueError("Parameter 'size' must be a positive number.")

    shape = param_dict.get("shape")
    if element_type == "shapes" and shape is not None:
        valid_shapes = {"circle", "hex", "visium_hex", "square"}
        if not isinstance(shape, str):
            raise TypeError(f"Parameter 'shape' must be a String from {valid_shapes} if not None.")
        if shape not in valid_shapes:
            raise ValueError(f"'{shape}' is not supported for 'shape', please choose from {valid_shapes}.")

    table_name = param_dict.get("table_name")
    table_layer = param_dict.get("table_layer")
    if table_name and not isinstance(param_dict["table_name"], str):
        raise TypeError("Parameter 'table_name' must be a string.")

    if table_layer and not isinstance(param_dict["table_layer"], str):
        raise TypeError("Parameter 'table_layer' must be a string.")

    def _ensure_table_and_layer_exist_in_sdata(
        sdata: SpatialData, table_name: str | None, table_layer: str | None
    ) -> bool:
        """Ensure that table_name and table_layer are valid; throw error if not."""
        if table_name:
            if table_layer:
                if table_layer in sdata.tables[table_name].layers:
                    return True
                raise ValueError(f"Layer '{table_layer}' not found in table '{table_name}'.")
            return True  # using sdata.tables[table_name].X

        if table_layer:
            # user specified a layer but we have no tables => invalid
            if len(sdata.tables) == 0:
                raise ValueError("Trying to use 'table_layer' but no tables are present in the SpatialData object.")
            if len(sdata.tables) == 1:
                single_table_name = list(sdata.tables.keys())[0]
                if table_layer in sdata.tables[single_table_name].layers:
                    return True
                raise ValueError(f"Layer '{table_layer}' not found in table '{single_table_name}'.")
            # more than one tables, try to find which one has the given layer
            found_table = False
            for tname in sdata.tables:
                if table_layer in sdata.tables[tname].layers:
                    if found_table:
                        raise ValueError(
                            "Trying to guess 'table_name' based on 'table_layer', but found multiple matches."
                        )
                    found_table = True

            if found_table:
                return True

            raise ValueError(f"Layer '{table_layer}' not found in any table.")

        return True  # not using any table

    assert _ensure_table_and_layer_exist_in_sdata(param_dict.get("sdata"), table_name, table_layer)

    method = param_dict.get("method")
    if method not in ["matplotlib", "datashader", None]:
        raise ValueError("If specified, parameter 'method' must be either 'matplotlib' or 'datashader'.")

    valid_ds_reduction_methods = [
        "sum",
        "mean",
        "any",
        "count",
        # "m2", -> not intended to be used alone (see https://datashader.org/api.html#datashader.reductions.m2)
        # "mode", -> not supported for points (see https://datashader.org/api.html#datashader.reductions.mode)
        "std",
        "var",
        "max",
        "min",
    ]
    ds_reduction = param_dict.get("ds_reduction")
    if ds_reduction and (ds_reduction not in valid_ds_reduction_methods):
        raise ValueError(f"Parameter 'ds_reduction' must be one of the following: {valid_ds_reduction_methods}.")

    if method == "datashader" and ds_reduction is None:
        param_dict["ds_reduction"] = "sum"

    return param_dict


def _validate_label_render_params(
    sdata: sd.SpatialData,
    element: str | None,
    cmap: list[Colormap | str] | Colormap | str | None,
    color: str | None,
    fill_alpha: float | int,
    contour_px: int | None,
    groups: list[str] | str | None,
    palette: list[str] | str | None,
    na_color: ColorLike | None,
    norm: Normalize | None,
    outline_alpha: float | int,
    scale: str | None,
    table_name: str | None,
    table_layer: str | None,
    colorbar: bool | str | None,
    colorbar_params: dict[str, object] | None,
) -> dict[str, dict[str, Any]]:
    param_dict: dict[str, Any] = {
        "sdata": sdata,
        "element": element,
        "fill_alpha": fill_alpha,
        "contour_px": contour_px,
        "groups": groups,
        "palette": palette,
        "color": color,
        "na_color": na_color,
        "outline_alpha": outline_alpha,
        "cmap": cmap,
        "norm": norm,
        "scale": scale,
        "table_name": table_name,
        "table_layer": table_layer,
        "colorbar": colorbar,
        "colorbar_params": colorbar_params,
    }
    param_dict = _type_check_params(param_dict, "labels")

    element_params: dict[str, dict[str, Any]] = {}
    for el in param_dict["element"]:
        # ensure that the element exists in the SpatialData object
        _ = param_dict["sdata"][el]

        element_params[el] = {}
        element_params[el]["na_color"] = param_dict["na_color"]
        element_params[el]["cmap"] = param_dict["cmap"]
        element_params[el]["norm"] = param_dict["norm"]
        element_params[el]["fill_alpha"] = param_dict["fill_alpha"]
        element_params[el]["scale"] = param_dict["scale"]
        element_params[el]["outline_alpha"] = param_dict["outline_alpha"]
        element_params[el]["contour_px"] = param_dict["contour_px"]
        element_params[el]["table_layer"] = param_dict["table_layer"]

        element_params[el]["table_name"] = None
        element_params[el]["color"] = None
        color = param_dict["color"]
        if color is not None:
            color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True)
            element_params[el]["table_name"] = table_name
            element_params[el]["color"] = color

        element_params[el]["palette"] = param_dict["palette"] if element_params[el]["table_name"] is not None else None
        element_params[el]["groups"] = param_dict["groups"] if element_params[el]["table_name"] is not None else None
        element_params[el]["colorbar"] = param_dict["colorbar"]
        element_params[el]["colorbar_params"] = param_dict["colorbar_params"]

    return element_params


def _validate_points_render_params(
    sdata: sd.SpatialData,
    element: str | None,
    alpha: float | int | None,
    color: ColorLike | None,
    groups: list[str] | str | None,
    palette: list[str] | str | None,
    na_color: ColorLike | None,
    cmap: list[Colormap | str] | Colormap | str | None,
    norm: Normalize | None,
    size: float | int,
    table_name: str | None,
    table_layer: str | None,
    ds_reduction: str | None,
    colorbar: bool | str | None,
    colorbar_params: dict[str, object] | None,
) -> dict[str, dict[str, Any]]:
    param_dict: dict[str, Any] = {
        "sdata": sdata,
        "element": element,
        "alpha": alpha,
        "color": color,
        "groups": groups,
        "palette": palette,
        "na_color": na_color,
        "cmap": cmap,
        "norm": norm,
        "size": size,
        "table_name": table_name,
        "table_layer": table_layer,
        "ds_reduction": ds_reduction,
        "colorbar": colorbar,
        "colorbar_params": colorbar_params,
    }
    param_dict = _type_check_params(param_dict, "points")

    element_params: dict[str, dict[str, Any]] = {}
    for el in param_dict["element"]:
        # ensure that the element exists in the SpatialData object
        _ = param_dict["sdata"][el]

        element_params[el] = {}
        element_params[el]["na_color"] = param_dict["na_color"]
        element_params[el]["cmap"] = param_dict["cmap"]
        element_params[el]["norm"] = param_dict["norm"]
        element_params[el]["color"] = param_dict["color"]
        element_params[el]["size"] = param_dict["size"]
        element_params[el]["alpha"] = param_dict["alpha"]
        element_params[el]["table_layer"] = param_dict["table_layer"]

        element_params[el]["table_name"] = None
        element_params[el]["col_for_color"] = None
        col_for_color = param_dict["col_for_color"]
        if col_for_color is not None:
            col_for_color, table_name = _validate_col_for_column_table(
                sdata, el, col_for_color, param_dict["table_name"]
            )
            element_params[el]["table_name"] = table_name
            element_params[el]["col_for_color"] = col_for_color

        element_params[el]["palette"] = param_dict["palette"] if param_dict["col_for_color"] is not None else None
        element_params[el]["groups"] = param_dict["groups"] if param_dict["col_for_color"] is not None else None
        element_params[el]["ds_reduction"] = param_dict["ds_reduction"]
        element_params[el]["colorbar"] = param_dict["colorbar"]
        element_params[el]["colorbar_params"] = param_dict["colorbar_params"]

    return element_params


def _validate_shape_render_params(
    sdata: sd.SpatialData,
    element: str | None,
    fill_alpha: float | int | None,
    groups: list[str] | str | None,
    palette: list[str] | str | None,
    color: ColorLike | None,
    na_color: ColorLike | None,
    outline_width: float | int | tuple[float | int, float | int] | None,
    outline_color: ColorLike | tuple[ColorLike] | None,
    outline_alpha: float | int | tuple[float | int, float | int] | None,
    cmap: list[Colormap | str] | Colormap | str | None,
    norm: Normalize | None,
    scale: float | int,
    table_name: str | None,
    table_layer: str | None,
    shape: Literal["circle", "hex", "visium_hex", "square"] | None,
    method: str | None,
    ds_reduction: str | None,
    colorbar: bool | str | None,
    colorbar_params: dict[str, object] | None,
) -> dict[str, dict[str, Any]]:
    param_dict: dict[str, Any] = {
        "sdata": sdata,
        "element": element,
        "fill_alpha": fill_alpha,
        "groups": groups,
        "palette": palette,
        "color": color,
        "na_color": na_color,
        "outline_width": outline_width,
        "outline_color": outline_color,
        "outline_alpha": outline_alpha,
        "cmap": cmap,
        "norm": norm,
        "scale": scale,
        "table_name": table_name,
        "table_layer": table_layer,
        "shape": shape,
        "method": method,
        "ds_reduction": ds_reduction,
        "colorbar": colorbar,
        "colorbar_params": colorbar_params,
    }
    param_dict = _type_check_params(param_dict, "shapes")

    element_params: dict[str, dict[str, Any]] = {}
    for el in param_dict["element"]:
        # ensure that the element exists in the SpatialData object
        _ = param_dict["sdata"][el]

        element_params[el] = {}
        element_params[el]["fill_alpha"] = param_dict["fill_alpha"]
        element_params[el]["na_color"] = param_dict["na_color"]
        element_params[el]["outline_width"] = param_dict["outline_width"]
        element_params[el]["outline_color"] = param_dict["outline_color"]
        element_params[el]["outline_alpha"] = param_dict["outline_alpha"]
        element_params[el]["cmap"] = param_dict["cmap"]
        element_params[el]["norm"] = param_dict["norm"]
        element_params[el]["scale"] = param_dict["scale"]
        element_params[el]["table_layer"] = param_dict["table_layer"]
        element_params[el]["shape"] = param_dict["shape"]

        element_params[el]["color"] = param_dict["color"]

        element_params[el]["table_name"] = None
        element_params[el]["col_for_color"] = None
        col_for_color = param_dict["col_for_color"]
        if col_for_color is not None:
            col_for_color, table_name = _validate_col_for_column_table(
                sdata, el, col_for_color, param_dict["table_name"]
            )
            element_params[el]["table_name"] = table_name
            element_params[el]["col_for_color"] = col_for_color

        element_params[el]["palette"] = param_dict["palette"] if param_dict["col_for_color"] is not None else None
        element_params[el]["groups"] = param_dict["groups"] if param_dict["col_for_color"] is not None else None
        element_params[el]["method"] = param_dict["method"]
        element_params[el]["ds_reduction"] = param_dict["ds_reduction"]
        element_params[el]["colorbar"] = param_dict["colorbar"]
        element_params[el]["colorbar_params"] = param_dict["colorbar_params"]

    return element_params


def _validate_col_for_column_table(
    sdata: SpatialData,
    element_name: str,
    col_for_color: str | None,
    table_name: str | None,
    labels: bool = False,
) -> tuple[str | None, str | None]:
    if col_for_color is None:
        return None, None

    if not labels and col_for_color in sdata[element_name].columns:
        table_name = None
    elif table_name is not None:
        tables = get_element_annotators(sdata, element_name)
        if table_name not in tables:
            raise KeyError(f"Table '{table_name}' does not annotate element '{element_name}'.")
        if col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names:
            raise KeyError(
                f"Column '{col_for_color}' not found in obs/var of table '{table_name}' for element '{element_name}'."
            )
    else:
        tables = get_element_annotators(sdata, element_name)
        if len(tables) == 0:
            raise KeyError(
                f"Element '{element_name}' has no annotating tables. "
                f"Cannot use column '{col_for_color}' for coloring. "
                "Please ensure the element is annotated by at least one table."
            )
        # Now check which tables contain the column
        for annotates in tables.copy():
            if col_for_color not in sdata[annotates].obs.columns and col_for_color not in sdata[annotates].var_names:
                tables.remove(annotates)
        if len(tables) == 0:
            raise KeyError(
                f"Unable to locate color key '{col_for_color}' for element '{element_name}'. "
                "Please ensure the key exists in a table annotating this element."
            )
        table_name = next(iter(tables))
        if len(tables) > 1:
            logger.warning(f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.")
    return col_for_color, table_name


def _validate_image_render_params(
    sdata: sd.SpatialData,
    element: str | None,
    channel: list[str] | list[int] | str | int | None,
    alpha: float | int | None,
    palette: list[str] | str | None,
    na_color: ColorLike | None,
    cmap: list[Colormap | str] | Colormap | str | None,
    norm: Normalize | None,
    scale: str | None,
    colorbar: bool | str | None,
    colorbar_params: dict[str, object] | None,
) -> dict[str, dict[str, Any]]:
    param_dict: dict[str, Any] = {
        "sdata": sdata,
        "element": element,
        "channel": channel,
        "alpha": alpha,
        "palette": palette,
        "na_color": na_color,
        "cmap": cmap,
        "norm": norm,
        "scale": scale,
        "colorbar": colorbar,
        "colorbar_params": colorbar_params,
    }
    param_dict = _type_check_params(param_dict, "images")

    element_params: dict[str, dict[str, Any]] = {}
    for el in param_dict["element"]:
        element_params[el] = {}
        spatial_element = param_dict["sdata"][el]

        # robustly get channel names from image or multiscale image
        spatial_element_ch = (
            spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values
        )
        channel = param_dict["channel"]
        if channel is not None:
            # Normalize channel to always be a list of str or a list of int
            if isinstance(channel, str):
                channel = [channel]

            if isinstance(channel, int):
                channel = [channel]

            # If channel is a list, ensure all elements are the same type
            if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)):
                raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.")

            invalid = [c for c in channel if c not in spatial_element_ch]
            if invalid:
                raise ValueError(
                    f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}"
                )
            element_params[el]["channel"] = channel
        else:
            element_params[el]["channel"] = None

        element_params[el]["alpha"] = param_dict["alpha"]

        palette = param_dict["palette"]
        assert isinstance(palette, list | type(None))  # if present, was converted to list, just to make sure

        if isinstance(palette, list):
            # case A: single palette for all channels
            if len(palette) == 1:
                palette_length = len(channel) if channel is not None else len(spatial_element_ch)
                palette = palette * palette_length
            # case B: one palette per channel (either given or derived from channel length)
            channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel
            if channels_to_use is not None and len(palette) != len(channels_to_use):
                raise ValueError(
                    f"Palette length ({len(palette)}) does not match channel length "
                    f"({', '.join(str(c) for c in channels_to_use)})."
                )
        element_params[el]["palette"] = palette
        element_params[el]["na_color"] = param_dict["na_color"]

        cmap = param_dict["cmap"]
        if cmap is not None:
            if len(cmap) == 1:
                cmap_length = len(channel) if channel is not None else len(spatial_element_ch)
                cmap = cmap * cmap_length
            if (channel is not None and len(cmap) != len(channel)) or len(cmap) != len(spatial_element_ch):
                cmap = None
        element_params[el]["cmap"] = cmap
        element_params[el]["norm"] = param_dict["norm"]
        scale = param_dict["scale"]
        if scale and isinstance(param_dict["sdata"][el], DataTree):
            if scale not in list(param_dict["sdata"][el].keys()) and scale != "full":
                element_params[el]["scale"] = None
            else:
                element_params[el]["scale"] = scale
        else:
            element_params[el]["scale"] = scale
        element_params[el]["colorbar"] = param_dict["colorbar"]
        element_params[el]["colorbar_params"] = param_dict["colorbar_params"]

    return element_params


def _get_wanted_render_elements(
    sdata: SpatialData,
    sdata_wanted_elements: list[str],
    params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
    cs: str,
    element_type: Literal["images", "labels", "points", "shapes"],
) -> tuple[list[str], list[str], bool]:
    wants_elements = True
    if element_type in [
        "images",
        "labels",
        "points",
        "shapes",
    ]:  # Prevents eval security risk
        wanted_elements: list[str] = [params.element]
        wanted_elements_on_cs = [
            element for element in wanted_elements if cs in set(get_transformation(sdata[element], get_all=True).keys())
        ]

        sdata_wanted_elements.extend(wanted_elements_on_cs)
        return sdata_wanted_elements, wanted_elements_on_cs, wants_elements

    raise ValueError(f"Unknown element type {element_type}")


def _ax_show_and_transform(
    array: MaskedArray[tuple[int, ...], Any] | npt.NDArray[Any],
    trans_data: CompositeGenericTransform,
    ax: Axes,
    alpha: float | None = None,
    cmap: ListedColormap | LinearSegmentedColormap | None = None,
    zorder: int = 0,
    extent: list[float] | None = None,
    norm: Normalize | None = None,
) -> matplotlib.image.AxesImage:
    # default extent in mpl:
    image_extent = [-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5]
    if extent is not None:
        # make sure extent is [x_min, x_max, y_min, y_max]
        if extent[3] < extent[2]:
            extent[2], extent[3] = extent[3], extent[2]
        if extent[0] < 0:
            x_factor = array.shape[1] / (extent[1] - extent[0])
            image_extent[0] = image_extent[0] + (extent[0] * x_factor)
            image_extent[1] = image_extent[1] + (extent[0] * x_factor)
        if extent[2] < 0:
            y_factor = array.shape[0] / (extent[3] - extent[2])
            image_extent[2] = image_extent[2] + (extent[2] * y_factor)
            image_extent[3] = image_extent[3] + (extent[2] * y_factor)

    if not cmap and alpha is not None:
        im = ax.imshow(
            array,
            alpha=alpha,
            zorder=zorder,
            extent=tuple(image_extent),
            norm=norm,
        )
        im.set_transform(trans_data)
    else:
        im = ax.imshow(
            array,
            cmap=cmap,
            zorder=zorder,
            extent=tuple(image_extent),
            norm=norm,
        )
        im.set_transform(trans_data)
    return im


def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = None) -> ListedColormap:
    """
    Modify colormap so that 0s are transparent.

    Parameters
    ----------
    cmap (Colormap | str): A matplotlib Colormap instance or a colormap name string.
    steps (int): The number of steps in the colormap.

    Returns
    -------
    ListedColormap: A new colormap instance with modified alpha values.
    """
    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)

    colors = cmap(np.arange(steps or cmap.N))
    colors[0, :] = [1.0, 1.0, 1.0, 0.0]

    return ListedColormap(colors)


def _get_extent_and_range_for_datashader_canvas(
    spatial_element: SpatialElement,
    coordinate_system: str,
    ax: Axes,
    fig_params: FigParams,
) -> tuple[Any, Any, list[Any], list[Any], Any]:
    extent = get_extent(spatial_element, coordinate_system=coordinate_system)
    x_ext = [min(0, extent["x"][0]), extent["x"][1]]
    y_ext = [min(0, extent["y"][0]), extent["y"][1]]
    previous_xlim = ax.get_xlim()
    previous_ylim = ax.get_ylim()
    # increase range if sth larger was rendered on the axis before
    if _mpl_ax_contains_elements(ax):
        x_ext = [min(x_ext[0], previous_xlim[0]), max(x_ext[1], previous_xlim[1])]
        y_ext = (
            [
                min(y_ext[0], previous_ylim[1]),
                max(y_ext[1], previous_ylim[0]),
            ]
            if ax.yaxis_inverted()
            else [
                min(y_ext[0], previous_ylim[0]),
                max(y_ext[1], previous_ylim[1]),
            ]
        )

    # compute canvas size in pixels close to the actual image size to speed up computation
    plot_width = x_ext[1] - x_ext[0]
    plot_height = y_ext[1] - y_ext[0]
    plot_width_px = int(round(fig_params.fig.get_size_inches()[0] * fig_params.fig.dpi))
    plot_height_px = int(round(fig_params.fig.get_size_inches()[1] * fig_params.fig.dpi))
    factor: float
    factor = np.min([plot_width / plot_width_px, plot_height / plot_height_px])
    plot_width = int(np.round(plot_width / factor))
    plot_height = int(np.round(plot_height / factor))

    return plot_width, plot_height, x_ext, y_ext, factor


def _create_image_from_datashader_result(
    ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]],
    factor: float,
    ax: Axes,
) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]:
    # create SpatialImage from datashader output to get it back to original size
    rgba_image_data = ds_result.copy() if isinstance(ds_result, np.ndarray) else ds_result.to_numpy().base
    rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1))
    rgba_image = Image2DModel.parse(
        rgba_image_data,
        dims=("c", "y", "x"),
        transformations={"global": Scale([1, factor, factor], ("c", "y", "x"))},
    )

    _, trans_data = _prepare_transformation(rgba_image, "global", ax)

    rgba_image = np.transpose(rgba_image.data.compute(), (1, 2, 0))  # type: ignore[attr-defined]
    rgba_image = ma.masked_array(rgba_image)  # type conversion for mypy

    return rgba_image, trans_data


def _datashader_aggregate_with_function(
    reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
    cvs: Canvas,
    spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
    col_for_color: str | None,
    element_type: Literal["points", "shapes"],
) -> DataArray:
    """
    When shapes or points are colored by a continuous value during rendering with datashader.

    This function performs the aggregation using the user-specified reduction method.

    Parameters
    ----------
    reduction: String specifying the datashader reduction method to be used.
        If None, "sum" is used as default.
    cvs: Canvas object previously created with ds.Canvas()
    spatial_element: geo or dask dataframe with the shapes or points to render
    col_for_color: name of the column containing the values by which to color
    element_type: tells us if this function is called from _render_shapes() or _render_points()
    """
    if reduction is None:
        reduction = "sum"

    reduction_function_map = {
        "sum": ds.sum,
        "mean": ds.mean,
        "any": ds.any,
        "count": ds.count,
        "std": ds.std,
        "var": ds.var,
        "max": ds.max,
        "min": ds.min,
    }

    try:
        reduction_function = reduction_function_map[reduction](column=col_for_color)
    except KeyError as e:
        raise ValueError(
            f"Reduction '{reduction}' is not supported. Please use one of: {', '.join(reduction_function_map.keys())}."
        ) from e

    element_function_map = {
        "points": cvs.points,
        "shapes": cvs.polygons,
    }

    try:
        element_function = element_function_map[element_type]
    except KeyError as e:
        raise ValueError(f"Element type '{element_type}' is not supported. Use 'points' or 'shapes'.") from e

    if element_type == "points":
        points_aggregate = element_function(spatial_element, "x", "y", agg=reduction_function)
        if reduction == "any":
            # replace False/True by nan/1
            points_aggregate = points_aggregate.astype(int)
            points_aggregate = points_aggregate.where(points_aggregate > 0)
        return points_aggregate

    # is shapes
    return element_function(spatial_element, geometry="geometry", agg=reduction_function)


def _datshader_get_how_kw_for_spread(
    reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
) -> str:
    # Get the best input for the how argument of ds.tf.spread(), needed for numerical values
    reduction = reduction or "sum"

    reduction_to_how_map = {
        "sum": "add",
        "mean": "source",
        "any": "source",
        "count": "add",
        "std": "source",
        "var": "source",
        "max": "max",
        "min": "min",
    }

    if reduction not in reduction_to_how_map:
        raise ValueError(
            f"Reduction {reduction} is not supported, please use one of the following: sum, mean, any, count"
            ", std, var, max, min."
        )

    return reduction_to_how_map[reduction]


def _prepare_transformation(
    element: DataArray | GeoDataFrame | dask.dataframe.core.DataFrame,
    coordinate_system: str,
    ax: Axes | None = None,
) -> tuple[
    matplotlib.transforms.Affine2D,
    matplotlib.transforms.CompositeGenericTransform | None,
]:
    trans = get_transformation(element, get_all=True)[coordinate_system]
    affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))
    trans = mtransforms.Affine2D(matrix=affine_trans)
    trans_data = trans + ax.transData if ax is not None else None

    return trans, trans_data


def _datashader_map_aggregate_to_color(
    agg: DataArray,
    cmap: str | list[str] | ListedColormap,
    color_key: None | list[str] = None,
    min_alpha: float = 40,
    span: None | list[float] = None,
    clip: bool = True,
) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]:
    """ds.tf.shade() part, ensuring correct clipping behavior.

    If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results.
    This ensures the correct clipping behavior, because else datashader would always automatically clip.
    """
    if not clip and isinstance(cmap, Colormap) and span is not None:
        # in case we use datashader together with a Normalize object where clip=False
        # why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372
        agg_in = agg.where((agg >= span[0]) & (agg <= span[1]))
        img_in = ds.tf.shade(
            agg_in,
            cmap=cmap,
            span=(span[0], span[1]),
            how="linear",
            color_key=color_key,
            min_alpha=min_alpha,
        )

        agg_under = agg.where(agg < span[0])
        img_under = ds.tf.shade(
            agg_under,
            cmap=[to_hex(cmap.get_under())[:7]],
            min_alpha=min_alpha,
            color_key=color_key,
        )

        agg_over = agg.where(agg > span[1])
        img_over = ds.tf.shade(
            agg_over,
            cmap=[to_hex(cmap.get_over())[:7]],
            min_alpha=min_alpha,
            color_key=color_key,
        )

        # stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0
        stack = img_under.to_numpy().base
        if stack is None:
            stack = img_in.to_numpy().base
        else:
            stack[stack[:, :, 3] == 0] = img_in.to_numpy().base[stack[:, :, 3] == 0]
        img_over = img_over.to_numpy().base
        if img_over is not None:
            stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0]
        return stack

    return ds.tf.shade(
        agg,
        cmap=cmap,
        color_key=color_key,
        min_alpha=min_alpha,
        span=span,
        how="linear",
    )


def _hex_no_alpha(hex: str) -> str:
    """
    Return a hex color string without an alpha component.

    Parameters
    ----------
    hex : str
        The input hex color string. Must be in one of the following formats:
        - "#RRGGBB": a hex color without an alpha channel.
        - "#RRGGBBAA": a hex color with an alpha channel that will be removed.

    Returns
    -------
    str
        The hex color string in "#RRGGBB" format.
    """
    if not isinstance(hex, str):
        raise TypeError("Input must be a string")
    if not hex.startswith("#"):
        raise ValueError("Invalid hex color: must start with '#'")

    hex_digits = hex[1:]
    length = len(hex_digits)

    if length == 6:
        if not all(c in "0123456789abcdefABCDEF" for c in hex_digits):
            raise ValueError("Invalid hex color: contains non-hex characters")
        return hex  # Already in #RRGGBB format.

    if length == 8:
        if not all(c in "0123456789abcdefABCDEF" for c in hex_digits):
            raise ValueError("Invalid hex color: contains non-hex characters")
        # Return only the first 6 characters, stripping the alpha.
        return "#" + hex_digits[:6]

    raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'")


def _convert_shapes(
    shapes: GeoDataFrame,
    target_shape: str,
    max_extent: float,
    warn_above_extent_fraction: float = 0.5,
) -> GeoDataFrame:
    """Convert shapes in a GeoDataFrame to the target_shape, using positional indexing."""
    if warn_above_extent_fraction < 0.0 or warn_above_extent_fraction > 1.0:
        warn_above_extent_fraction = 0.5
    warn_shape_size = False

    # work on a copy with a clean positional index
    shapes = shapes.reset_index(drop=True).copy()

    def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
        verts = [
            (
                center.x + radius * math.cos(math.radians(a)),
                center.y + radius * math.sin(math.radians(a)),
            )
            for a in range(30, 390, 60)
        ]
        return shapely.Polygon(verts), None

    def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
        verts = [
            (
                center.x + radius * math.cos(math.radians(a)),
                center.y + radius * math.sin(math.radians(a)),
            )
            for a in range(45, 360, 90)
        ]
        return shapely.Polygon(verts), None

    def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]:
        return center, radius

    def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]:
        coords = np.array(polygon.exterior.coords)
        hull_pts = coords[ConvexHull(coords).vertices]
        center = np.mean(hull_pts, axis=0)
        radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1)))
        nonlocal warn_shape_size
        if 2 * radius > max_extent * warn_above_extent_fraction:
            warn_shape_size = True
        return shapely.Point(center), radius

    def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
        c, r = _polygon_to_circle(polygon)
        return _circle_to_hexagon(c, r)

    def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
        c, r = _polygon_to_circle(polygon)
        return _circle_to_square(c, r)

    def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]:
        pts = []
        for poly in multipolygon.geoms:
            pts.extend(poly.exterior.coords)
        pts_array = np.array(pts)
        hull_pts = pts_array[ConvexHull(pts_array).vertices]
        center = np.mean(hull_pts, axis=0)
        radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1)))
        nonlocal warn_shape_size
        if 2 * radius > max_extent * warn_above_extent_fraction:
            warn_shape_size = True
        return shapely.Point(center), radius

    def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
        c, r = _multipolygon_to_circle(multipolygon)
        return _circle_to_hexagon(c, r)

    def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
        c, r = _multipolygon_to_circle(multipolygon)
        return _circle_to_square(c, r)

    # choose conversion methods
    conversion_methods: dict[str, Any]
    if target_shape == "circle":
        conversion_methods = {
            "Point": _circle_to_circle,
            "Polygon": _polygon_to_circle,
            "MultiPolygon": _multipolygon_to_circle,
        }
    elif target_shape == "hex":
        conversion_methods = {
            "Point": _circle_to_hexagon,
            "Polygon": _polygon_to_hexagon,
            "MultiPolygon": _multipolygon_to_hexagon,
        }
    elif target_shape == "visium_hex":
        # estimate hex radius from point spacing when possible
        point_centers = []
        non_point_count = 0
        for geom in shapes.geometry:
            if geom.geom_type == "Point":
                point_centers.append((geom.x, geom.y))
            else:
                non_point_count += 1
        if non_point_count > 0:
            logger.warning("visium_hex supports Points best. Non-Point geometries will use regular hex conversion.")
        if len(point_centers) >= 2:
            centers = np.array(point_centers, dtype=float)
            # pairwise min distance
            dmin = np.inf
            for i in range(len(centers)):
                diffs = centers[i + 1 :] - centers[i]
                if diffs.size:
                    d = np.min(np.linalg.norm(diffs, axis=1))
                    dmin = min(dmin, d)
            if not np.isfinite(dmin) or dmin <= 0:
                # fallback
                conversion_methods = {
                    "Point": _circle_to_hexagon,
                    "Polygon": _polygon_to_hexagon,
                    "MultiPolygon": _multipolygon_to_hexagon,
                }
            else:
                hex_radius = dmin / math.sqrt(3.0)

                def _circle_to_visium_hex(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
                    return _circle_to_hexagon(center, hex_radius)

                def _polygon_to_visium_hex(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
                    return _polygon_to_hexagon(polygon)

                def _multipolygon_to_visium_hex(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
                    return _multipolygon_to_hexagon(multipolygon)

                conversion_methods = {
                    "Point": _circle_to_visium_hex,
                    "Polygon": _polygon_to_visium_hex,
                    "MultiPolygon": _multipolygon_to_visium_hex,
                }
        else:
            conversion_methods = {
                "Point": _circle_to_hexagon,
                "Polygon": _polygon_to_hexagon,
                "MultiPolygon": _multipolygon_to_hexagon,
            }
    else:
        conversion_methods = {
            "Point": _circle_to_square,
            "Polygon": _polygon_to_square,
            "MultiPolygon": _multipolygon_to_square,
        }

    # ensure radius column exists if needed
    if "radius" not in shapes.columns:
        shapes["radius"] = np.nan

    # convert all geometries using positional indexing
    for i in range(len(shapes)):
        geom = shapes.geometry.iloc[i]
        gtype = geom.geom_type
        if gtype == "Point":
            r = shapes["radius"].iloc[i]
            r = float(r) if np.isfinite(r) else 0.0
            converted, radius = conversion_methods["Point"](geom, r)  # type: ignore[arg-type]
        elif gtype == "Polygon":
            converted, radius = conversion_methods["Polygon"](geom)  # type: ignore[arg-type]
        elif gtype == "MultiPolygon":
            converted, radius = conversion_methods["MultiPolygon"](geom)  # type: ignore[arg-type]
        else:
            raise ValueError(f"Converting shape {gtype} to {target_shape} is not supported.")
        shapes.at[i, "geometry"] = converted
        if radius is not None:
            shapes.at[i, "radius"] = radius

    if warn_shape_size:
        logger.info(
            f"At least one converted shape spans >= {warn_above_extent_fraction * 100:.0f}% of the "
            "original total bound. Results may be suboptimal."
        )

    return shapes


def _convert_alpha_to_datashader_range(alpha: float) -> float:
    """Convert alpha from the range [0, 1] to the range [0, 255] used in datashader."""
    # prevent a value of 255, bc that led to fully colored test plots instead of just colored points/shapes
    return min([254, alpha * 255])
