"""Restart a completed task.

Examples:
    Programmatically,

    .. code-block:: python

        from pathlib import Path

        from autojob.next.restart import restart

        restart(src=Path.cwd())

    From the command-line,

    .. code-block:: console

        autojob restart
"""

from collections.abc import Iterable
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any

from autojob.harvest.harvest import harvest
from autojob.next import FILE_SIZE_LIMIT
from autojob.next import create_next_step
from autojob.next import finalize_task
from autojob.next import substitute_context
from autojob.parametrizations import VariableReference
from autojob.parametrizations import create_parametrization
from autojob.workflow import Step

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
    from autojob.bases.task_base import TaskBase

_to_exclude_on_restart = ["task_id", "uri", "date_created", "last_updated"]


def restart(
    src: str | Path | None = None,
    *,
    calc_mods: dict[str, Any] | None = None,
    sched_mods: dict[str, Any] | None = None,
    file_size_limit: float = FILE_SIZE_LIMIT,
    submit: bool = True,
    auto_restart: bool = False,  # noqa: ARG001
    files_to_carry_over: Iterable[str] | None = None,  # noqa: ARG001
    name_template: str | None = None,
) -> tuple["TaskBase", Path]:
    """Advance to the next task in the workflow.

    Args:
        src: The directory of the completed task. Defaults to the current
            working directory.
        calc_mods: A dictionary mapping calculator parameters to values that
            should be used to overwrite the existing parameters.
        sched_mods: A dictionary mapping Slurm options to values that
            should be used to overwrite the existing parameters.
        file_size_limit: A float specifying the threshold above which files
            of this size will be deleted. Defaults to FILE_SIZE_LIMIT.
        submit: Whether or not to submit the new job after creation. Defaults
            to True.
        auto_restart: Whether or not to add logic to automatically restart the
            calculation after the calculation has converged.
        files_to_carry_over: A list of strings indicating which files to carry
            over from the old job directory to the new job directory. Defaults
            to None, in which case, the files to copy are determined from the
            previous task.
        name_template: A template to use for the directory name. Defaults to
            None in which case the task ID will be used.

    Returns:
        A list of tuples (task_i, path_i) where task_i is the ith created task
        and path_i is the Path representing the directory containing the ith
        created task.

    Warning:
        When specifying `sched_mods`, be wary of setting mutually exclusive
        scheduler parameters (e.g, `mem` and `mem_per_cpu` or `cores` and
        `cores_per_node`). For example, if the `mem` parameter is set and one
        wants to set the `mem_per_cpu` parameter, set the `mem` key to `Unset`
        in `sched_mods` in addition to setting the `mem_per_cpu` key.
    """
    src = Path(src) if src else Path.cwd()
    logger.debug(f"Restarting task in {src}")
    calc_mods = calc_mods or {}
    sched_mods = sched_mods or {}
    completed_task = next(iter(harvest(src, use_cache=True)))
    finalize_task(
        src=src,
        task=completed_task,
        # TODO: must implement record keeping first, then expose as arg
        record_task=False,
    )

    # Template scheduler inputs with task metadata and structure name
    context = completed_task.task_metadata.model_dump(exclude_none=True)
    context["structure"] = (
        Path(completed_task.task_inputs.atoms_filename).stem or "{structure}"
    )
    sched_mods = substitute_context(sched_mods, context)

    parametrization = create_parametrization(
        completed_task,
        calc_mods=calc_mods,
        sched_mods=sched_mods,
        exclude_metadata=_to_exclude_on_restart,
    )
    # Set new input atoms to completed task's output atoms
    parametrization.append(
        VariableReference(
            set_path=["task_inputs", "atoms"],
            constant=completed_task.task_outputs.atoms,
        )
    )
    step = Step(
        workflow_step_id=completed_task.task_metadata.workflow_step_id,
        task_class=completed_task.task_metadata.task_class,
        progression="independent",
        parametrizations=[parametrization],
    )

    new_task, new_task_dir = next(
        iter(
            create_next_step(
                src=src,
                step=step,
                previous_task=completed_task,
                file_size_limit=file_size_limit,
                submit=submit,
                restart=True,
                name_template=name_template,
            )
        )
    )
    logger.debug(f"Successfully created restart task from {src}")
    return new_task, new_task_dir
