from pathlib import Path
import shutil

import ase
import ase.io
import pytest

from autojob import SETTINGS
from autojob.harvest.harvesters import vasp
from autojob.utils.vasp import reorder_vasp_sequence

BADER_CHARGE_JOB_ID = "jJPQg3Puwz"


@pytest.fixture(name="ase_atoms")
def fixture_ase_atoms(datadir: Path) -> ase.Atoms:
    ase_atoms: ase.Atoms = ase.io.read(Path(datadir, "atoms.traj"))
    ase_atoms = ase_atoms[0] if isinstance(ase_atoms, list) else ase_atoms
    return ase_atoms


@pytest.fixture(name="vasp_atoms")
def fixture_vasp_atoms(datadir: Path) -> ase.Atoms:
    vasp_atoms: ase.Atoms = ase.io.read(Path(datadir, "POSCAR"))
    vasp_atoms = vasp_atoms[0] if isinstance(vasp_atoms, list) else vasp_atoms
    return vasp_atoms


@pytest.fixture(name="reordered_vasp_atoms")
def fixture_reordered_vasp_atoms(
    vasp_atoms: ase.Atoms, datadir: Path
) -> ase.Atoms:
    return reorder_vasp_sequence(
        vasp_atoms, dir_name=datadir, direction="to_ase"
    )


@pytest.fixture(name="reordered_ase_atoms")
def fixture_reordered_ase_atoms(
    ase_atoms: ase.Atoms, datadir: Path
) -> ase.Atoms:
    return reorder_vasp_sequence(
        ase_atoms, dir_name=datadir, direction="to_vasp"
    )


class TestReorderVASPIterableToASE:
    @staticmethod
    def test_should_reorder_vasp_atoms_to_match_atomic_symbols(
        ase_atoms: ase.Atoms, reordered_vasp_atoms: ase.Atoms
    ) -> None:
        symbols_match = []
        for i, a in enumerate(ase_atoms):
            symbols_match.append(a.symbol == reordered_vasp_atoms[i].symbol)

        assert all(symbols_match)

    @staticmethod
    def test_should_reorder_vasp_atoms_to_match_atomic_positions(
        ase_atoms: ase.Atoms, reordered_vasp_atoms: ase.Atoms
    ) -> None:
        positions_match = []
        for i, a in enumerate(ase_atoms):
            positions_match.append(
                all(
                    x1 == x2
                    for x1, x2 in zip(
                        a.position,
                        reordered_vasp_atoms[i].position,
                        strict=False,
                    )
                )
            )

        assert all(positions_match)


class TestReorderVASPIterableToVASP:
    @staticmethod
    def test_should_reorder_vasp_atoms_to_match_atomic_symbols(
        vasp_atoms: ase.Atoms, reordered_ase_atoms: ase.Atoms
    ) -> None:
        symbols_match = []
        for i, a in enumerate(vasp_atoms):
            symbols_match.append(a.symbol == reordered_ase_atoms[i].symbol)

        assert all(symbols_match)

    @staticmethod
    def test_should_reorder_vasp_atoms_to_match_atomic_positions(
        vasp_atoms: ase.Atoms, reordered_ase_atoms: ase.Atoms
    ) -> None:
        positions_match = []
        for i, a in enumerate(vasp_atoms):
            positions_match.append(
                all(
                    x1 == x2
                    for x1, x2 in zip(
                        a.position,
                        reordered_ase_atoms[i].position,
                        strict=False,
                    )
                )
            )

        assert all(positions_match)


class TestGetOutputAtoms:
    @staticmethod
    # TODO: Replace with appropriate fixtures
    # @pytest.mark.output_files("structure")
    def test_should_retrieve_existing_output_atoms_named_final_traj(
        task_directory: Path,
        output_atoms: ase.Atoms,
        write_task_inputs: list[str],  # noqa: ARG004
        calculate: None,  # noqa: ARG004
    ) -> None:
        output_atoms.write(Path(task_directory, SETTINGS.OUTPUT_ATOMS_FILE))
        retrieved_atoms = vasp.get_output_atoms(src=task_directory)
        assert retrieved_atoms == output_atoms

    @staticmethod
    @pytest.mark.vasp
    def test_should_retrieve_output_atoms_under_alternate_name1(
        task_directory: Path,
        output_atoms: ase.Atoms,
        write_task_inputs: list[str],  # noqa: ARG004
        datadir: Path,
    ) -> None:
        shutil.copy(Path(datadir, "POSCAR"), Path(task_directory, "CONTCAR"))
        output_atoms = ase.io.read(Path(datadir, "POSCAR"))
        retrieved_atoms = vasp.get_output_atoms(src=task_directory)
        assert retrieved_atoms == output_atoms

    @staticmethod
    @pytest.mark.vasp
    def test_should_retrieve_output_atoms_under_alternate_name2(
        task_directory: Path,
        output_atoms: ase.Atoms,
        write_task_inputs: list[str],  # noqa: ARG004
    ) -> None:
        output_atoms.write(Path(task_directory, "relax.traj"))
        retrieved_atoms = vasp.get_output_atoms(src=task_directory)
        assert retrieved_atoms == output_atoms

    @staticmethod
    # TODO: Replace with appropriate fixtures
    # @pytest.mark.output_files
    # @pytest.mark.parametrize("output_structure_name", "fake.traj")
    def test_should_raise_file_not_found_error_if_no_structure_found(
        task_directory: Path,
        write_task_inputs: list[str],  # noqa: ARG004
        datadir: Path,
    ) -> None:
        for f in datadir.iterdir():
            shutil.copy(f, task_directory)
        with pytest.raises(FileNotFoundError):
            _ = vasp.get_output_atoms(src=task_directory)


@pytest.mark.vasp
class TestVaspLoadCalculationOutputs:
    @pytest.mark.xfail(reason="Missing OUTCAR file")
    @staticmethod
    def test_should_load_outputs(shared_datadir: Path) -> None:
        assert vasp.load_calculation_results(shared_datadir)

    @staticmethod
    @pytest.mark.xfail(reason="Missing OUTCAR file")
    def test_should_load_vasprun_xml(shared_datadir: Path) -> None:
        assert vasp.load_calculation_results(src=shared_datadir)

    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_not_load_dos_when_vasp_keep_dos_is_false() -> None:
        not_load_dos_when_vasp_keep_dos_is_false = False
        assert not_load_dos_when_vasp_keep_dos_is_false

    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_load_dos_when_vasp_keep_dos_is_true() -> None:
        not_load_dos_when_vasp_keep_dos_is_true = False
        assert not_load_dos_when_vasp_keep_dos_is_true
