"""Store the results of scan calculations.

This module defines the :class:`.autojob.tasks.scan.BondScan`,
:class:`.autojob.tasks.scan.BondScanInputs`, and
:class:`.autojob.tasks.scan.BondScanOutputs` classes. Instances
of these classes represent the results of a scan calculation, its inputs, and
its outputs, respectively.

For building the respective documents from a folder, each class exposes a
``from_directory()`` method.

Example:
    .. code-block:: python

        from autojob.tasks.scan import BondScan

        src = "path/to/calculation/directory"
        results = BondScan.from_directory(src)
"""

from __future__ import annotations

from copy import deepcopy
import logging
from pathlib import Path
import re
from typing import TYPE_CHECKING
from typing import Annotated
from typing import Any
from typing import ClassVar
from typing import NamedTuple
from typing import Self

import ase.io
import numpy as np
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import FieldSerializationInfo
from pydantic import SerializerFunctionWrapHandler
from pydantic import ValidationError
from pydantic import ValidationInfo
from pydantic import ValidatorFunctionWrapHandler
from pydantic import WrapValidator
from pydantic import field_serializer
from pydantic import field_validator

from autojob import SETTINGS
from autojob.tasks.calculation import Calculation
from autojob.utils.schemas import PydanticAtoms

if TYPE_CHECKING:
    from ase import Atoms

logger = logging.getLogger(__name__)


class BondScanParams(NamedTuple):
    """Parameters for modifying bond scans.

    ``a0``, ``a1``, ``mask``, ``indices``, and ``fix`` map to the parameters
    in :meth:`~ase.Atoms.set_distance``. ``bond_lims`` specifies the minimum
    and maximum bond lengths for the bond scan. The endpoint may be excluded
    if the scan step does not evenly divide the difference between the minimum
    and maximum bond lengths.
    """

    a0: int
    a1: int
    mask: list[bool] | None = None
    indices: list[int] | None = None
    fix: float = 0.5
    bond_lims: tuple[float, float] = (0.7, 4.0)


class BondScanInputs(BaseModel):
    """The inputs for the calculation."""

    bond_scan_params: list[BondScanParams] | None = Field(
        default=None,
        description="A list of bond scan parameters specifying which and how "
        "bonds will be linearly scanned. List entries can be BondScanParams "
        "or tuples of length 6 or less.",
    )
    scan_step: float = Field(
        default=0.1,
        description="The step size (in Angstroms) to use for the linear scan "
        "calculations.",
    )
    write_traj: bool = Field(
        default=True,
        description="Whether or not to write trajectory files for each bond "
        "scan containing the images for which energies are calculated.",
    )
    traj_template: str = Field(
        default="scan_{}_{}.traj",
        description="A format string to be used to name the trajectory files "
        "of each linear scan. The format string must accept two fields, "
        "which correspond to the atomic indices of the atoms in the bond "
        "for which the linear bond scan is being performed.",
    )

    model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")

    @field_validator("bond_scan_params", mode="wrap")
    @classmethod
    def validate_bond_scan_params(
        cls,
        v: Any,
        handler: ValidatorFunctionWrapHandler,
        _: ValidationInfo,
    ) -> float | None:
        """Validate the bond scan parameters."""
        try:
            return handler(v)
        except ValidationError:
            pass
        try:
            bond_scan_params = [BondScanParams(*x) for x in v]
            return handler(bond_scan_params)
        except TypeError as err:
            raise ValueError from err


# This function is necessary because tuple keys are serialized as
# comma-separated strings by json.dump
def _validate_bond_tuple(
    v: Any, handler: ValidatorFunctionWrapHandler, _: ValidationInfo
) -> tuple[int, int]:
    try:
        return handler(v)  # type: ignore[no-any-return]
    except ValidationError:
        if isinstance(v, str):
            a0, a1 = v.split(",")
            return int(a0), int(a1)
        raise


BondTuple = Annotated[tuple[int, int], WrapValidator(_validate_bond_tuple)]

BondScanResults = dict[BondTuple, list[dict[str, Any]]]


