from collections.abc import Generator
from datetime import UTC
from datetime import datetime
from itertools import groupby
import json
from pathlib import Path

from ase import Atoms
import pytest

from autojob import SETTINGS
from autojob.study import Study
from autojob.tasks.calculation import Calculation


@pytest.fixture(name="files_to_carryover")
def fixture_files_to_carryover() -> list[str]:
    files_to_carryover: list[str] = []

    return files_to_carryover


class TestToDirectory:
    @staticmethod
    def test_should_dump_metadata_file(tmp_path: Path, study: Study) -> None:
        study.to_directory(tmp_path)
        study_directory = Path(tmp_path, str(study.study_id))
        metadata_file = Path(study_directory, SETTINGS.STUDY_METADATA_FILE)
        assert metadata_file.exists()

    @staticmethod
    def test_should_dump_correct_study_metadata(
        tmp_path: Path, study: Study
    ) -> None:
        def tg_key(t):
            return str(t.task_metadata.task_group_id)

        dumped_study = study.model_dump(mode="json", exclude={"tasks"})
        tasks = sorted(study.tasks, key=tg_key)
        ids_and_tgs = list(groupby(tasks, key=tg_key))
        dumped_task_groups = [str(x) for x, _ in ids_and_tgs]
        study.to_directory(tmp_path)
        study_directory = Path(tmp_path, str(study.study_id))
        metadata_file = Path(study_directory, SETTINGS.STUDY_METADATA_FILE)

        with metadata_file.open(mode="r", encoding="utf-8") as file:
            metadata = json.load(file)

        # compare tasks separately
        loaded_task_groups = metadata.pop("task_groups")

        assert dumped_study == metadata
        assert dumped_task_groups == loaded_task_groups

    @staticmethod
    def test_should_dump_all_tasks_to_directories(
        tmp_path: Path, study: Study
    ) -> None:
        study.to_directory(tmp_path)
        study_directory = Path(tmp_path, str(study.study_id))
        tasks_in_dir = []

        for task in study.tasks:
            dir_name = str(task.task_metadata.task_group_id)
            tasks_in_dir.append(Path(study_directory, dir_name).exists())

        assert all(tasks_in_dir)


class TestFromDirectory:
    @staticmethod
    @pytest.fixture(name="dump_study")
    def fixture_dump_study(study_group_directory: Path, study: Study) -> None:
        study.to_directory(study_group_directory)
        print("")

    @staticmethod
    @pytest.fixture(name="write_output_atoms", autouse=True)
    def fixture_write_output_atoms(
        dump_study: None,  # noqa: ARG004
        task_directory: Path,
        calculate: None,  # noqa: ARG004
        output_atoms: Atoms,
    ) -> Generator[None, None, None]:
        output_atoms_filename = Path(
            task_directory, SETTINGS.OUTPUT_ATOMS_FILE
        )
        output_atoms.write(output_atoms_filename)
        yield None
        output_atoms_filename.unlink()

    @staticmethod
    @pytest.fixture(name="loaded_study")
    def fixture_loaded_study(
        dump_study: None,  # noqa: ARG004
        study_directory: Path,
        write_slurm_output_file: Path,  # noqa: ARG004
        write_output_atoms: None,  # noqa: ARG004
    ) -> Study:
        return Study.from_directory(study_directory, strict_mode=False)

    @staticmethod
    def test_should_generate_same_metadata(
        study: Study,
        loaded_study: Study,
    ) -> None:
        dumped_study = study.model_dump()
        _loaded_study = loaded_study.model_dump()
        _ = dumped_study.pop("tasks")
        _ = _loaded_study.pop("tasks")
        assert _loaded_study == dumped_study

    @staticmethod
    def test_should_generate_same_tasks(
        study: Study, loaded_study: Study
    ) -> None:
        to_exclude = {
            "task_outputs",
            "scheduler_outputs",
            "calculation_outputs",
        }
        dumped_task = study.tasks[0].model_dump(exclude=to_exclude)
        loaded_task = loaded_study.tasks[0].model_dump(exclude=to_exclude)

        # Task Outputs, URIs, last updated expected to misalign
        dumped_task["task_metadata"].pop("uri")
        loaded_task["task_metadata"].pop("uri")
        dumped_task["task_metadata"].pop("last_updated")
        loaded_task["task_metadata"].pop("last_updated")

        assert dumped_task == loaded_task


class TestFromDirectoryCalculation(TestFromDirectory):
    @staticmethod
    @pytest.fixture(name="study")
    def fixture_study(
        calculation: Calculation, study_id: str, study_group_id: str
    ) -> Study:
        study = Study(
            tasks=[calculation],
            date_created=datetime.now(tz=UTC),
            study_id=study_id,
            study_group_id=study_group_id,
        )
        return study
