from pathlib import Path

import ase
import ase.io
import pytest

from autojob.harvest.harvesters import vasp
from autojob.utils.vasp import reorder_vasp_sequence

BADER_CHARGE_JOB_ID = "jJPQg3Puwz"


@pytest.fixture(name="ase_atoms")
def fixture_ase_atoms(datadir: Path) -> ase.Atoms:
    ase_atoms: ase.Atoms = ase.io.read(Path(datadir, "atoms.traj"))
    ase_atoms = ase_atoms[0] if isinstance(ase_atoms, list) else ase_atoms
    return ase_atoms


@pytest.fixture(name="vasp_atoms")
def fixture_vasp_atoms(datadir: Path) -> ase.Atoms:
    vasp_atoms: ase.Atoms = ase.io.read(Path(datadir, "POSCAR"))
    vasp_atoms = vasp_atoms[0] if isinstance(vasp_atoms, list) else vasp_atoms
    return vasp_atoms


@pytest.fixture(name="reordered_vasp_atoms")
def fixture_reordered_vasp_atoms(
    vasp_atoms: ase.Atoms, datadir: Path
) -> ase.Atoms:
    return reorder_vasp_sequence(
        vasp_atoms, dir_name=datadir, direction="to_ase"
    )


@pytest.fixture(name="reordered_ase_atoms")
def fixture_reordered_ase_atoms(
    ase_atoms: ase.Atoms, datadir: Path
) -> ase.Atoms:
    return reorder_vasp_sequence(
        ase_atoms, dir_name=datadir, direction="to_vasp"
    )


class TestReorderVASPIterableToASE:
    @staticmethod
    def test_should_reorder_vasp_atoms_to_match_atomic_symbols(
        ase_atoms: ase.Atoms, reordered_vasp_atoms: ase.Atoms
    ) -> None:
        symbols_match = []
        for i, a in enumerate(ase_atoms):
            symbols_match.append(a.symbol == reordered_vasp_atoms[i].symbol)

        assert all(symbols_match)

    @staticmethod
    def test_should_reorder_vasp_atoms_to_match_atomic_positions(
        ase_atoms: ase.Atoms, reordered_vasp_atoms: ase.Atoms
    ) -> None:
        positions_match = []
        for i, a in enumerate(ase_atoms):
            positions_match.append(
                all(
                    x1 == x2
                    for x1, x2 in zip(
                        a.position,
                        reordered_vasp_atoms[i].position,
                        strict=False,
                    )
                )
            )

        assert all(positions_match)


class TestReorderVASPIterableToVASP:
    @staticmethod
    def test_should_reorder_vasp_atoms_to_match_atomic_symbols(
        vasp_atoms: ase.Atoms, reordered_ase_atoms: ase.Atoms
    ) -> None:
        symbols_match = []
        for i, a in enumerate(vasp_atoms):
            symbols_match.append(a.symbol == reordered_ase_atoms[i].symbol)

        assert all(symbols_match)

    @staticmethod
    def test_should_reorder_vasp_atoms_to_match_atomic_positions(
        vasp_atoms: ase.Atoms, reordered_ase_atoms: ase.Atoms
    ) -> None:
        positions_match = []
        for i, a in enumerate(vasp_atoms):
            positions_match.append(
                all(
                    x1 == x2
                    for x1, x2 in zip(
                        a.position,
                        reordered_ase_atoms[i].position,
                        strict=False,
                    )
                )
            )

        assert all(positions_match)


class TestGetOutputAtoms:
    @staticmethod
    @pytest.mark.output_files("structure")
    def test_should_retrieve_existing_output_atoms_named_final_traj(
        populate_outputs: Path,
        atoms: ase.Atoms,
    ) -> None:
        retrieved_atoms = vasp.get_output_atoms(dir_name=populate_outputs)
        assert retrieved_atoms == atoms

    @staticmethod
    @pytest.mark.vasp
    @pytest.mark.output_files("structure", "vasprun_xml", "ase_sort_dat")
    def test_should_retrieve_output_atoms_under_alternate_name1(
        populate_outputs: Path,
        output_atoms: ase.Atoms,
    ) -> None:
        retrieved_atoms = vasp.get_output_atoms(dir_name=populate_outputs)
        assert retrieved_atoms == output_atoms

    @staticmethod
    @pytest.mark.vasp
    @pytest.mark.output_files("structure", "contcar", "ase_sort_dat")
    def test_should_retrieve_output_atoms_under_alternate_name2(
        populate_outputs: Path,
        output_atoms: ase.Atoms,
    ) -> None:
        retrieved_atoms = vasp.get_output_atoms(dir_name=populate_outputs)
        assert retrieved_atoms == output_atoms

    @staticmethod
    @pytest.mark.output_files
    @pytest.mark.parametrize("output_structure_name", "fake.traj")
    def test_should_raise_file_not_found_error_if_no_structure_found(
        populate_outputs: Path,
    ) -> None:
        with pytest.raises(FileNotFoundError):
            _ = vasp.get_output_atoms(dir_name=populate_outputs)


@pytest.mark.vasp
class TestVaspLoadCalculationOutputs:
    @staticmethod
    @pytest.fixture(name="outputs")
    def fixture_outputs(datadir: Path) -> Path:
        outputs = datadir.joinpath(BADER_CHARGE_JOB_ID)

        return outputs

    @staticmethod
    def test_should_load_outputs(
        outputs: Path,
    ) -> None:
        assert vasp.load_calculation_outputs(outputs)

    @staticmethod
    @pytest.mark.output_files("vasprun_xml", "contcar", "structure", "outcar")
    def test_should_load_vasprun_xml(populate_outputs: Path) -> None:
        assert vasp.load_calculation_outputs(dir_name=populate_outputs)

    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_not_load_dos_when_vasp_keep_dos_is_false() -> None:
        not_load_dos_when_vasp_keep_dos_is_false = False
        assert not_load_dos_when_vasp_keep_dos_is_false

    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_load_dos_when_vasp_keep_dos_is_true() -> None:
        not_load_dos_when_vasp_keep_dos_is_true = False
        assert not_load_dos_when_vasp_keep_dos_is_true
