from pathlib import Path
from typing import TYPE_CHECKING

from ase.calculators.calculator import Calculator
from ase.calculators.emt import EMT
import pytest

from autojob import SETTINGS
from autojob.bases.task_base import TaskOutcome
from autojob.tasks.calculation import Calculation

if TYPE_CHECKING:
    from autojob.bases.task_base import TaskBase


class TestPrepareInputAtoms:
    @staticmethod
    @pytest.fixture(name="calculator")
    def fixture_calculator() -> Calculator:
        calc = EMT()
        calc.use_cache = True
        calc.results = {"magmoms": [0.0]}
        return calc

    @staticmethod
    @pytest.fixture(name="prepare_input_atoms")
    def fixture_prepare_input_atoms(
        calculation: Calculation, calculator: Calculator
    ) -> None:
        calculation.task_inputs.atoms.calc = calculator
        return calculation.prepare_input_atoms()

    @staticmethod
    def test_should_copy_magnetic_moments_when_present(
        prepare_input_atoms: None,  # noqa: ARG004
        task: "TaskBase",
    ) -> None:
        assert task.task_inputs.atoms
        magmoms = task.task_inputs.atoms.calc.results["magmoms"]
        assert task.task_inputs.atoms.get_initial_magnetic_moments() == magmoms

    @staticmethod
    def test_should_do_nothing_when_no_atoms_present(
        prepare_input_atoms: None,
    ) -> None:
        assert prepare_input_atoms is None

    @staticmethod
    def test_should_do_nothing_when_no_magnetic_moments_present(
        prepare_input_atoms: None,
    ) -> None:
        assert prepare_input_atoms is None


class TestWriteCalculationScript:
    @staticmethod
    def test_should_write_calculation_script(
        calculation: Calculation,
        tmp_path: Path,
    ) -> None:
        _ = calculation.write_calculation_script(tmp_path)
        calculation_script = calculation.calculation_inputs.calculation_script
        assert tmp_path.joinpath(calculation_script).exists()


class TestWriteScript:
    @staticmethod
    def test_should_write_script_to_default_filename(
        tmp_path: Path, calculation: Calculation
    ) -> None:
        _ = calculation.write_task_script(tmp_path)
        task_script = calculation.task_inputs.task_script
        assert tmp_path.joinpath(task_script).exists()


class TestFromDirectory:
    @staticmethod
    def test_should_load_calculation_inputs(
        calculation: Calculation,
        task_directory: Path,
    ) -> None:
        calculation.write_inputs(task_directory)
        atoms = calculation.task_inputs.atoms
        atoms.write(Path(task_directory, SETTINGS.OUTPUT_ATOMS_FILE))
        loaded = Calculation.from_directory(task_directory, strict_mode=False)
        dumped_inputs = calculation.calculation_inputs.model_dump()
        dumped_loaded = loaded.calculation_inputs.model_dump()
        assert dumped_inputs == dumped_loaded

    @staticmethod
    def test_should_set_task_outcome_as_successful_if_calculation_converged(
        task_directory: Path,
        calculation: Calculation,
        calculate: None,  # noqa: ARG004
    ) -> None:
        _ = calculation.write_inputs(task_directory)
        calculation.task_outputs.atoms.write(
            Path(task_directory, SETTINGS.OUTPUT_ATOMS_FILE)
        )
        loaded = Calculation.from_directory(task_directory)
        assert loaded.task_outputs.outcome == TaskOutcome.SUCCESS

    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_patch_output_atoms_from_calculation_outputs_if_output_atoms_missing_from_task_outputs() -> (
        None
    ):
        patch_output_atoms_from_calculation_outputs_if_output_atoms_missing_from_task_outputs = False
        assert patch_output_atoms_from_calculation_outputs_if_output_atoms_missing_from_task_outputs
