import json
from pathlib import Path

from ase import Atoms
import pytest

from autojob import SETTINGS
from autojob.tasks.task import TaskInputs


class TestFromDirectory:
    @staticmethod
    @pytest.fixture(name="atoms_filename")
    def fixture_atoms_filename(request: pytest.FixtureRequest) -> str:
        atoms_filename = str(getattr(request, "param", ""))
        return atoms_filename

    @staticmethod
    @pytest.fixture(name="inputs")
    def fixture_inputs(atoms_filename: str) -> TaskInputs:
        inputs = TaskInputs(atoms_filename=atoms_filename)
        return inputs

    @staticmethod
    @pytest.mark.parametrize("strict_mode", [True, False])
    def test_should_load_task_inputs_from_file_if_inputs_file_present_and_atoms_filename_not_specified_regardless_of_strict_mode(
        tmp_path: Path, inputs: TaskInputs, strict_mode: bool
    ) -> None:
        inputs_json = Path(tmp_path, SETTINGS.INPUTS_FILE)
        dumped_inputs = {"task_inputs": inputs.model_dump(exclude={"atoms"})}

        with inputs_json.open(mode="w", encoding="utf-8") as file:
            json.dump(dumped_inputs, file)

        loaded_inputs = TaskInputs.from_directory(
            tmp_path, strict_mode=strict_mode
        ).model_dump(exclude={"atoms"})

        assert loaded_inputs == dumped_inputs["task_inputs"]

    @staticmethod
    @pytest.mark.parametrize("atoms_filename", ["in.traj"], indirect=True)
    def test_should_load_task_inputs_if_filename_specified_but_no_atoms_file_present_if_strict_mode_disabled(
        tmp_path: Path, inputs: TaskInputs
    ) -> None:
        inputs_json = Path(tmp_path, SETTINGS.INPUTS_FILE)
        dumped_inputs = {"task_inputs": inputs.model_dump(exclude={"atoms"})}

        with inputs_json.open(mode="w", encoding="utf-8") as file:
            json.dump(dumped_inputs, file)

        loaded_inputs = TaskInputs.from_directory(
            tmp_path, strict_mode=False
        ).model_dump(exclude={"atoms"})

        assert loaded_inputs == dumped_inputs["task_inputs"]

    @staticmethod
    def test_should_raise_error_if_inputs_file_not_present(
        tmp_path: Path,
    ) -> None:
        with pytest.raises(FileNotFoundError):
            TaskInputs.from_directory(tmp_path, strict_mode=True)

    @staticmethod
    @pytest.mark.parametrize("atoms_filename", ["in.traj"], indirect=True)
    def test_should_load_task_inputs_with_input_atoms_if_filename_specified_and_atoms_file_present(
        tmp_path: Path, inputs: TaskInputs
    ) -> None:
        inputs.atoms = Atoms("C")
        inputs_json = Path(tmp_path, SETTINGS.INPUTS_FILE)
        dumped_inputs = {"task_inputs": inputs.model_dump(exclude={"atoms"})}

        with inputs_json.open(mode="w", encoding="utf-8") as file:
            json.dump(dumped_inputs, file)
        inputs.atoms.write(Path(tmp_path, inputs.atoms_filename))

        loaded_input_atoms = TaskInputs.from_directory(
            tmp_path, strict_mode=True
        ).atoms

        assert loaded_input_atoms == inputs.atoms

    @staticmethod
    @pytest.mark.parametrize("atoms_filename", ["in.traj"], indirect=True)
    def test_should_raise_file_not_found_error_if_filename_specified_and_atoms_file_not_present(
        tmp_path: Path, inputs: TaskInputs
    ) -> None:
        inputs.atoms = Atoms("C")
        inputs_json = Path(tmp_path, SETTINGS.INPUTS_FILE)
        dumped_inputs = {"task_inputs": inputs.model_dump(exclude={"atoms"})}

        with inputs_json.open(mode="w", encoding="utf-8") as file:
            json.dump(dumped_inputs, file)

        with pytest.raises(FileNotFoundError):
            _ = TaskInputs.from_directory(tmp_path, strict_mode=True)
