"""Store the results of a molecular dynamics simulation."""

import contextlib
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Self

import ase.io
import jinja2
from pydantic import Field

from autojob import SETTINGS
from autojob.hpc import ScheduledMixin
from autojob.hpc import SchedulerInputs
from autojob.hpc import SchedulerOutputs
from autojob.tasks.calculation import Calculation
from autojob.tasks.calculation import CalculationInputs
from autojob.tasks.calculation import CalculationOutputs
from autojob.tasks.task import Task
from autojob.utils.files import get_loader
from autojob.utils.schemas import PydanticAtoms

if TYPE_CHECKING:
    from ase.calculators.calculator import Calculator

logger = logging.getLogger(__name__)

_TRAJECTORY_KEY = "_trajectory_file"


class MDInputs(CalculationInputs):
    """The inputs of a molecular dynamics simulation."""

    md_params: dict[str, Any] = Field(
        default={},
        description="The parameters used to configure the molecular dynamics "
        f"object. The special key '{_TRAJECTORY_KEY}' should be used to specify "
        "the trajectory file for run.",
    )


class MDOutputs(CalculationOutputs):
    """The outputs of a molecular dynamics simulation."""

    trajectory: list[PydanticAtoms] | None = Field(
        default=None,
        description="A list of atoms representing the trajectory of the "
        "system throughout a molecular dynamics simulation.",
    )

    # TODO: Implement strict_mode
    @classmethod
    def from_directory(
        cls,
        src: Path,
        *,
        calculator: str | None = None,
        analyses: list[str] | None = None,
        strict_mode: bool | None = None,
        trajectory: str | None = None,
    ) -> Self:
        """Load the outputs of a molecular dynamics run from a directory.

        Args:
            src: The directory from which to load the results.
            calculator: The name of the ASE calculator used to perform the
                calculation. This will be used to determine which harvester
                plugin will be used to retrieve the calculator-specific
                results. Defaults to the harvester defined in
                mod:`autojob.harvest.harvester.default`.
            analyses: A list of post-calculation analyses whose results are to
                be harvested. Defaults to an empty list.
            strict_mode: Whether or not to require all outputs. If True,
                errors will be thrown on missing outputs. Defaults to
                ``SETTINGS.STRICT_MODE``.
            trajectory: The name of the trajectory file. Defaults to None in
                which case no trajectory will be loaded.
        """
        logger.info(
            "Loading molecular dynamics outputs from directory: %s", src
        )
        calc_outputs = super().from_directory(
            src=src,
            calculator=calculator,
            analyses=analyses,
            strict_mode=strict_mode,
        )
        md_outputs = calc_outputs.model_dump()
        with contextlib.suppress(FileNotFoundError, TypeError):
            md_outputs["trajectory"] = ase.io.read(Path(src, trajectory), ":")
        return cls(**md_outputs)


