"""Harvest charge analysis results from completed task directory."""

import logging
import math
from pathlib import Path
from typing import Annotated
from typing import TypeVar
import warnings

import ase
from ase.geometry import geometry
import ase.io
from pydantic import PlainSerializer
from pymatgen.command_line.chargemol_caller import ChargemolAnalysis
from pymatgen.core.periodic_table import Element
from typing_extensions import TypedDict

from autojob.utils.files import extract_structure_name

logger = logging.getLogger(__name__)


DDEC_NET_ATOMIC_CHARGE_FILE = "DDEC6_even_tempered_net_atomic_charges.xyz"
CHGCAR = "CHGCAR"
POSCAR = "POSCAR"
POTCAR = "POTCAR"


logger = logging.getLogger(__name__)


_T = TypeVar("_T")

StringSerialized = Annotated[
    _T, PlainSerializer(lambda x: str(x), return_type=str)
]


class _BondedToDict(TypedDict):
    index: int
    element: StringSerialized[Element]
    bond_order: float
    direction: tuple[float, float, float]
    spin_polarization: float


class _BondOrders(TypedDict):
    element: StringSerialized[Element]
    bonded_to: list[_BondedToDict]
    bond_order_sum: float


BondOrderDict = dict[int, _BondOrders]


class DDEC6Analysis(TypedDict):
    """DDEC6 analysis data."""

    partial_charges: list[float]
    spin_moments: list[float]
    dipoles: list[list[float]]
    rsquared_moments: list[float]
    rcubed_moments: list[float]
    rfourth_moments: list[float]
    bond_order_dict: BondOrderDict


def get_ddec6_index_map(src: str | Path, *, tol: float = 1e-3) -> list[int]:
    """Return a list of integers mapping DDEC6 indices to ASE indices.

    The ASE index of the atom at index i in the DDEC6 structure can be found
    as follows:

        index_map = get_ddec6_index_map(dir_name)
        index = index_map[i]

    Args:
        src: The directory containing a calculation.
        tol: The maximum allowed deviation in atomic positions in Angstroms.
    """
    with Path(src, "run.py").open(mode="r", encoding="utf-8") as python_script:
        input_traj = extract_structure_name(python_script).removeprefix("./")

    ase_atoms = ase.io.read(Path(src).joinpath(input_traj))
    ase_atoms = (
        ase_atoms if isinstance(ase_atoms, ase.Atoms) else ase_atoms[-1]
    )
    ase_atoms.center()

    xyz_file = Path(src, DDEC_NET_ATOMIC_CHARGE_FILE)
    ddec_atoms = ase.io.read(xyz_file)
    ddec_atoms = (
        ddec_atoms if isinstance(ddec_atoms, ase.Atoms) else ddec_atoms[-1]
    )
    ddec_atoms.cell = ase_atoms.cell
    ddec_atoms.center()

    ddec6_index_map = []
    # (i, j): distance between ith atom in ddec_atoms and jth atom in ase_atoms
    _, distances = geometry.get_distances(
        ddec_atoms.positions,
        ase_atoms.positions,
        cell=ddec_atoms.cell,
        pbc=True,
    )
    min_distances = []
    for atom_distances in distances:
        min_distance = math.inf
        for j, distance in enumerate(atom_distances):
            if distance <= min_distance and j not in ddec6_index_map:
                closest_ase_atom_index = j
                min_distance = distance
        ddec6_index_map.append(closest_ase_atom_index)
        min_distances.append(min_distance)

    if len(ddec6_index_map) != len(ddec_atoms):
        msg = f"DDEC6 map is incomplete for calculation in directory: {src}"
        raise RuntimeError(msg)

    if any(d > tol for d in min_distances):
        max_distance = max(min_distances)
        msg = f"Atomic displacement exceeds tolerance: {max_distance} > {tol}"
        warnings.warn(msg, category=UserWarning, stacklevel=1)

    return ddec6_index_map


def load_ddec6_results(src: str | Path) -> DDEC6Analysis:
    """Extract the DDEC6 data from the job directory.

    Args:
        src: The directory of the completed calculation.

    Returns:
        The DDEC6 data.  If no data is found, every value will be None.
    """
    logger.debug(f"Loading DDEC6 data for {src!s}")
    rsquared_moments = rcubed_moments = rfourth_moments = None
    dipoles = charges = spin_densities = bond_orders = None
    analysis = ChargemolAnalysis(path=src, run_chargemol=False)
    ddec6_index_map = get_ddec6_index_map(src)
    charges = [0.0] * len(ddec6_index_map)
    spin_densities = [0.0] * len(ddec6_index_map)
    bond_orders: BondOrderDict = {}
    dipoles = [[0.0] * 3] * len(ddec6_index_map)
    rsquared_moments = [0.0] * len(ddec6_index_map)
    rcubed_moments = [0.0] * len(ddec6_index_map)
    rfourth_moments = [0.0] * len(ddec6_index_map)

    for i_ddec, i_ase in enumerate(ddec6_index_map):
        charges[i_ase] = analysis.ddec_charges[i_ddec]
        spin_densities[i_ase] = analysis.ddec_spin_moments[i_ddec]
        bond_orders[i_ase] = analysis.bond_order_dict[i_ddec]
        dipoles[i_ase] = analysis.dipoles[i_ddec]
        rsquared_moments[i_ase] = analysis.ddec_rsquared_moments[i_ddec]
        rcubed_moments[i_ase] = analysis.ddec_rcubed_moments[i_ddec]
        rfourth_moments[i_ase] = analysis.ddec_rfourth_moments[i_ddec]

    for bond_order in bond_orders.values():
        for bonded_to in bond_order["bonded_to"]:
            # DDEC6 Data is 1-indexed
            bonded_to["index"] = ddec6_index_map[bonded_to["index"] - 1]

    ddec6_data = DDEC6Analysis(
        partial_charges=charges,
        spin_moments=spin_densities,
        dipoles=dipoles,
        rsquared_moments=rsquared_moments,
        rcubed_moments=rcubed_moments,
        rfourth_moments=rfourth_moments,
        bond_order_dict=bond_orders,
    )
    logger.info(f"Successfully loaded DDEC6 data for {src!s}")
    return ddec6_data
