from datetime import datetime
from datetime import timedelta
import logging
from pathlib import Path
from typing import Any

from pydantic import ValidationError
import pytest

from autojob.hpc import Partition
from autojob.hpc import SchedulerInputs
from autojob.hpc import SchedulerOutputs
from autojob.hpc import validate_memory
from autojob.tasks.calculation import Calculation
from autojob.utils.parsing import TimedeltaTuple

logger = logging.getLogger(__name__)

DEFAULT_ELAPSED = timedelta(days=1, hours=0, minutes=0, seconds=0)

DEFAULT_SHEBANG = ["#!/usr/bin/bash"]


class TestSchedulerInputsTime:
    @staticmethod
    def test_should_validate_timedelta(
        elapsed: timedelta,
    ) -> None:
        assert SchedulerInputs(time=elapsed).time == elapsed

    @staticmethod
    def test_should_validate_slurm_time(
        elapsed: timedelta,
    ) -> None:
        assert (
            SchedulerInputs.model_validate(
                {
                    "time": TimedeltaTuple.from_timedelta(
                        elapsed
                    ).to_slurm_time()
                },
                context={"format": "slurm"},
            ).time
            == elapsed
        )

    @staticmethod
    def test_should_not_validate_invalid_string() -> None:
        with pytest.raises(ValidationError):
            SchedulerInputs(time="x")


class TestSchedulerInputsMailType:
    @staticmethod
    @pytest.fixture(name="name", params=["mail_type", "mail-type"])
    def fixture_name(request: pytest.FixtureRequest) -> str:
        name: str = request.param
        return name

    @staticmethod
    def test_should_validate_list_of_strings(
        name: str, mail_type: list[str]
    ) -> None:
        assert (
            SchedulerInputs.model_validate({name: mail_type}).mail_type
            == mail_type
        )

    @staticmethod
    def test_should_validate_comma_separated_strings_in_slurm_format(
        name: str,
        mail_type: list[str],
    ) -> None:
        assert (
            SchedulerInputs.model_validate(
                {name: ",".join(mail_type)}, context={"format": "slurm"}
            ).mail_type
            == mail_type
        )


class TestSchedulerInputsMemory:
    @staticmethod
    @pytest.fixture(name="name", params=["mem", "mem_per_cpu", "mem-per-cpu"])
    def fixture_name(request: pytest.FixtureRequest) -> str:
        name: str = request.param
        return name

    @staticmethod
    def test_should_validate_int(name: str, mem_per_cpu: int) -> None:
        assert (
            getattr(
                SchedulerInputs.model_validate({name: mem_per_cpu}),
                name.replace("-", "_"),
            )
            == mem_per_cpu
        )

    @staticmethod
    def test_should_validate_memory_string(
        name: str,
        mem_per_cpu: int,
    ) -> None:
        assert (
            getattr(
                SchedulerInputs.model_validate(
                    {name: mem_per_cpu}, context={"format": "slurm"}
                ),
                name.replace("-", "_"),
            )
            == mem_per_cpu
        )


class TestSchedulerInputsPartition:
    @staticmethod
    @pytest.fixture(name="name", params=["partition", "partitions"])
    def fixture_name(request: pytest.FixtureRequest) -> str:
        name: str = request.param
        return name

    @staticmethod
    def test_should_validate_list_of_partitions(
        name: str, partitions: list[Partition]
    ) -> None:
        assert (
            SchedulerInputs.model_validate({name: partitions}).partitions
            == partitions
        )

    @staticmethod
    def test_should_validate_memory_string(
        name: str,
        partitions: list[Partition],
    ) -> None:
        assert (
            SchedulerInputs.model_validate(
                {name: ",".join(p.cluster_name for p in partitions)},
                context={"format": "slurm"},
            ).partitions
        ) == partitions


class TestSchedulerInputs:
    @staticmethod
    def test_should_instantiate_dict(
        scheduler_inputs: SchedulerInputs,
    ) -> None:
        assert scheduler_inputs


