"""Semi-automatically advance workflows.

Examples:
    Programmatically,

    .. code-block:: python

        from pathlib import Path

        from autojob.advance.advance import advance

        advance(dir_name=Path.cwd())

    From the command-line,

    .. code-block:: console

        autojob advance
"""

import logging
from pathlib import Path
from typing import TYPE_CHECKING

from pydantic import TypeAdapter

from autojob import SETTINGS
from autojob.bases.task_base import TaskOutcome
from autojob.harvest.harvest import harvest
from autojob.next import FILE_SIZE_LIMIT
from autojob.next import create_next_step
from autojob.next import finalize_task
from autojob.workflow import Step
from autojob.workflow import Workflow

if TYPE_CHECKING:
    from autojob.bases.task_base import TaskBase

logger = logging.getLogger(__name__)


def get_next_steps(task: "TaskBase", study_dir: Path) -> list[str]:
    """Get the UUIDs of the next steps in the workflow.

    Args:
        task: The previous task.
        study_dir: The root directory of the study containing the completed
            task.

    Returns:
        A list of strings representing the steps that should be started since
        `task` has completed. If the task is to be restarted, the list will
        only contain a single string: the workflow step ID of the previous
        task.
    """
    logger.debug(f"Determining next steps for {task.task_metadata.task_id}")
    wfw = Workflow.from_directory(study_dir)
    nodes = iter(wfw.static_order())
    try:
        # ! For backwards-compatibility, assume only the first task (a
        # ! relaxation Calculation) can fail; if it does, restart
        next_steps = [next(nodes)]
    except StopIteration:
        next_steps = []

    num_next_steps = len(next_steps)
    logger.debug(
        f"{num_next_steps} next step{'' if num_next_steps == 1 else 's'}"
    )

    return next_steps


# TODO: Expand templating outside of metadata by flattening task
# TODO: support/implement setting the "label" key
# ! Note that all parametrizations of a given step are currently
# ! initiated at once
# TODO: Support calc/slurm_mods
def advance(
    *,
    src: Path,
    file_size_limit: float = FILE_SIZE_LIMIT,
    submit: bool = True,
    name_template: str | None = None,
) -> list[tuple["TaskBase", Path]]:
    """Advance to the next task in the workflow.

    Args:
        src: The directory of the completed calculation.
        file_size_limit: A float specifying the threshold above which files
            of this size will be deleted. Defaults to FILE_SIZE_LIMIT.
        submit: Whether or not to submit the new job after creation. Defaults
            to True.
        name_template: A template to use for the directory name. Defaults to
            None in which case the task ID will be used.

    Raises:
        RuntimeError: Task failed! Cannot advance to next step!

    Returns:
        A list of tuples (task_i, path_i) where task_i is the ith created task
        and path_i is the Path representing the directory containing the ith
        created task.
    """
    logger.debug(f"Advancing task in {src}")
    study_dir = src.parent.parent
    completed_task = next(iter(harvest(src, use_cache=True)))

    if (
        completed_task.task_outputs is None
        or completed_task.task_outputs.outcome != TaskOutcome.SUCCESS
    ):
        msg = "Task failed! Cannot advance to next step!"
        raise RuntimeError(msg)

    finalize_task(
        src=src,
        task=completed_task,
        record_task=True,
    )
    next_steps = get_next_steps(completed_task, study_dir)
    params = Path(SETTINGS.PARAMETRIZATION_FILE)

    with params.open(mode="r", encoding="utf-8") as file:
        steps = TypeAdapter(dict[str, Step]).validate_json(file.read())

    tasks_and_dirs: list[tuple[TaskBase, Path]] = []

    for step in next_steps:
        new_tasks_and_dirs = create_next_step(
            src=src,
            step=steps[step],
            previous_task=completed_task,
            file_size_limit=file_size_limit,
            submit=submit,
            restart=False,
            name_template=name_template,
        )
        tasks_and_dirs.extend(new_tasks_and_dirs)
    logger.debug(f"Successfully advanced from task in {src}")
    return tasks_and_dirs
