"""VASP calculation output utilities.

This module provides the :func:`load_calculation_outputs`
and :func:`get_output_atoms` functions for retrieving
calculation outputs and output atoms from the directory
of a VASP calculation.

Example:
    from pathlib import Path
    from autojob.calculation.vasp import vasp

    outputs = vasp.load_calculation_outputs(Path.cwd())
    atoms = vasp.get_output_atoms(Path.cwd())
"""

import logging
from pathlib import Path
from typing import Any
from xml.etree import ElementTree

from ase import Atoms
import ase.io
from emmet.core.tasks import TaskDoc
from emmet.core.tasks import TaskState
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.io.vasp.outputs import Vasprun

from autojob import SETTINGS
from autojob.utils.atoms import copy_atom_metadata

logger = logging.getLogger(__name__)

ALTERNATE_OUTPUT_STRUCTURES = ("relax.traj", "vasprun.xml", "CONTCAR")
FILES_TO_CARRYOVER = ("CHGCAR", "WAVECAR")
VOLUMETRIC_FILES = ("CHGCAR", "LOCPOT", "AECCAR0", "AECCAR1", "AECCAR2")


def load_calculation_results(
    src: str | Path,
) -> dict[str, Any]:
    """Load VASP calculation outputs from a directory.

    Args:
        src: The directory from which to load VASP outputs.

    Returns:
        A dictionary with, at minimum, the required keys to initialize
        a :class:`autojob.calculation.calculation.Calculation` but
        also with same keys as an instance of
        :class:`emmet.core.tasks.OutputDoc` and additional keys mapping
        to a dictionary representation of a
        :class:`pymatgen.io.vasp.outputs.Vasprun` object and a dictionary
        representation of a
        :class:`pymatgen.electronic_structure.dos.CompleteDos` object.
    """
    logger.info(f"Loading VASP calculation outputs from {src}")
    results = {}

    try:
        doc = TaskDoc.from_directory(src)
        output_doc = doc.output.model_dump() if doc.output else {}
        structure = output_doc.get("structure")
        atoms = AseAtomsAdaptor.get_atoms(structure) if structure else None
        dumped_doc = doc.model_dump(exclude={"output"})
        results["energy"] = output_doc.pop("energy")
        results["forces"] = output_doc.pop("forces")
        results["converged"] = dumped_doc.pop("state") == TaskState.SUCCESS
        results["calculator_results"] = {
            **output_doc,
            **dumped_doc,
            "atoms": atoms,
        }
        vasprun_xml = Path(src, "vasprun.xml")

        if SETTINGS.VASP_KEEP_DOS and vasprun_xml.exists():
            logger.info("Keeping VASP DOS outputs")
            vasprun = Vasprun(vasprun_xml)
            dos = vasprun.complete_dos
            results["calculator_results"]["complete_dos"] = dos.as_dict()
            results["calculator_results"]["vasprun"] = vasprun.as_dict()
        else:
            logger.info("Discarding VASP DOS outputs")

    except TypeError as err:
        if "Calculation.from_vasp_files" in err.args[0]:
            msg = "Unable to find VASP file"
            raise FileNotFoundError(msg) from err
        else:
            raise

    logger.debug(f"Successfully loaded VASP calculation outputs from {src}")
    return results


# TODO: Unit test
def _reorder_atoms(output_atoms: Atoms, src: str | Path) -> Atoms:
    """Creates a new Atoms object reordered according to ase-sort.dat.

    This function assumes that the Atoms object passed is ordered in
    accordance to the POSCAR/POTCAR.
    """
    logger.debug("Reordering atoms")
    sort_file = Path(src).joinpath("ase-sort.dat")

    with Path(sort_file).open(mode="r", encoding="utf-8") as file:
        lines = file.readlines()

    # First column: if the VASP index of an atom is i, then the index of the
    # corresponding atom in the ASE Atoms object is the integer in row i
    conversion_table = [int(line.split()[0]) for line in lines]
    ase_ordering = [conversion_table[atom.index] for atom in output_atoms]
    atoms = [output_atoms[i] for i in ase_ordering]

    logger.debug(
        "Successfully reordered atoms: "
        f"{[atom.index for atom in output_atoms]!r} -> {ase_ordering!r}"
    )
    return Atoms(
        atoms,
        cell=output_atoms.cell,
        pbc=output_atoms.pbc,
        celldisp=output_atoms.get_celldisp(),
    )


def get_output_atoms(
    src: str | Path,
    alt_filename_index: int | None = None,
    input_atoms: Atoms | None = None,
) -> Atoms:
    """Retrieve an Atoms object representing the output structure.

    This function also copies tags and constraints from the input structure
    in the case that the output structure must be read from a non-ASE file
    (e.g., vasprun.xml).

    Args:
        src: The directory from which to retrieve the output structure.
        alt_filename_index: An integer pointing to which alternative structure
            file should be used. This number will be used to index
            `ALTERNATE_OUTPUT_STRUCTURES`.
        input_atoms: An Atoms object representing the corresponding input
            structure.

    Returns:
        An Atoms object representing the output structure.
    """
    if alt_filename_index is None:
        alt_filename_index = 0
        filename = SETTINGS.OUTPUT_ATOMS_FILE
    else:
        filename = ALTERNATE_OUTPUT_STRUCTURES[alt_filename_index]
        alt_filename_index += 1

    full_filename = Path(src).joinpath(filename)

    logger.debug(f"Retrieving output atoms from {full_filename}")
    atoms = None

    try:
        atoms = ase.io.read(full_filename)
    except (FileNotFoundError, AttributeError, ElementTree.ParseError):
        msg = (
            f"Unable to retrieve atoms from: {full_filename}.\nFile not found."
        )
        logger.warning(msg)
        try:
            atoms = get_output_atoms(
                src=src,
                alt_filename_index=alt_filename_index,
                input_atoms=input_atoms,
            )
            atoms = _reorder_atoms(output_atoms=atoms, src=src)
            copy_atom_metadata(
                input_atoms=input_atoms,
                output_atoms=atoms,
            )
        except IndexError as err:
            msg = (
                f"No output atoms found in {SETTINGS.OUTPUT_ATOMS_FILE} or "
                f"{ALTERNATE_OUTPUT_STRUCTURES!r}"
            )
            raise FileNotFoundError(msg) from err
        except FileNotFoundError:
            if atoms is None:
                raise
            logger.warning("Unable to reorder atoms")

    logger.debug(f"Successfully retrieved output atoms from {full_filename}")
    return atoms
