from collections.abc import Generator
import json
from pathlib import Path

from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.optimize.minimahopping import MinimaHopping
import pytest

from autojob import SETTINGS
from autojob.tasks.calculation import CalculationInputs
from autojob.tasks.md import MDInputs
from autojob.tasks.md import MDOutputs
from autojob.tasks.md import MolecularDynamics
from autojob.tasks.task import Task
from autojob.tasks.task import TaskInputs
from autojob.tasks.task import TaskMetadata


@pytest.fixture(name="trajectory_file")
def fixture_trajectory_file() -> str:
    return "minima.traj"


@pytest.fixture(name="totalsteps")
def fixture_totalsteps() -> int:
    return 2


@pytest.fixture(name="run_md")
def fixture_run_md(
    output_atoms: Atoms,
    task_directory: Path,
    calculator: Calculator,
    monkeypatch: pytest.MonkeyPatch,
    totalsteps: int,
    trajectory_file: str,
) -> Generator[None, None, None]:
    hopper = MinimaHopping(output_atoms, minima_traj=trajectory_file)
    output_atoms.calc = calculator
    with monkeypatch.context() as m:
        m.chdir(task_directory)
        hopper(totalsteps=totalsteps)
        outputs_atoms_file = Path(task_directory, SETTINGS.OUTPUT_ATOMS_FILE)
        output_atoms.write(outputs_atoms_file)
        yield None
        outputs_atoms_file.unlink()


@pytest.fixture(name="md_inputs")
def fixture_md_inputs(
    calculation_inputs: CalculationInputs,
    trajectory_file: str,
    totalsteps: int,
) -> MDInputs:
    md_inputs = calculation_inputs.model_dump()
    md_inputs["md_params"] = {
        "init": {"_trajectory_file": trajectory_file},
        "run": {"totalsteps": totalsteps},
    }
    return MDInputs(**md_inputs)


@pytest.fixture(name="minima_hopping")
def fixture_minima_hopping(
    task_metadata: TaskMetadata, task_inputs: TaskInputs, md_inputs: MDInputs
) -> MolecularDynamics:
    return MolecularDynamics(
        task_metadata=task_metadata.model_dump(exclude={"task_class"}),
        task_inputs=task_inputs,
        md_inputs=md_inputs,
    )


@pytest.fixture(name="write_md_inputs")
def fixture_write_md_inputs(
    minima_hopping: MolecularDynamics, task_directory: Path
) -> list[Path]:
    return minima_hopping.write_inputs(task_directory)


class TestWriteInputs:
    @staticmethod
    def test_should_write_md_inputs_to_inputs_json(tmp_path: Path) -> None:
        task = MolecularDynamics()
        task.write_inputs_json(tmp_path)
        inputs_json = Path(tmp_path, SETTINGS.INPUTS_FILE)
        with inputs_json.open(mode="r", encoding="utf-8") as file:
            inputs = json.load(file)
        assert "md_inputs" in inputs


@pytest.mark.usefixtures("run_md")
class TestMDOutputs:
    @staticmethod
    def test_should_load_md_outputs_from_directory(
        task_directory: Path, trajectory_file: str
    ) -> None:
        md_outputs = MDOutputs.from_directory(
            task_directory, trajectory=trajectory_file
        )
        assert md_outputs


@pytest.mark.usefixtures(
    "run_md",
    "write_job_stats_file",
    "write_slurm_output_file",
    "write_md_inputs",
)
class TestMolecularDynamics:
    @staticmethod
    def test_should_load_molecular_dynamics_task_from_directory(
        task_directory: Path,
    ) -> None:
        md_task = Task.from_directory(
            task_directory,
            magic_mode=True,
        )
        assert isinstance(md_task, MolecularDynamics)
