from pathlib import Path
from typing import TypeAlias

import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib.axes import Axes

from surfplot import Plot

# 类型别名定义
Num: TypeAlias = float | int

__all__ = [
    "plot_brain_surface_figure",
]

# 路径常量
NEURODATA = Path(__file__).resolve().parent / "data" / "neurodata"


def _map_labels_to_values(data, gifti_file):
    gifti = nib.load(gifti_file)
    # 获取顶点标签编号数组，shape=(顶点数,)
    labels = gifti.darrays[0].data  
    # 构建标签编号到脑区名称的映射字典
    key_to_label = {label.key: label.label for label in gifti.labeltable.labels}
    # 检查数据中是否有在图集中找不到的脑区标签
    missing_labels = list(set(data.keys()) - set(key_to_label.values()))
    if missing_labels:
        raise ValueError(
            f"以下脑区标签在指定图集中未找到，请检查名称是否正确: {missing_labels}"
        )

    # 准备输出数组，初始化为NaN
    parc = np.full(labels.shape, np.nan, dtype=float)
    # 遍历所有标签，将数据映射到对应的顶点
    for key, label_name in key_to_label.items():
        if label_name in data:
            parc[labels == key] = data[label_name]
    return parc


def plot_brain_surface_figure(
    data: dict[str, float],
    species: str = "human",
    atlas: str = "glasser",
    surf: str = "veryinflated",
    ax: Axes | None = None,
    vmin: Num | None = None,
    vmax: Num | None = None,
    cmap: str = "Reds",
    colorbar: bool = True,
    colorbar_location: str = "right",
    colorbar_label_name: str = "",
    colorbar_label_rotation: int = 0,
    colorbar_decimals: int = 1,
    colorbar_fontsize: int = 8,
    colorbar_nticks: int = 2,
    colorbar_shrink: float = 0.15,
    colorbar_aspect: int = 8,
    colorbar_draw_border: bool = False,
    title_name: str = "",
    title_fontsize: int = 12,
    as_outline: bool = False,
) -> Axes: 
    """在大脑皮层表面绘制数值数据的函数。

    脑区图是一种用于在大脑皮层表面可视化数值数据的图表。它能够将不同脑区的数值映射到大脑皮层的相应区域，
    并以颜色编码的方式展示这些数值的分布情况。这种图表常用于展示神经科学研究中的各种脑区指标，
    如功能连接强度、激活程度或其他数值化的大脑特征。

    本函数基于 surfplot 库开发，提供了一个统一且简化的接口来绘制人脑、猕猴脑和黑猩猩脑的表面图。
    支持多种脑图谱，包括人脑的 Glasser 和 BNA 图谱、猕猴脑的 CHARM5/6、BNA 和 D99 图谱，
    以及黑猩猩脑的 BNA 图谱。

    Args:
        data (dict[str, float]): 包含脑区名称和对应数值的字典，键为脑区名称（如"lh_bankssts"），值为数值
        species (str, optional): 物种名称，支持"human"、"chimpanzee"、"macaque". Defaults to "human".
        atlas (str, optional): 脑图集名称，根据物种不同可选不同图集. Defaults to "glasser".
        surf (str, optional): 大脑皮层表面类型，如"inflated"、"veryinflated"、"midthickness"等. Defaults to "veryinflated".
        ax (Axes | None, optional): matplotlib的坐标轴对象，如果为None则使用当前坐标轴. Defaults to None.
        vmin (Num | None, optional): 颜色映射的最小值，None表示使用数据中的最小值. Defaults to None.
        vmax (Num | None, optional): 颜色映射的最大值，None表示使用数据中的最大值. Defaults to None.
        cmap (str, optional): 颜色映射方案，如"Reds"、"Blues"、"viridis"等. Defaults to "Reds".
        colorbar (bool, optional): 是否显示颜色条. Defaults to True.
        colorbar_location (str, optional): 颜色条位置，可选"left"、"right"、"top"、"bottom". Defaults to "right".
        colorbar_label_name (str, optional): 颜色条标签名称. Defaults to "".
        colorbar_label_rotation (int, optional): 颜色条标签旋转角度. Defaults to 0.
        colorbar_decimals (int, optional): 颜色条刻度标签的小数位数. Defaults to 1.
        colorbar_fontsize (int, optional): 颜色条字体大小. Defaults to 8.
        colorbar_nticks (int, optional): 颜色条刻度数量. Defaults to 2.
        colorbar_shrink (float, optional): 颜色条收缩比例. Defaults to 0.15.
        colorbar_aspect (int, optional): 颜色条宽高比. Defaults to 8.
        colorbar_draw_border (bool, optional): 是否绘制颜色条边框. Defaults to False.
        title_name (str, optional): 图形标题. Defaults to "".
        title_fontsize (int, optional): 标题字体大小. Defaults to 12.
        as_outline (bool, optional): 是否以轮廓线形式显示. Defaults to False.

    Raises:
        ValueError: 当指定的物种不支持时抛出
        ValueError: 当指定的图集不支持时抛出
        ValueError: 当数据为空时抛出
        ValueError: 当vmin大于vmax时抛出

    Returns:
        Axes: 包含绘制图像的matplotlib坐标轴对象
    """
    # 获取或创建坐标轴对象
    ax = ax or plt.gca()

    # 提取所有数值用于确定vmin和vmax
    values = list(data.values())
    if not values:
        raise ValueError("data 不能为空")
    vmin = min(values) if vmin is None else vmin
    vmax = max(values) if vmax is None else vmax
    if vmin == vmax:
        vmin, vmax = min(0, vmin), max(0, vmax)
    if vmin > vmax:
        raise ValueError("vmin必须小于等于vmax")

    # 设置数据文件路径
    # 定义不同物种、表面类型和图集的文件路径信息
    atlas_info = {
        "human": {
            "surf": {
                "lh": f"surfaces/human_fsLR/tpl-fsLR_den-32k_hemi-L_{surf}.surf.gii",
                "rh": f"surfaces/human_fsLR/tpl-fsLR_den-32k_hemi-R_{surf}.surf.gii",
            },
            "atlas": {
                "glasser": {
                    "lh": "atlases/human_Glasser/fsaverage.L.Glasser.32k_fs_LR.label.gii",
                    "rh": "atlases/human_Glasser/fsaverage.R.Glasser.32k_fs_LR.label.gii",
                },
                "bna": {
                    "lh": "atlases/human_BNA/fsaverage.L.BNA.32k_fs_LR.label.gii",
                    "rh": "atlases/human_BNA/fsaverage.R.BNA.32k_fs_LR.label.gii",
                },
            },
        },
        "chimpanzee": {
            "surf": {
                "lh": f"surfaces/chimpanzee_BNA/ChimpYerkes29_v1.2.L.{surf}.32k_fs_LR.surf.gii",
                "rh": f"surfaces/chimpanzee_BNA/ChimpYerkes29_v1.2.R.{surf}.32k_fs_LR.surf.gii",
            },
            "atlas": {
                "bna": {
                    "lh": "atlases/chimpanzee_BNA/ChimpBNA.L.32k_fs_LR.label.gii",
                    "rh": "atlases/chimpanzee_BNA/ChimpBNA.R.32k_fs_LR.label.gii",
                },
            }
        },
        "macaque": {
            "surf": {
                "lh": f"surfaces/macaque_BNA/civm.L.{surf}.32k_fs_LR.surf.gii",
                "rh": f"surfaces/macaque_BNA/civm.R.{surf}.32k_fs_LR.surf.gii",
            },
            "atlas": {
                "charm5": {
                    "lh": "atlases/macaque_CHARM5/L.charm5.label.gii",
                    "rh": "atlases/macaque_CHARM5/R.charm5.label.gii",
                },
                "charm6": {
                    "lh": "atlases/macaque_CHARM6/L.charm6.label.gii",
                    "rh": "atlases/macaque_CHARM6/R.charm6.label.gii",
                },
                "bna": {
                    "lh": "atlases/macaque_BNA/MBNA_124_32k_L.label.gii",
                    "rh": "atlases/macaque_BNA/MBNA_124_32k_R.label.gii",
                },
                "d99": {
                    "lh": "atlases/macaque_D99/L.d99.label.gii",
                    "rh": "atlases/macaque_D99/R.d99.label.gii",
                },
            }
        }
    }

    # 检查物种是否支持
    if species not in atlas_info:
        raise ValueError(f"Unsupported species: {species}. Supported species are: {list(atlas_info.keys())}")
    else:
        # 检查指定物种的图集是否支持
        if atlas not in atlas_info[species]["atlas"]:
            raise ValueError(f"Unsupported {atlas} atlas for {species}")

    # 创建Plot对象，用于绘制大脑皮层
    p = Plot(
        NEURODATA / atlas_info[species]["surf"]["lh"],
        NEURODATA / atlas_info[species]["surf"]["rh"],
    )

    # 分离左半球和右半球的数据
    hemisphere_data = {}
    for hemi in ["lh", "rh"]:
        hemi_data = {k: v for k, v in data.items() if k.startswith(f"{hemi}_")}
        hemi_parc = _map_labels_to_values(
            hemi_data, NEURODATA / atlas_info[species]["atlas"][atlas][hemi]
        )
        hemisphere_data[hemi] = hemi_parc

    # 画图
    # colorbar参数设置（用列表统一管理，便于维护）
    colorbar_params = [
        ("location", colorbar_location),
        ("label_direction", colorbar_label_rotation),
        ("decimals", colorbar_decimals),
        ("fontsize", colorbar_fontsize),
        ("n_ticks", colorbar_nticks),
        ("shrink", colorbar_shrink),
        ("aspect", colorbar_aspect),
        ("draw_border", colorbar_draw_border),
    ]
    colorbar_kws = {k: v for k, v in colorbar_params}
    # 添加图层到绘图对象
    p.add_layer(
        {"left": hemisphere_data["lh"], "right": hemisphere_data["rh"]},
        cbar=colorbar,
        cmap=cmap,
        color_range=(vmin, vmax),
        cbar_label=colorbar_label_name,
        zero_transparent=False,
        as_outline=as_outline,
    )
    # 构建坐标轴并应用颜色条设置
    ax = p.build_axis(ax=ax, cbar_kws=colorbar_kws)
    # 设置图形标题
    ax.set_title(title_name, fontsize=title_fontsize)

    return ax