# TODO: Subclass Calculation
class MolecularDynamics(Task, ScheduledMixin):
    """A molecular dynamics simulation."""

    md_inputs: MDInputs = Field(
        default_factory=MDInputs,
        description="The inputs of a molecular dynamics simulation",
    )
    md_outputs: MDOutputs = Field(
        default_factory=MDOutputs,
        description="The outputs of a molecular dynamics simulation",
    )

    @classmethod
    def from_directory(cls, src, **kwargs):
        """Generate a ``MolecularDynamics`` document from a task directory.

        Args:
            src: The directory of a molecular dynamics simulation.
            kwargs: Additional keyword arguments:

            - strict_mode: Whether or not to fail on any error. Defaults to
                `SETTINGS.STRICT_MODE`.
            - magic_mode: Whether or not to instantiate subclasses. If
                True, the task returned must be an instance determined by
                metadata in the directory. Defaults to False.

        Returns:
            A :class:`MolecularDynamics` or a subclass of a
            :class:`MolecularDynamics`.

        .. seealso::

            :meth:`.calculation.Calculation.from_directory`
        """
        strict_mode = kwargs.get("strict_mode", SETTINGS.STRICT_MODE)
        magic_mode = kwargs.get("magic_mode", False)
        logger.debug(
            "Loading molecular dynamics simulation from directory: %s", src
        )
        logger.debug("Magic mode: %sabled", "en" if magic_mode else "dis")
        logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis")

        if magic_mode:
            return cls.load_magic(src, strict_mode=strict_mode)

        task = Task.from_directory(
            src=src, strict_mode=strict_mode, magic_mode=False
        )
        data = task.task_inputs.model_extra.pop("md_inputs", {})
        md_inputs = MDInputs(**data)
        md_outputs = MDOutputs.from_directory(
            src,
            trajectory=md_inputs.md_params.get(_TRAJECTORY_KEY),
            strict_mode=strict_mode,
        )
        sched_inputs = SchedulerInputs.from_directory(src=src)
        sched_outputs = SchedulerOutputs.from_directory(src=src)

        if md_outputs.calculator_results:
            output_atoms = md_outputs.calculator_results.get("atoms")
        else:
            output_atoms = None

        Calculation.patch_task(
            task_outputs=task.task_outputs,
            input_atoms=task.task_inputs.atoms,
            output_atoms=output_atoms,
            state=sched_outputs.state,
            converged=md_outputs.converged,
        )

        logger.debug(
            "Successfully loaded molecular dynamics simulation from directory: %s",
            src,
        )
        return cls(
            task_metadata=task.task_metadata,
            task_inputs=task.task_inputs,
            task_outputs=task.task_outputs,
            md_inputs=md_inputs,
            md_outputs=md_outputs,
            scheduler_inputs=sched_inputs,
            scheduler_outputs=sched_outputs,
        )

    def prepare_input_atoms(self) -> None:
        """Copy the final magnetic moments to initial magnetic moments.

        This function modifies atoms in place. Note that if atoms were obtained
        from a ``vasprun.xml`` via ``ase.io.read("vasprun.xml")``, no magnetic
        moments will be read. In order to ensure continuity between runs, it is
        a good idea to retain the ``WAVECAR`` between runs.
        """
        logger.debug("Preparing atoms for next run.")

        atoms = self.task_inputs.atoms

        if atoms is None:
            logger.info("No input atoms found.")
            return None

        calc: Calculator = self.task_inputs.atoms.calc

        if calc is None:
            logger.info("No calculator found.")
            return None

        magmoms = calc.results.get("magmoms", None)

        if magmoms is None:
            logger.info(
                "No magnetic moments to copy found. Using the initial "
                "magnetic moments: "
                f"{self.task_inputs.atoms.get_initial_magnetic_moments()!r}"
            )
            return None

        self.task_inputs.atoms.set_initial_magnetic_moments(magmoms)
        logger.debug("Copied magnetic moments to initial magnetic moments")

    def write_input_atoms(self, dest: str | Path) -> Path | None:
        """Write the input atoms to a file.

        Args:
            dest: The directory in which to write the Atoms file.

        Returns:
            The filename in which the Atoms where written.
        """
        self.prepare_input_atoms()
        return super().write_input_atoms(dest)

    def write_inputs_json(
        self,
        dest: str | Path,
        *,
        additional_data: dict[str, Any] | None = None,
        **kwargs,  # noqa: ARG002
    ) -> Path:
        """Write the inputs JSON to a file.

        Args:
            dest: The directory in which to write the inputs JSON.
            additional_data: A dictionary mapping strings to JSON-serializable
                values to be merged with the task inputs that will be written
                to the inputs JSON. Defaults to an empty dictionary.
            kwargs: Additional keyword arguments.

        Returns:
            The filename in which the inputs JSON written.
        """
        additional_data = additional_data or {}
        additional_data = {
            "md_inputs": self.md_inputs.model_dump(mode="json"),
            **additional_data,
        }
        return super().write_inputs_json(dest, additional_data=additional_data)

    def write_calculation_script(
        self,
        dest: str | Path,
        *,
        template: str = SETTINGS.CALCULATION_SCRIPT_TEMPLATE,  # noqa: ARG002
        structure_name: str | None = None,
    ) -> Path:
        """Write the calculation script used to run the task.

        Args:
            dest: The directory in which to write the Python script.
            template: The name of the template to use to write the Python
                script.
            structure_name: The filename of the input structure to be read to
                load the :class:`~ase.atoms.Atoms` object for the calculation.
                Defaults to the value of the ``"filename"`` key in the input
                Atoms object, if present. Defaults to the value of
                ``SETTINGS.INPUT_ATOMS`` otherwise. Take care to ensure that
                this matches the name of the file to which the structure is
                written.

        Returns:
            A Path object representing the filename of the written task
            script.
        """
        dest = Path(dest)
        logger.debug("Writing calculation script to directory: %s", dest)
        calculator = self.md_inputs.calculator
        parameters = self.md_inputs.calc_params

        env = jinja2.Environment(
            loader=get_loader(),
            autoescape=False,  # noqa: S701
            trim_blocks=True,
            lstrip_blocks=True,
            keep_trailing_newline=True,
        )

        to_render = env.get_template(
            self.md_inputs.calculation_script_template
        )
        filename = Path(dest, self.md_inputs.calculation_script)

        if structure_name is None:
            if (
                self.task_inputs.atoms is None
                or self.task_inputs.atoms.info.get("filename") is None
            ):
                structure_name = SETTINGS.INPUT_ATOMS_FILE
            else:
                structure_name = self.task_inputs.atoms.info["filename"]

        # TODO: Change context to include more parameters (e.g., opt_params,
        # TODO: analyses, etc.)
        with filename.open(mode="x", encoding="utf-8") as file:
            file.write(
                to_render.render(
                    calculator=calculator,
                    structure=structure_name,
                    parameters=parameters,
                    settings=SETTINGS.model_dump(),
                )
            )

        logger.debug(
            "Successfully wrote calculation script to file: %s", filename
        )
        return filename

    def write_task_script(
        self,
        dest: str | Path,
        *,
        additional_data: dict[str, Any] | None = None,
        **kwargs,  # noqa: ARG002
    ) -> Path:
        """Write the SLURM input script using the template given.

        Args:
            dest: The directory in which to write the SLURM file.
            additional_data: A dictionary mapping strings to JSON-serializable
                values to be merged with the task inputs that will be written
                to the inputs JSON. Defaults to an empty dictionary.
            kwargs: additional keyword arguments to be used to
                render the script template.

        Returns:
            A Path representing the filename of the written SLURM script.
        """
        additional_data = additional_data or {}
        md_inputs = self.md_inputs.model_dump()
        raw_sched_inputs = self.scheduler_inputs.model_dump(
            mode="json", exclude_none=True, by_alias=True
        )
        formatted_sched_inputs = {
            f"--{k}": v for k, v in raw_sched_inputs.items()
        }

        additional_data = {
            "md_inputs": md_inputs,
            "scheduler_inputs": formatted_sched_inputs,
            **additional_data,
        }

        return super().write_task_script(dest, additional_data=additional_data)

    def write_inputs(
        self,
        dest: str | Path,
        **kwargs,
    ) -> list[Path]:
        """Write the required inputs for an MD simulation to a directory.

        Args:
            dest: The directory in which to write the inputs.
            kwargs: Additional keyword arguments.

        Returns:
            A list of Path objects where each Path represents the filename of
            an input written to ``dest``.
        """
        logger.debug(
            "Writing %s inputs to directory: %s", self.__class__.__name__, dest
        )
        inputs = super().write_inputs(dest, **kwargs)
        inputs.append(self.write_calculation_script(dest))
        logger.debug(
            "Successfully wrote %s inputs to directory: %s",
            self.__class__.__name__,
            inputs,
        )
        return inputs
