from datetime import UTC
from datetime import datetime
from typing import Any

from ase import Atoms
import pytest

from autojob.bases.task_base import TaskBase
from autojob.hpc import SchedulerInputs
from autojob.parametrizations import AttributePath
from autojob.parametrizations import VariableReference
from autojob.parametrizations import create_parametrization
from autojob.tasks.calculation import Calculation
from autojob.tasks.calculation import CalculationInputs
from autojob.tasks.task import Task
from autojob.tasks.task import TaskInputs
from autojob.utils.schemas import Unset


class TestSetInputValue:
    @staticmethod
    @pytest.fixture(name="path_key")
    def fixture_path_key() -> str:
        path_key = "a"

        return path_key

    @staticmethod
    @pytest.fixture(name="set_path")
    def fixture_set_path(path_key: str) -> AttributePath:
        set_path = [path_key]
        return set_path

    @staticmethod
    def test_should_delete_unset_variable(
        set_path: AttributePath, path_key: str
    ) -> None:
        constant = Unset
        ref: VariableReference = VariableReference(
            set_path=set_path, constant=constant
        )
        context = {}
        shell = {path_key: None}
        res = {}
        ref.set_input_value(context=context, shell=shell)
        assert shell == res

    @staticmethod
    @pytest.mark.parametrize("constant", [1, 1.0, "", None, {}])
    def test_should_set_value_to_constant_for_mapping(
        constant: Any,
        set_path: AttributePath,
        path_key: str,
    ) -> None:
        ref: VariableReference = VariableReference(
            set_path=set_path, constant=constant
        )
        context = {}
        shell = {}
        res = {path_key: constant}
        ref.set_input_value(context=context, shell=shell)
        assert shell == res

    @staticmethod
    @pytest.mark.parametrize("constant", [1, 1.0, "", None, {}])
    def test_should_set_value_to_constant_for_object(
        constant: Any,
        set_path: AttributePath,
    ) -> None:
        ref: VariableReference = VariableReference(
            set_path=set_path, constant=constant
        )

        # Attribute must match last item in set_path
        class TestClass:
            def __init__(self, val=0) -> None:
                self.a = val

        context = {}
        shell = TestClass()
        res = TestClass(constant)
        ref.set_input_value(context=context, shell=shell)
        assert getattr(shell, set_path[-1]) == getattr(res, set_path[-1])

    @staticmethod
    @pytest.mark.parametrize(("constant", "set_path"), [(1, ("a", "b"))])
    def test_should_set_value_to_constant_for_nested_object(
        constant: Any,
        set_path: AttributePath,
    ) -> None:
        ref: VariableReference = VariableReference(
            set_path=set_path, constant=constant
        )

        # Attribute must match last item in set_path
        class TestClass:
            def __init__(self, val=0) -> None:
                self.a = {"b": val}

        context = {}
        shell = TestClass()
        res = TestClass(constant)
        ref.set_input_value(context=context, shell=shell)
        assert shell.a.get(set_path[-1]) == res.a.get(set_path[-1])