class TestExtractSchedulerInputs:
    @staticmethod
    @pytest.fixture(name="extracted_parameters")
    def fixture_extracted_parameters(
        write_scheduler_script: Path,
    ) -> dict[str, Any]:
        with write_scheduler_script.open(mode="r", encoding="utf-8") as f:
            extracted_parameters = SchedulerInputs.extract_scheduler_inputs(
                stream=f
            )
            return extracted_parameters

    # This functional test tests a number of cases by virtue of the
    # parametrization (see the "slurm_script" fixture and its requested
    # fixtures). Some key tests:
    #   - with/without bash shebang at top of file
    #   - various values for slurm parameters
    #   - with/without auto-restart logic and files deleted
    #   - slurm options interspersed with code

    @staticmethod
    def test_should_read_all_and_only_options_in_heading(
        scheduler_inputs: SchedulerInputs,
        extracted_parameters: dict[str, Any],
    ) -> None:
        assert (
            extracted_parameters["time"]
            == TimedeltaTuple.from_timedelta(
                scheduler_inputs.time
            ).to_slurm_time()
        )
        assert extracted_parameters["partition"] == ",".join(
            p.cluster_name for p in scheduler_inputs.partitions
        )
        assert extracted_parameters["nodes"] == str(scheduler_inputs.nodes)
        assert extracted_parameters["job-name"] == scheduler_inputs.job_name
        assert extracted_parameters["mail-type"] == ",".join(
            scheduler_inputs.mail_type
        )
        assert extracted_parameters["mail-user"] == scheduler_inputs.mail_user
        # memory is written/specified in MB
        mem = extracted_parameters["mem-per-cpu"].removesuffix("MB")
        assert mem == str(scheduler_inputs.mem_per_cpu)
        assert extracted_parameters["ntasks-per-node"] == str(
            scheduler_inputs.cores_per_node
        )


class TestUpdateValues:
    @staticmethod
    def test_should_update_memory() -> None:
        old_kw = "mem"
        old_value = 1
        inputs = {old_kw: old_value}
        new_kw = "mem_per_cpu"
        new_value = 2
        mods = {new_kw: new_value}
        SchedulerInputs.update_values(inputs, mods)
        assert old_kw not in inputs
        assert inputs[new_kw] == mods[new_kw]

    @staticmethod
    def test_should_update_time() -> None:
        parameter = "nodes"
        value = 1
        mods = {parameter: value}
        inputs = {}
        SchedulerInputs.update_values(inputs, mods)
        assert inputs[parameter] == value


@pytest.mark.usefixtures("write_scheduler_script")
class TestSchedulerInputsFromDirectory:
    @staticmethod
    @pytest.fixture(name="write_scheduler_script")
    def fixture_write_scheduler_script(
        calculation: Calculation, task_directory: Path
    ) -> None:
        calculation.write_task_script(task_directory)

    @staticmethod
    def test_should_load_scheduler_inputs(
        task_directory: Path,
        scheduler_inputs: SchedulerInputs,
    ) -> None:
        loaded_inputs = SchedulerInputs.from_directory(src=task_directory)
        assert loaded_inputs.model_dump() == scheduler_inputs.model_dump()


class TestSchedulerOutputsValidate:
    @staticmethod
    @pytest.fixture(name="val")
    def fixture_val() -> float:
        val = 1.0
        return val

    @staticmethod
    @pytest.fixture(name="units")
    def fixture_units() -> str:
        units = "MB"
        return units

    @staticmethod
    @pytest.fixture(name="num_type")
    def fixture_num_type() -> type[int | float]:
        num_type = int
        return num_type

    @staticmethod
    def test_should_validate_missing_time(submit_time: datetime) -> None:
        outputs = {
            "Submit": submit_time.isoformat(),
        }
        assert SchedulerOutputs.model_validate(outputs)

    @staticmethod
    def test_should_validate_integer_memory(
        val: float, num_type: type[int | float], units: str
    ) -> None:
        memory = f"{num_type(val)}{units}"
        assert validate_memory(memory, lambda x: x, None) == val

    @staticmethod
    @pytest.mark.parametrize("units", [""])
    def test_should_validate_memory_without_units(
        val: float, num_type: type[int | float], units: str
    ) -> None:
        memory = f"{num_type(val)}{units}"
        assert validate_memory(memory, lambda x: x, None) == val

    @staticmethod
    @pytest.mark.parametrize("num_type", [float])
    def test_should_validate_decimal_memory(
        val: float, num_type: type[int | float], units: str
    ) -> None:
        memory = f"{num_type(val)}{units}"
        assert validate_memory(memory, lambda x: x, None) == val


@pytest.mark.usefixtures("write_job_stats_file", "write_slurm_output_file")
class TestSchedulerOutputsFromDirectory:
    @staticmethod
    def test_should_load_scheduler_outputs(
        task_directory: Path,
    ) -> None:
        assert SchedulerOutputs.from_directory(task_directory)

    # regression test
    @staticmethod
    def test_should_load_all_scheduler_outputs_in_stats_file(
        task_directory: Path,
    ) -> None:
        loaded_outputs = SchedulerOutputs.from_directory(
            task_directory
        ).model_dump(mode="json")

        assert loaded_outputs["elapsed"] == "10:50:16"
        assert loaded_outputs["idle_time"] == "09:39:46"
        assert loaded_outputs["job_id"] == 37042290
        assert loaded_outputs["max_rss"] == "115715MB"
        assert loaded_outputs["nodes"] == ["cdr1974"]
        assert loaded_outputs["partition"] == "cpubase_bycore_b3"
        assert loaded_outputs["state"] == "F"
