from collections.abc import Generator
import json
from pathlib import Path
import shutil

from ase import Atoms
import pytest

from autojob import SETTINGS
from autojob.bases.task_base import TaskBase
from autojob.tasks.calculation import Calculation
from autojob.tasks.task import Task
from autojob.tasks.task import TaskInputs
from autojob.tasks.task import TaskOutputs


@pytest.mark.usefixtures("write_task_inputs")
class TestWriteInputs:
    @staticmethod
    @pytest.fixture(name="input_atoms", scope="class")
    def fixture_input_atoms() -> Atoms | None:
        input_atoms: Atoms = Atoms("C")
        return input_atoms

    @staticmethod
    @pytest.fixture(name="task", scope="class")
    def fixture_task(input_atoms: Atoms | None) -> Task:
        task = Task()
        task.task_inputs.atoms = input_atoms
        return task

    @staticmethod
    @pytest.fixture(name="task_directory", scope="class")
    def fixture_task_directory(
        tmp_path_factory: pytest.TempPathFactory,
    ) -> Generator[Path, None, None]:
        task_directory = tmp_path_factory.mktemp("dest")
        yield task_directory
        shutil.rmtree(task_directory)

    @staticmethod
    @pytest.fixture(name="write_task_inputs", scope="class")
    def fixture_write_task_inputs(
        task_directory: Path, task: Task
    ) -> list[Path]:
        return task.write_inputs(task_directory)

    @staticmethod
    def test_should_write_input_atoms(
        task_directory: Path, task: Task
    ) -> None:
        assert Path(task_directory, task.task_inputs.atoms_filename).exists()

    @staticmethod
    def test_should_return_path_matching_input_atoms_filename(
        task: Task, write_task_inputs: list[Path]
    ) -> None:
        assert any(
            f.name == task.task_inputs.atoms_filename
            for f in write_task_inputs
        )

    @staticmethod
    def test_should_write_inputs_json(task_directory: Path) -> None:
        assert Path(task_directory, SETTINGS.INPUTS_FILE).exists()

    @staticmethod
    def test_should_return_path_matching_inputs_json_filename(
        write_task_inputs: list[Path],
    ) -> None:
        assert any(f.name == SETTINGS.INPUTS_FILE for f in write_task_inputs)

    @staticmethod
    def test_should_write_metadata(task_directory: Path) -> None:
        assert Path(task_directory, SETTINGS.TASK_METADATA_FILE).exists()

    @staticmethod
    def test_should_return_path_matching_metadata_filename(
        write_task_inputs: list[Path],
    ) -> None:
        assert any(
            f.name == SETTINGS.TASK_METADATA_FILE for f in write_task_inputs
        )


class TestWriteInputAtoms:
    @staticmethod
    def test_should_write_input_atoms(tmp_path: Path) -> None:
        task = Task()
        task.task_inputs.atoms = Atoms("C", info={"filename": "dummy.traj"})
        _ = task.write_input_atoms(tmp_path)
        assert Path(tmp_path, task.task_inputs.atoms_filename).exists()

    @staticmethod
    def test_should_return_path_matching_input_atoms(tmp_path: Path) -> None:
        task = Task()
        task.task_inputs.atoms = Atoms("C")
        written = task.write_input_atoms(tmp_path)
        assert Path(tmp_path, task.task_inputs.atoms_filename) == written

    @staticmethod
    def test_should_not_write_input_atoms_if_input_atoms_is_none(
        tmp_path: Path,
    ) -> None:
        task = Task()
        task.write_input_atoms(tmp_path)
        assert not Path(tmp_path, task.task_inputs.atoms_filename).exists()