class BondScanOutputs(BaseModel):
    """The raw data of a bond scan calculation.

    Note that data under each key in ``results`` must correspond to the data
    under ``images`` of the same index.
    """

    images: dict[BondTuple, list[PydanticAtoms]] = Field(
        default_factory=dict,
        description="A dictionary mapping paired atomic indices to Atoms "
        "objects generated in the bond scan.",
    )
    bond_scan_results: BondScanResults = Field(
        default_factory=dict,
        description="A dictionary mapping paired atomic indices to calculated "
        "results generated in the bond scan.",
    )

    @field_serializer("bond_scan_results", mode="wrap")
    def serialize_bond_scan_results(
        self,
        v: BondScanResults | None,
        _: SerializerFunctionWrapHandler,
        info: FieldSerializationInfo,
    ) -> BondScanResults | None:
        """Serialize the bond scan results."""
        if v is None:
            return None

        if info.mode == "json":
            listified: BondScanResults = {}
            for bond, bond_results in v.items():
                listified[bond] = []
                for result in bond_results:
                    listified[bond].append({})
                    for k, value in result.items():
                        listified[bond][-1][k] = (
                            value.tolist()
                            if isinstance(value, np.ndarray)
                            # for mutable results
                            else deepcopy(value)
                        )
            return listified

        return v

    @property
    def bonds(self) -> list[tuple[int, int]]:
        """A list of paired atomic indices indicating bonded atom pairs."""
        return list(self.images)

    @classmethod
    def from_directory(
        cls,
        *,
        src: str | Path,
        traj_template: str = "scan_{}_{}.traj",
        strict_mode: bool | None = None,
    ) -> Self:
        """Retrieve bond scan outputs from a directory.

        Args:
            src: The directory of a calculation.
            traj_template: A format string to be used to name the trajectory
                files of each linear scan. The format string must accept two
                fields, which correspond to the atomic indices of the atoms in
                the bond for which the linear bond scan is being performed.
                Defaults to ``"scan_{}_{}.traj"``.
            strict_mode: Whether or not to require all results. If True,
                errors will be thrown on missing results. Defaults to
                ``SETTINGS.STRICT_MODE``.

        Returns:
            A BondScanOutputs object.
        """
        if strict_mode is None:
            strict_mode = SETTINGS.STRICT_MODE

        src = Path(src)
        logger.debug("Loading bond scan outputs from directory: %s", src)
        logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis")

        traj_re = re.compile(
            re.subn(r"\{\}", r"(.+)", traj_template, count=2)[0]
        )
        images: dict[tuple[int, int], list[Atoms]] = {}
        results: BondScanResults = {}

        for traj in src.iterdir():
            match = traj_re.match(traj.name)

            if not match:
                continue

            groups = match.groups()
            a0 = int(groups[0])
            a1 = int(groups[1])
            bond = (a0, a1)
            images[bond] = []
            results[bond] = []

            for i, image in enumerate(ase.io.read(traj, index=":")):
                try:
                    results[bond].append(image.calc.results.copy())
                except AttributeError:
                    if strict_mode:
                        raise
                    logger.info(
                        "Unable to load results from image %s of trajectory: %s",
                        i,
                        traj,
                    )
                    results[bond].append({})
                image.calc = None
                images[bond].append(image)

        return cls(images=images, bond_scan_results=results)


class BondScan(Calculation):
    """A record representing a bond scan calculation."""

    bond_scan_inputs: BondScanInputs = Field(
        default_factory=BondScanInputs,
        description="The inputs of the bond scan calculation",
    )
    bond_scan_outputs: BondScanOutputs | None = Field(
        default=None, description="The outputs of the bond scan calculation"
    )

    @classmethod
    def from_directory(cls, src: str | Path, **kwargs) -> Self:
        """Generate a ``BondScan`` document from a task directory.

        Args:
            src: The directory of a bond scan calculation.
            kwargs: Additional keyword arguments:

            - strict_mode: Whether or not to fail on any error. Defaults to
                `SETTINGS.STRICT_MODE`.
            - magic_mode: Whether or not to instantiate subclasses. If
                True, the task returned must be an instance determined by
                metadata in the directory. Defaults to False.

        Returns:
            A :class:`BondScan` or a subclass of a :class:`BondScan`.

        .. seealso::

            :meth:`.calculation.Calculation.from_directory`
        """
        strict_mode = kwargs.get("strict_mode", SETTINGS.STRICT_MODE)
        magic_mode = kwargs.get("magic_mode", False)
        logger.debug("Loading bond scan calculation from directory: %s", src)
        logger.debug("Magic mode: %sabled", "en" if magic_mode else "dis")
        logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis")

        if magic_mode:
            return cls.load_magic(src, strict_mode=strict_mode)

        calc = Calculation.from_directory(
            src=src, strict_mode=strict_mode, magic_mode=False
        )
        data = calc.task_inputs.model_extra.pop("bond_scan_inputs", {})
        bond_scan_inputs = BondScanInputs(**data)
        bond_scan_outputs = BondScanOutputs.from_directory(
            src=src,
            traj_template=bond_scan_inputs.traj_template,
            strict_mode=strict_mode,
        )

        logger.debug(
            "Successfully loaded bond scan calculation from directory: %s", src
        )
        return cls(
            **calc.model_dump(),
            bond_scan_inputs=bond_scan_inputs,
            bond_scan_outputs=bond_scan_outputs,
        )

    def write_inputs_json(
        self,
        dest: str | Path,
        *,
        additional_data: dict[str, Any] | None = None,
        **kwargs,  # noqa: ARG002
    ) -> Path:
        """Write the inputs JSON to a file.

        Args:
            dest: The directory in which to write the inputs JSON.
            additional_data: A dictionary mapping strings to JSON-serializable
                values to be merged with the bond scan inputs that will be written
                to the inputs JSON. Defaults to an empty dictionary.
            kwargs: Additional keyword arguments.

        Returns:
            The filename in which the inputs JSON written.
        """
        additional_data = additional_data or {}
        additional_data = {
            "bond_scan_inputs": self.bond_scan_inputs.model_dump(mode="json"),
            **additional_data,
        }
        return super().write_inputs_json(dest, additional_data=additional_data)