class TestCreateParametrization:
    @staticmethod
    @pytest.fixture(name="calc_params")
    def fixture_calc_params() -> dict[str, Any]:
        return {}

    @staticmethod
    @pytest.fixture(name="task_inputs")
    def fixture_task_inputs() -> TaskInputs:
        return TaskInputs(atoms=Atoms("C"))

    @staticmethod
    @pytest.fixture(name="calculation_inputs")
    def fixture_calculation_inputs(
        calc_params: dict[str, Any],
    ) -> CalculationInputs:
        return CalculationInputs(calc_params=calc_params)

    @staticmethod
    @pytest.fixture(name="scheduler_inputs")
    def fixture_scheduler_inputs() -> SchedulerInputs:
        return SchedulerInputs()

    @staticmethod
    @pytest.fixture(name="task")
    def fixture_task(
        task_class: type[TaskBase],
        task_inputs: TaskInputs,
        calculation_inputs: CalculationInputs,
        scheduler_inputs: SchedulerInputs,
    ) -> TaskBase:
        if issubclass(task_class, Calculation):
            extra_inputs = {
                "calculation_inputs": calculation_inputs,
                "scheduler_inputs": scheduler_inputs,
            }
        else:
            extra_inputs = {}
        return task_class(task_inputs=task_inputs, **extra_inputs)

    @staticmethod
    @pytest.mark.parametrize("task_class", [Task, Calculation])
    def test_should_create_parametrization_that_does_not_change_initial_task_when_no_mods_passed_and_exclude_is_none(
        task: TaskBase,
    ) -> None:
        initial_task = task.model_dump()
        parametrization = create_parametrization(task)

        for ref in parametrization:
            ref.set_input_value(initial_task, task)

        dumped_task = task.model_dump()
        initial_atoms = initial_task["task_inputs"].pop("atoms")
        dumped_atoms = dumped_task["task_inputs"].pop("atoms")

        assert initial_task == dumped_task
        assert initial_atoms == dumped_atoms

    @staticmethod
    @pytest.mark.parametrize("task_class", [Task, Calculation])
    def test_should_create_parametrization_with_identical_task_inputs(
        task: TaskBase,
    ) -> None:
        parametrization = create_parametrization(task)
        task_inputs = task.task_inputs.model_dump()
        task_inputs_identical = []
        for ref in parametrization:
            if "task_inputs" in ref.set_path:
                value = ref.evaluate(task)
                task_inputs_identical.append(
                    value == task_inputs[ref.set_path[-1]]
                )
        assert all(task_inputs_identical)

    @staticmethod
    @pytest.mark.parametrize("task_class", [Calculation])
    def test_should_create_parametrization_with_identical_calc_inputs_when_calc_mods_empty(
        task: Calculation,
    ) -> None:
        parametrization = create_parametrization(task)
        calc_inputs = task.calculation_inputs.model_dump()
        calc_inputs_identical = []
        for ref in parametrization:
            if "calculation_inputs" in ref.set_path:
                value = ref.evaluate(task)
                calc_inputs_identical.append(
                    value == calc_inputs[ref.set_path[-1]]
                )
        assert all(calc_inputs_identical)

    @staticmethod
    @pytest.mark.parametrize("task_class", [Calculation])
    def test_should_create_parametrization_with_identical_scheduler_params_when_slurm_mods_empty(
        task: Calculation,
    ) -> None:
        parametrization = create_parametrization(task)
        sched_inputs = task.scheduler_inputs.model_dump()
        sched_inputs_identical = []
        for ref in parametrization:
            if "scheduler_inputs" in ref.set_path:
                value = ref.evaluate(task)
                sched_inputs_identical.append(
                    value == sched_inputs[ref.set_path[-1]]
                )
        assert all(sched_inputs_identical)

    @staticmethod
    @pytest.mark.parametrize("task_class", [Task, Calculation])
    def test_should_create_parametrization_with_identical_metadata_when_exclude_metadata_is_none(
        task: TaskBase,
    ) -> None:
        parametrization = create_parametrization(task)
        task_metadata = task.task_metadata.model_dump()
        task_metadata_identical = []
        for ref in parametrization:
            if "task_metadata" in ref.set_path:
                value = ref.evaluate(task)
                task_metadata_identical.append(
                    value == task_metadata[ref.set_path[-1]]
                )
        assert all(task_metadata_identical)

    @staticmethod
    @pytest.mark.parametrize("task_class", [Task, Calculation])
    def test_should_create_parametrization_with_identical_metadata_when_exclude_metadata_empty(
        task: TaskBase,
    ) -> None:
        parametrization = create_parametrization(task)
        task_metadata = task.task_metadata.model_dump()
        task_metadata_identical = []
        for ref in parametrization:
            if "task_metadata" in ref.set_path:
                value = ref.evaluate(task)
                task_metadata_identical.append(
                    value == task_metadata[ref.set_path[-1]]
                )
        assert all(task_metadata_identical)

    @staticmethod
    @pytest.mark.parametrize("task_class", [Calculation])
    @pytest.mark.parametrize("calc_mods", [{"kpts": (1, 1, 1), "encut": 450}])
    def test_should_create_parametrization_with_calc_mods(
        task: Calculation, calc_mods: dict[str, Any]
    ) -> None:
        parametrization = create_parametrization(task, calc_mods=calc_mods)
        calc_params_modified = []
        for ref in parametrization:
            if "calc_params" in ref.set_path and ref.set_path[-1] in calc_mods:
                value = ref.evaluate(task)
                calc_params_modified.append(
                    value == calc_mods[ref.set_path[-1]]
                )
        assert all(calc_params_modified)

    @staticmethod
    @pytest.mark.parametrize("task_class", [Calculation])
    @pytest.mark.parametrize(
        "sched_mods",
        [
            {
                "mem": 100,
                "time": datetime(year=1, month=1, day=1, tzinfo=UTC),
            }
        ],
    )
    def test_should_create_parametrization_with_sched_mods(
        task: Calculation, sched_mods: dict[str, Any]
    ) -> None:
        parametrization = create_parametrization(task, sched_mods=sched_mods)
        sched_params_modified = []
        for ref in parametrization:
            if (
                "scheduler_inputs" in ref.set_path
                and ref.set_path[-1] in sched_mods
            ):
                value = ref.evaluate(task)
                sched_params_modified.append(
                    value == sched_mods[ref.set_path[-1]]
                )
        assert all(sched_params_modified)

    @staticmethod
    @pytest.mark.parametrize("task_class", [Task, Calculation])
    @pytest.mark.parametrize(
        "exclude_metadata", [{"task_id", "uri", "date_created"}]
    )
    def test_should_exclude_task_metadata_in_exclude_metadata(
        task: TaskBase, exclude_metadata: set[str]
    ) -> None:
        parametrization = create_parametrization(
            task, exclude_metadata=exclude_metadata
        )
        metadata_excluded = []
        for ref in parametrization:
            if "task_metadata" in ref.set_path:
                metadata_excluded.append(
                    ref.set_path[-1] not in exclude_metadata
                )
        assert all(metadata_excluded)
