from collections.abc import Generator
import json
from pathlib import Path
import shutil
from typing import Any

import pytest

from autojob import SETTINGS
from autojob.bases.task_base import TaskBase
from autojob.next import _write_task_group_metadata
from autojob.next.restart import restart
from autojob.tasks.calculation import Calculation


@pytest.mark.usefixtures(
    "write_job_stats_file",
    "calculate",
    "write_slurm_output_file",
    "write_calculation_inputs",
    "write_calculation_outputs",
)
class TestRestart:
    @staticmethod
    @pytest.fixture(name="write_task_group_metadata")
    def fixture_write_task_group_metadata(
        task_group_directory: Path, calculation: Calculation
    ) -> None:
        _write_task_group_metadata(task_group_directory, calculation)

    @staticmethod
    @pytest.fixture(name="restart_calculation")
    def fixture_restart_calculation(
        write_task_group_metadata: None,  # noqa: ARG004
        write_calculation_outputs: None,  # noqa: ARG004
        task_directory: Path,
    ) -> Generator[tuple[TaskBase, Path], None, None]:
        task, task_dir = restart(task_directory, submit=False)
        yield (task, task_dir)
        shutil.rmtree(task_dir)

    @staticmethod
    def test_should_set_up_new_task(
        restart_calculation: tuple[TaskBase, Path],
    ) -> None:
        _, task_dir = restart_calculation
        assert task_dir.exists()

    @staticmethod
    def test_should_add_new_task_id_to_task_group_metadata(
        restart_calculation: tuple[TaskBase, Path], task_group_directory: Path
    ) -> None:
        task, _ = restart_calculation
        metadata_file = Path(
            task_group_directory, SETTINGS.TASK_GROUP_METADATA_FILE
        )
        with metadata_file.open(mode="r", encoding="utf-8") as file:
            metadata: dict[str, Any] = json.load(file)
        assert str(task.task_metadata.task_id) in metadata["tasks"]

    @staticmethod
    def test_should_create_task_metadata_with_old_task_id_in_metadata_tags(
        calculation: Calculation,
        restart_calculation: tuple[TaskBase, Path],
    ) -> None:
        task, task_dir = restart_calculation
        metadata_file = Path(task_dir, SETTINGS.TASK_METADATA_FILE)
        with metadata_file.open(mode="r", encoding="utf-8") as file:
            metadata: dict[str, Any] = json.load(file)
        assert str(calculation.task_metadata.task_id) in metadata["tags"]
        assert (
            str(calculation.task_metadata.task_id) in task.task_metadata.tags
        )
