"""Store the results of a calculation.

This module defines the :class:`.autojob.calculation.calculation.Calculation`,
:class:`.autojob.calculation.calculation.CalculationInputs`, and
:class:`.autojob.calculation.calculation.CalculationOutputs` classes. Instances
of these classes represent the results of a calculation, its inputs, and its
outputs, respectively.

For building the respective documents from a folder, each class exposes a
``from_directory()`` method.

Example:
    .. code-block:: python

        from autojob.calculation.calculation import Calculation

        src = "path/to/calculation/directory"
        results = Calculation.from_directory(src)
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar
from typing import Self

import jinja2
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field

from autojob import SETTINGS
from autojob.bases.task_base import TaskOutcome
from autojob.hpc import JobState
from autojob.hpc import SchedulerInputs
from autojob.hpc import SchedulerOutputs
from autojob.plugins import get_harvester
from autojob.tasks.task import Task
from autojob.tasks.task import TaskOutputs
from autojob.utils.atoms import copy_atom_metadata
from autojob.utils.files import get_loader

if TYPE_CHECKING:
    from ase import Atoms
    from ase.calculators.calculator import Calculator

logger = logging.getLogger(__name__)

FILES_TO_COPY = [
    "CHGCAR",
    "*py",
    "*cif",
    "POSCAR",
    "coord",
    "*xyz",
    "*.traj",
    "CONTCAR",
    "*.pkl",
    "*xml",
    "WAVECAR",
    "*.com",
    "*.chk",
]
FILES_TO_DELETE = [
    "*.d2e",
    "*.int",
    "*.rwf",
    "*.skr",
    "*.inp",
    "EIGENVAL",
    "IBZKPT",
    "PCDAT",
    "PROCAR",
    "ELFCAR",
    "LOCPOT",
    "PROOUT",
    "TMPCAR",
    "vasp.dipcor",
]


ArgSpec = tuple[
    # (input posargs, input kwargs)
    list[str], dict[str, Any]
]


class CalculationInputs(BaseModel):
    """The inputs for the calculation."""

    calculator: str = Field(
        default="vasp",
        description="The name of the ASE Calculator used to perform this calculation",
        validate_default=True,
    )
    optimizer: str | None = Field(
        default=None,
        description="The name of the ASE optimizer used to perform this calculation",
    )
    calc_params: dict[str, Any] = Field(
        default={},
        description="The parameters used to configure the ASE calculator",
    )
    opt_params: dict[str, Any] | None = Field(
        default=None,
        description="The parameters used to configure the ASE optimizer",
    )
    analyses: dict[str, ArgSpec] = Field(
        default={},
        description="A dictionary specifying the post-calculation "
        "analyses. Keys correspond to analysis names and map to a 2-tuple "
        "whose first and second elements indicate positional and keyword "
        "arguments, respectively",
    )
    calculation_script: str = Field(
        default=SETTINGS.DEFAULT_CALCULATION_SCRIPT_FILE,
        description="The default filename for the calculation script",
    )
    calculation_script_template: str = Field(
        default=SETTINGS.CALCULATION_SCRIPT_TEMPLATE,
        description="The name of the default calculation script template",
    )

    model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")


class CalculationOutputs(BaseModel):
    """The outputs of a calculation."""

    energy: float | None = Field(
        default=None,
        description="Total energy in units of eV.",
    )
    forces: list[list[float]] | None = Field(
        default=None,
        description="The force on each atom in units of eV/Å.",
    )
    # ? Rename to completed and use calculator_results to store calculator-
    # ? specific convergence results?
    converged: bool = Field(
        default=False,
        description="Whether or not the calculaton has converged",
    )
    calculator_results: dict[str, Any] | None = Field(
        default=None,
        description="Calculator-specific results in excess of "
        "model-level fields",
    )
    optimizer_results: list[dict[str, Any]] | None = Field(
        default=None,
        description="A list of dictionaries, each containing "
        "optimizer results from a step in the optimization",
    )
    analysis_results: dict[str, dict[str, Any]] | None = Field(
        default=None,
        description="A dictionary mapping post-calculation analysis names to "
        "their results",
    )
    model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow")

    # TODO: Add optimizer harvester
    @classmethod
    def from_directory(
        cls,
        *,
        src: str | Path,
        calculator: str | None = None,
        analyses: list[str] | None = None,
        strict_mode: bool | None = None,
    ) -> CalculationOutputs:
        """Retrieve calculation outputs from a calculation directory.

        Args:
            src: The directory of a calculation.
            calculator: The name of the ASE calculator used to perform the
                calculation. This will be used to determine which harvester
                plugin will be used to retrieve the calculator-specific
                results. Defaults to the harvester defined in
                mod:`autojob.harvest.harvester.default`.
            analyses: A list of post-calculation analyses whose results are to
                be harvested. Defaults to an empty list.
            strict_mode: Whether or not to require all outputs. If True,
                errors will be thrown on missing outputs. Defaults to
                ``SETTINGS.STRICT_MODE``.

        Returns:
            A CalculationOutputs object.
        """
        analyses = analyses or []
        calculator = calculator or "default"
        if strict_mode is None:
            strict_mode = SETTINGS.STRICT_MODE
        logger.debug("Loading calculation outputs from directory: %s", src)
        logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis")

        # ? How to implement strict_mode
        try:
            harvester = get_harvester(calculator)
        except ValueError:
            logger.info(
                "Unable to load calculator harvester: %s. Using default harvester",
                calculator,
            )
            harvester = get_harvester("default")

        calc_results = harvester(src)
        analysis_results = {}

        for analysis in analyses:
            harvester = get_harvester(analysis)
            analysis_results[analysis] = harvester(src)

        calculation_outputs = cls(
            **calc_results, analysis_results=analysis_results
        )

        logger.debug(
            "Successfully loaded calculation outputs from directory: %s", src
        )
        return calculation_outputs


class Calculation(Task):
    """A record representing a calculation."""

    calculation_inputs: CalculationInputs = Field(
        default_factory=CalculationInputs,
        description="The inputs of the calculation",
    )
    calculation_outputs: CalculationOutputs | None = Field(
        default=None, description="The calculation outputs"
    )
    scheduler_inputs: SchedulerInputs = Field(
        default_factory=SchedulerInputs, description="The scheduler intputs"
    )
    scheduler_outputs: SchedulerOutputs | None = Field(
        default=None, description="The scheduler statistics and outputs"
    )

    # TODO: Write unit test for copying metadata
    @staticmethod
    def patch_task(
        *,
        task_outputs: TaskOutputs | None,
        input_atoms: Atoms | None,
        output_atoms: Atoms | None,
        state: JobState,
        converged: bool,
    ) -> None:
        """Patch Task attributes using Calculation values.

        Note that this method modifies the Task in place. The following
        attributes are patched:

        - ``Task.task_outputs.atoms``: replaced with ``output_atoms`` with
          metadata inherited from ``input_atoms``
        - ``Task.task_inputs.files_to_carryover``: replaced with
          ``files_to_carry_over``
        - ``Task.task_outputs.outcome``: set according to ``converged`` and
          ``state``

        Args:
            task_outputs: The :class:`~base_task.TaskOutputs` to be patched.
            input_atoms: An Atoms object representing the input geometry.
            output_atoms: An Atoms object representing the output geometry.
            state: The state of the scheduler job.
            converged: Whether or not the Calculation converged.
        """
        if task_outputs is None:
            logger.info("No task outputs to patch in task")
            return None

        if task_outputs.atoms is None and output_atoms:
            logger.debug("Patching output atoms")
            copy_atom_metadata(input_atoms, output_atoms)
            task_outputs.atoms = output_atoms

        if state in (JobState.COMPLETED, JobState.UNKNOWN) and converged:
            task_outputs.outcome = TaskOutcome.SUCCESS
        else:
            task_outputs.outcome = TaskOutcome.FAILED

        logger.debug(f"Task outcome: {task_outputs.outcome}")

    @classmethod
    def from_directory(cls, src: str | Path, **kwargs) -> Self:
        """Generate a ``Calculation`` document from a calculation directory.

        Args:
            src: The directory of a calculation.
            kwargs: Additional keyword arguments:

            - strict_mode: Whether or not to fail on any error. Defaults to
                `SETTINGS.STRICT_MODE`.
            - magic_mode: Whether or not to instantiate subclasses. If
                True, the task returned must be an instance determined by
                metadata in the directory. Defaults to False.

        Returns:
            A :class:`Calculation` or a subclass of a :class:`Calculation`.

        .. seealso::

            :meth:`.task.Task.from_directory`
        """
        strict_mode = kwargs.get("strict_mode", SETTINGS.STRICT_MODE)
        magic_mode = kwargs.get("magic_mode", False)
        logger.debug("Loading calculation from directory: %s", src)
        logger.debug("Magic mode: %sabled", "en" if magic_mode else "dis")
        logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis")

        if magic_mode:
            return cls.load_magic(src, strict_mode=strict_mode)

        task = Task.from_directory(
            src=src, strict_mode=strict_mode, magic_mode=False
        )
        data = task.task_inputs.model_extra.pop("calculation_inputs") or {}
        calc_inputs = CalculationInputs(**data)
        calc_outputs = CalculationOutputs.from_directory(
            src=src,
            calculator=calc_inputs.calculator,
            analyses=list(calc_inputs.analyses),
            strict_mode=strict_mode,
        )
        sched_inputs = SchedulerInputs.from_directory(src=src)
        sched_outputs = SchedulerOutputs.from_directory(src=src)

        if calc_outputs.calculator_results:
            output_atoms = calc_outputs.calculator_results.get("atoms")
        else:
            output_atoms = None

        cls.patch_task(
            task_outputs=task.task_outputs,
            input_atoms=task.task_inputs.atoms,
            output_atoms=output_atoms,
            state=sched_outputs.state,
            converged=calc_outputs.converged,
        )

        logger.debug("Successfully loaded calculation from directory: %s", src)
        return cls(
            task_metadata=task.task_metadata,
            task_inputs=task.task_inputs,
            task_outputs=task.task_outputs,
            calculation_inputs=calc_inputs,
            calculation_outputs=calc_outputs,
            scheduler_inputs=sched_inputs,
            scheduler_outputs=sched_outputs,
        )

    def prepare_input_atoms(self) -> None:
        """Copy the final magnetic moments to initial magnetic moments.

        This function modifies atoms in place. Note that if atoms were obtained
        from a ``vasprun.xml`` via ``ase.io.read("vasprun.xml")``, no magnetic
        moments will be read. In order to ensure continuity between runs, it is
        a good idea to retain the ``WAVECAR`` between runs.
        """
        logger.debug("Preparing atoms for next run.")

        atoms = self.task_inputs.atoms

        if atoms is None:
            logger.info("No input atoms found.")
            return None

        calc: Calculator = self.task_inputs.atoms.calc

        if calc is None:
            logger.info("No calculator found.")
            return None

        magmoms = calc.results.get("magmoms", None)

        if magmoms is None:
            logger.info(
                "No magnetic moments to copy found. Using the initial "
                "magnetic moments: "
                f"{self.task_inputs.atoms.get_initial_magnetic_moments()!r}"
            )
            return None

        self.task_inputs.atoms.set_initial_magnetic_moments(magmoms)
        logger.debug("Copied magnetic moments to initial magnetic moments")

    def write_input_atoms(self, dest: str | Path) -> Path | None:
        """Write the input atoms to a file.

        Args:
            dest: The directory in which to write the Atoms file.

        Returns:
            The filename in which the Atoms where written.
        """
        self.prepare_input_atoms()
        return super().write_input_atoms(dest)

    def write_inputs_json(
        self,
        dest: str | Path,
        *,
        additional_data: dict[str, Any] | None = None,
        **kwargs,  # noqa: ARG002
    ) -> Path:
        """Write the inputs JSON to a file.

        Args:
            dest: The directory in which to write the inputs JSON.
            additional_data: A dictionary mapping strings to JSON-serializable
                values to be merged with the task inputs that will be written
                to the inputs JSON. Defaults to an empty dictionary.
            kwargs: Additional keyword arguments.

        Returns:
            The filename in which the inputs JSON written.
        """
        additional_data = additional_data or {}
        additional_data = {
            "calculation_inputs": self.calculation_inputs.model_dump(
                mode="json"
            ),
            **additional_data,
        }
        return super().write_inputs_json(dest, additional_data=additional_data)

    def write_calculation_script(
        self,
        dest: str | Path,
        *,
        template: str = SETTINGS.CALCULATION_SCRIPT_TEMPLATE,  # noqa: ARG002
        structure_name: str | None = None,
    ) -> Path:
        """Write the task script used to run the Calculation.

        Args:
            dest: The directory in which to write the Python script.
            template: The name of the template to use to write the Python
                script.
            structure_name: The filename of the input structure to be read to
                load the :class:`~ase.atoms.Atoms` object for the calculation.
                Defaults to the value of the ``"filename"`` key in the input
                Atoms object, if present. Defaults to the value of
                ``SETTINGS.INPUT_ATOMS`` otherwise. Take care to ensure that
                this matches the name of the file to which the structure is
                written.

        Returns:
            A Path object representing the filename of the written task
            script.
        """
        dest = Path(dest)
        logger.debug("Writing calculation script to directory: %s", dest)
        calculator = self.calculation_inputs.calculator
        parameters = self.calculation_inputs.calc_params

        env = jinja2.Environment(
            loader=get_loader(),
            autoescape=False,  # noqa: S701
            trim_blocks=True,
            lstrip_blocks=True,
            keep_trailing_newline=True,
        )

        to_render = env.get_template(
            self.calculation_inputs.calculation_script_template
        )
        filename = Path(dest, self.calculation_inputs.calculation_script)

        if structure_name is None:
            if (
                self.task_inputs.atoms is None
                or self.task_inputs.atoms.info.get("filename") is None
            ):
                structure_name = SETTINGS.INPUT_ATOMS_FILE
            else:
                structure_name = self.task_inputs.atoms.info["filename"]

        with filename.open(mode="x", encoding="utf-8") as file:
            file.write(
                to_render.render(
                    calculator=calculator,
                    structure=structure_name,
                    parameters=parameters,
                    settings=SETTINGS.model_dump(),
                )
            )

        logger.debug(
            "Successfully wrote calculation script to file: %s", filename
        )
        return filename

    def write_task_script(
        self,
        dest: str | Path,
        *,
        additional_data: dict[str, Any] | None = None,
        **kwargs,  # noqa: ARG002
    ) -> Path:
        """Write the SLURM input script using the template given.

        Args:
            dest: The directory in which to write the SLURM file.
            additional_data: A dictionary mapping strings to JSON-serializable
                values to be merged with the task inputs that will be written
                to the inputs JSON. Defaults to an empty dictionary.
            kwargs: additional keyword arguments to be used to
                render the script template.

        Returns:
            A Path representing the filename of the written SLURM script.
        """
        additional_data = additional_data or {}
        calculation_inputs = self.calculation_inputs.model_dump()
        raw_sched_inputs = self.scheduler_inputs.model_dump(
            mode="json", exclude_none=True, by_alias=True
        )
        formatted_sched_inputs = {
            f"--{k}": v for k, v in raw_sched_inputs.items()
        }

        additional_data = {
            "calculation_inputs": calculation_inputs,
            "scheduler_inputs": formatted_sched_inputs,
            **additional_data,
        }

        return super().write_task_script(dest, additional_data=additional_data)

    def write_inputs(
        self,
        dest: str | Path,
        **kwargs,
    ) -> list[Path]:
        """Write the required inputs for a Calculation to a directory.

        Args:
            dest: The directory in which to write the inputs.
            kwargs: Additional keyword arguments.

        Returns:
            A list of Path objects where each Path represents the filename of
            an input written to ``dest``.
        """
        logger.debug(
            "Writing %s inputs to directory: %s", self.__class__.__name__, dest
        )
        inputs = super().write_inputs(dest, **kwargs)
        inputs.append(self.write_calculation_script(dest))
        logger.debug(
            "Successfully wrote %s inputs to directory: %s",
            self.__class__.__name__,
            inputs,
        )
        return inputs
