from collections.abc import Generator
import json
from pathlib import Path
from typing import Any
from typing import ClassVar

from ase import Atoms
from ase.calculators.calculator import Calculator
import ase.io
from ase.optimize.bfgs import BFGS
import numpy as np
from numpy.typing import NDArray
import pytest

from autojob import SETTINGS
from autojob.harvest.harvesters import vasp
from autojob.plugins import register_plugin
from autojob.tasks.calculation import Calculation
from autojob.tasks.calculation import CalculationOutputs

BADER_CHARGE_JOB_ID = "jJPQg3Puwz"


@pytest.mark.skip(reason="Not implemented")
@pytest.mark.vasp
class TestVaspLoadCalculationOutputs:
    @staticmethod
    @pytest.fixture(name="outputs")
    def fixture_outputs(datadir: Path) -> Path:
        outputs = Path(datadir, BADER_CHARGE_JOB_ID)
        return outputs

    @staticmethod
    def test_should_load_outputs(outputs: Path) -> None:
        assert vasp.load_calculation_results(outputs)

    @staticmethod
    @pytest.mark.usefixtures("write_calculation_inputs", "calculate")
    def test_should_load_vasprun_xml(task_directory: Path) -> None:
        assert vasp.load_calculation_results(src=task_directory)


def volume_analysis(dest: Path) -> None:
    atoms = ase.io.read(Path(dest, SETTINGS.OUTPUT_ATOMS_FILE))
    analysis = {
        "cell_par": atoms.cell.cellpar().tolist(),
        "cell_volume": atoms.cell.volume,
    }
    with Path(dest, "volume_analysis.json").open(
        mode="w", encoding="utf-8"
    ) as f:
        json.dump(analysis, f)


@pytest.mark.usefixtures("register_plugins", "calculate")
class TestLoadCalculationOutputsFromDirectory:
    @staticmethod
    @pytest.fixture(name="input_atoms")
    def input_atoms() -> Atoms:
        return Atoms("CO", positions=[np.zeros(3), np.ones(3) * 2])

    @staticmethod
    @pytest.fixture(name="test_calculator_class")
    def fixture_test_calculator_class() -> Calculator:
        class TestCalculator(Calculator):
            implemented_properties: ClassVar[list[str]] = ["energy", "forces"]

        return TestCalculator

    # TODO: use calculate fixture
    @staticmethod
    @pytest.fixture(name="src")
    def fixture_src(
        task_directory: Path,
        calculation: Calculation,
        calculate: None,  # noqa: ARG004
        optimize: None,  # noqa: ARG004
    ) -> Generator[Path, None, None]:
        src = task_directory

        # Write inputs/outputs
        calculation.calculation_inputs.analyses = {"volume": ([], {})}
        _ = calculation.write_inputs(src)
        calculation.task_outputs.atoms.write(
            Path(src, SETTINGS.OUTPUT_ATOMS_FILE)
        )
        volume_analysis(src)
        return src

    @staticmethod
    @pytest.fixture(name="register_plugins")
    def fixture_register_plugins() -> None:
        def volume_analysis_harvester(dest: Path) -> None:
            with Path(dest, "volume_analysis.json").open(
                mode="r", encoding="utf-8"
            ) as f:
                return json.load(f)

        register_plugin("volume", volume_analysis_harvester, "harvester")

    @staticmethod
    @pytest.mark.parametrize("to_calculate", [["energy"]])
    def test_should_retrieve_final_energy(
        src: Path,
        calculation: Calculation,
        energy: float,
        calculate: None,  # noqa: ARG004
    ) -> None:
        calc_outputs = CalculationOutputs.from_directory(
            src=src,
            analyses=list(calculation.calculation_inputs.analyses),
        )
        assert calc_outputs.energy == energy

    @staticmethod
    @pytest.mark.parametrize("to_calculate", [["forces"]])
    def test_should_retrieve_atom_forces(
        src: Path, calculation: Calculation, forces: NDArray
    ) -> None:
        calc_outputs = CalculationOutputs.from_directory(
            src=src,
            analyses=list(calculation.calculation_inputs.analyses),
        )
        assert (calc_outputs.forces == forces).all()

    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_retrieve_calculator_results(
        src: Path, calculation: Calculation
    ) -> None:
        calc_outputs = CalculationOutputs.from_directory(
            src=src,
            analyses=list(calculation.calculation_inputs.analyses),
        )
        assert calc_outputs.calculator_results

    @staticmethod
    @pytest.mark.skip(reason="Optimizer harvesters not implemented")
    @pytest.mark.parametrize("opt_class", [BFGS])
    def test_should_retrieve_optimizer_results(
        src: Path,
        calculation: Calculation,
        optimize: None,  # noqa: ARG004
        opt_params: dict[str, Any],
    ) -> None:
        calc_outputs = CalculationOutputs.from_directory(
            src=src,
            analyses=list(calculation.calculation_inputs.analyses),
        )
        assert calc_outputs.optimizer_results["nsteps"] == opt_params["steps"]

    @staticmethod
    def test_should_retrieve_analysis_results(
        src: Path, calculation: Calculation
    ) -> None:
        calc_outputs = CalculationOutputs.from_directory(
            src=src,
            analyses=list(calculation.calculation_inputs.analyses),
        )
        assert "volume" in calc_outputs.analysis_results

    # TODO
    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_pass_with_failing_calculator_harvester_if_strict_mode_disabled() -> (
        None
    ):
        pass_with_failing_calculator_harvest_if_strict_mode_disabled = False
        assert pass_with_failing_calculator_harvest_if_strict_mode_disabled

    # TODO
    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_pass_with_failing_analysis_harvester_if_strict_mode_disabled() -> (
        None
    ):
        pass_with_failing_calculator_harvest_if_strict_mode_disabled = False
        assert pass_with_failing_calculator_harvest_if_strict_mode_disabled

    # TODO
    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_raise_error_with_failing_calculator_harvester_if_strict_mode_enabled() -> (
        None
    ):
        pass_with_failing_calculator_harvest_if_strict_mode_disabled = False
        assert pass_with_failing_calculator_harvest_if_strict_mode_disabled

    # TODO
    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_raise_error_with_failing_analysis_harvester_if_strict_mode_enabled() -> (
        None
    ):
        pass_with_failing_calculator_harvest_if_strict_mode_disabled = False
        assert pass_with_failing_calculator_harvest_if_strict_mode_disabled


class TestSerialization:
    @staticmethod
    def test_should_serialize_calculator_results_with_numpy_arrays(
        calculation_outputs: CalculationOutputs,
    ) -> None:
        dipoles = np.zeros(3)
        expected = dipoles.tolist()
        calculation_outputs.calculator_results = {"dipoles": np.zeros(3)}
        dumped = calculation_outputs.model_dump(mode="json")
        assert dumped["calculator_results"]["dipoles"] == expected