class TestWriteInputJSON:
    @staticmethod
    def test_should_write_input_json(tmp_path: Path) -> None:
        task = Task()
        task.write_inputs_json(tmp_path)
        assert Path(tmp_path, SETTINGS.INPUTS_FILE).exists()

    @staticmethod
    def test_should_not_include_atoms_key(tmp_path: Path) -> None:
        task = Task()
        task.write_inputs_json(tmp_path)
        inputs_json = Path(tmp_path, SETTINGS.INPUTS_FILE)
        with inputs_json.open(mode="r", encoding="utf-8") as file:
            inputs = json.load(file)
        assert "atoms" not in inputs["task_inputs"]

    @staticmethod
    def test_should_write_equivalent_input_json_without_atoms(
        tmp_path: Path,
    ) -> None:
        task = Task()
        task.write_inputs_json(tmp_path)
        inputs_json = Path(tmp_path, SETTINGS.INPUTS_FILE)
        with inputs_json.open(mode="r", encoding="utf-8") as file:
            inputs = json.load(file)
        loaded_inputs = TaskInputs(**inputs["task_inputs"]).model_dump(
            exclude={"atoms"}
        )
        assert loaded_inputs == task.task_inputs.model_dump(exclude={"atoms"})

    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_overwrite_task_inputs_with_additional_data() -> None:
        overwrite_task_inputs_with_additional_data = False
        assert overwrite_task_inputs_with_additional_data

    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_overwrite_task_metadata_with_additional_data() -> None:
        overwrite_task_inputs_with_additional_data = False
        assert overwrite_task_inputs_with_additional_data

    @staticmethod
    @pytest.mark.xfail(reason="Not implemented")
    def test_should_overwrite_settings_with_additional_data() -> None:
        overwrite_task_inputs_with_additional_data = False
        assert overwrite_task_inputs_with_additional_data


class TestWriteMetadata:
    @staticmethod
    def test_should_write_metadata_file(tmp_path: Path) -> None:
        task = Task()
        task.write_metadata(tmp_path)
        metadata_filename = Path(tmp_path, SETTINGS.TASK_METADATA_FILE)
        assert tmp_path.joinpath(metadata_filename).exists()


class TestLoadMagic:
    @staticmethod
    def test_should_return_instance_of_base_class_if_provided(
        tmp_path: Path,
    ) -> None:
        task = Task()
        task.write_inputs(tmp_path)
        loaded = Task.load_magic(tmp_path, strict_mode=False)
        assert isinstance(loaded, Task)

    @staticmethod
    @pytest.mark.parametrize("task_type", [Task, Calculation])
    def test_should_return_instance_of_task_if_no_base_class_and_strict_mode_off(
        tmp_path: Path, task_type: type[TaskBase]
    ) -> None:
        task = task_type()
        task.task_metadata.task_class = None
        task.write_inputs(tmp_path)
        loaded = Task.load_magic(tmp_path, strict_mode=False)
        assert isinstance(loaded, Task)

    @staticmethod
    @pytest.mark.parametrize("task_type", [Task, Calculation])
    def test_should_raise_runtime_error_if_no_base_class_and_strict_mode_on(
        tmp_path: Path, task_type: type[TaskBase]
    ) -> None:
        task = task_type()
        task.task_metadata.task_class = None
        task.write_inputs(tmp_path)
        with pytest.raises(
            RuntimeError, match=r"No build class provided for task in*"
        ):
            _ = Task.load_magic(tmp_path, strict_mode=True)


class TestFromDirectory:
    @staticmethod
    @pytest.mark.parametrize("magic_mode", [True, False])
    def test_should_load_task(tmp_path: Path, magic_mode: bool) -> None:
        task = Task()
        task.task_outputs = TaskOutputs()
        task.task_inputs.atoms = task.task_outputs.atoms = Atoms("C")
        task.write_inputs(tmp_path)
        task.task_inputs.atoms.write(
            Path(tmp_path, SETTINGS.OUTPUT_ATOMS_FILE)
        )
        task_from_directory = Task.from_directory(
            tmp_path, magic_mode=magic_mode
        )
        assert isinstance(task_from_directory, Task)
