"""Utilities for creating tasks from existing task directories."""

import json
import logging
from pathlib import Path
import shutil
import subprocess
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal

import click

from autojob import SETTINGS
from autojob.bases.task_base import TASK_GROUP_FIELDS
from autojob.harvest.archive import archive_json
from autojob.hpc import get_scheduler_command
from autojob.plugins import get_task_class
from autojob.utils.files import copy_permissions_and_ownership
from autojob.utils.templates import substitute_placeholders

if TYPE_CHECKING:
    from autojob.bases.task_base import TaskBase
    from autojob.parametrizations import VariableReference
    from autojob.workflow import Step

logger = logging.getLogger(__name__)
FILE_SIZE_LIMIT = 1e8  # 100 MB
# If this file is present in the task directory, the task will not be restarted
STOP_FILE = "autojob.stop"


def finalize_task(
    src: Path,
    task: "TaskBase",
    record_task: bool = False,
) -> None:
    """Archive a completed task and note its completion in the study record.

    Args:
        src: A Path object indicating in which directory to archive the task.
        task: The task to finalize.
        record_task: Whether or not to record the completion of the task in
            the study record. Defaults to False.
    """
    if record_task:
        with Path(src.parents[1], SETTINGS.RECORD_FILE).open(
            mode="a", encoding="utf-8"
        ) as file:
            file.write(f"{task.task_metadata.task_id}\n")

    archive_file = Path(src, SETTINGS.ARCHIVE_FILE)
    if not archive_file.exists():
        archive_json([task], archive_file)


def substitute_context(
    mods: dict[str, Any],
    context: dict[str, Any],
) -> dict[str, Any]:
    """Substitute context values into formatted strings.

    Args:
        mods: A dictionary mapping parameter names to values. String values
            will be subsituted according to context values.
        context: A dictionary mapping variable names to their values.
            Variables with names corresponding to template names will be
            substituted.

    Returns:
        A copy of `mods` with templated values substituted for their variable
        values.
    """
    new_mods: dict[str, Any] = {}

    for key, value in mods.items():
        if isinstance(value, str):
            new_mods[key] = substitute_placeholders(value, **context)
        else:
            new_mods[key] = value

    return new_mods


def initialize_task(
    *,
    task_class: str,
    parametrization: "list[VariableReference[Any]]",
    previous_task: "TaskBase",
    restart: bool = False,
) -> "TaskBase":
    """Setup a new task according to a parametrization.

    Args:
        task_class: A string representing the fully qualified class name
            of the type of task to be created.
        parametrization: The parametrization for the new task.
        previous_task: The previous task.
        restart: Whether or not the task to be created is a restart of a
            previous task. Defaults to False.

    Note:
        `parametrization` **must** specify all parameters which are to be
        inherited from `previous_task`. Any parameters that are not set by
        `parametrization` will assume their default values.

    Returns:
        The new TaskBase instance.
    """
    task_type = get_task_class(task_class)
    context = previous_task.model_dump(exclude_none=True)
    task_shell = task_type().model_dump(exclude_none=True)  # type: ignore[call-arg]

    for ref in parametrization:
        ref.set_input_value(context, task_shell)

    # TODO: add previous_task to TaskMetadataBase.parents once implemented
    task_id = context["task_metadata"]["task_id"]
    task_shell["task_metadata"]["tags"].append(str(task_id))

    if not restart:
        task_group_id = context["task_metadata"]["task_group_id"]
        task_shell["task_metadata"]["tags"].append(str(task_group_id))

    new_task = task_type(**task_shell)
    return new_task


def _substitute_dir_index(dir_name: str, dest: Path) -> str:
    i = 1
    dir_name = dir_name.replace("{i}", str(i))
    files = [f.name for f in dest.iterdir()]
    while dir_name in files:
        dir_name = dir_name.replace("{i}", str(i))
    return dir_name


def _ensure_unique_dir(dir_name: str, dest: Path) -> str:
    files = [f.name for f in dest.iterdir()]
    unique_dir_name = dir_name
    i = 0
    while unique_dir_name in files:
        i += 1
        unique_dir_name = f"{dir_name}_{i}"
        msg = "Directory name %s is not unique. Trying %s"
        logger.info(msg, dir_name, unique_dir_name)
    return unique_dir_name


def _create_templated_dir_name(
    name_template: str, dest: Path, task: "TaskBase"
) -> str:
    context = task.task_metadata.model_dump(exclude_none=True)
    context["structure"] = Path(task.task_inputs.atoms_filename).stem
    dir_name = substitute_placeholders(name_template, **context)
    if "{i}" in dir_name:
        dir_name = _substitute_dir_index(dir_name, dest)
    else:
        dir_name = _ensure_unique_dir(dir_name, dest)
    return dir_name


def _write_task_group_metadata(dest: Path, task: "TaskBase") -> None:
    task_group_metadata_file = Path(dest, SETTINGS.TASK_GROUP_METADATA_FILE)
    task_group_metadata = task.task_metadata.model_dump(
        mode="json", include=set(TASK_GROUP_FIELDS)
    )
    task_group_metadata["tasks"] = [str(task.task_metadata.task_id)]
    with task_group_metadata_file.open(mode="w", encoding="utf-8") as file:
        json.dump(task_group_metadata, file, indent=4)


# TODO: Expand templating outside of metadata by flattening task
def create_task_group_tree(
    task: "TaskBase",
    dest: Path,
    *,
    src: Path | None = None,
    name_template: str | None = None,
) -> Path:
    """Create a new task group directory.

    In addition to directory creation, this method will create a task group
    metadata file and copy directory permissions and ownership to the new
    directory.

    Args:
        task: The new task for which the task group directory will be made.
        dest: The directory in which the new task group directory will be
            created.
        src: The directory of the completed task. Defaults to None in which
            case permissions and ownership are not set.
        name_template: A template to use for the directory name. Defaults to
            None in which case the task group ID will be used.

    Raises:
        ValueError: Cannot create new task group without task group ID.

    Returns:
        The path to the newly created task group directory.
    """
    if name_template:
        dir_name = _create_templated_dir_name(name_template, dest, task)
    else:
        if task.task_metadata.task_group_id is None:
            msg = "Cannot create new task group without task group ID."
            raise ValueError(msg)
        dir_name = str(task.task_metadata.task_group_id)

    task_group = Path(dest, dir_name)
    task_group.mkdir(parents=True)
    if src:
        copy_permissions_and_ownership(src.parent, task_group)
    _write_task_group_metadata(task_group, task)
    return task_group


def carry_over_files(
    *,
    previous_task_src: Path,
    new_task_dest: Path,
    new_task: "TaskBase",
    files_to_carry_over: list[str] | None = None,
) -> None:
    """Copy files from a previous task directory to a new task directory.

    Args:
        previous_task_src: A Path object representing the directory of the
            completed task.
        new_task_dest: A Path object representing the destination directory of
            the new task.
        new_task: The new task.
        files_to_carry_over: A list of strings indicating the files to
            carry over from the previous task. Defaults to an empty list.
    """
    files_to_carry_over = files_to_carry_over or []
    for file in files_to_carry_over:
        try:
            p = Path(previous_task_src, file)
            copier = shutil.copytree if p.is_dir() else shutil.copy
            copier(
                src=Path(previous_task_src, file),
                dst=Path(new_task_dest, file),
            )

            logger.info(
                "Successfully copied %s to new task directory for task: %s",
                file,
                str(new_task.task_metadata.task_id),
            )
        except FileNotFoundError:
            logger.warning(
                "Unable to copy %s to new task directory for task: %s",
                file,
                str(new_task.task_metadata.task_id),
            )


# TODO: Expand templating outside of metadata by flattening task
def create_task_tree(
    task: "TaskBase",
    dest: Path,
    *,
    src: Path | None = None,
    files_to_carry_over: list[str] | None = None,
    name_template: str | None = None,
) -> Path:
    """Create a new task directory.

    Args:
        task: The new task for which the directory will be made.
        dest: A Path representing the directory in which to create the
            directory of the new task.
        src: The source directory for the new task. Defaults to None in which
            case no files will be carried over.
        files_to_carry_over: A list of strings indicating the files to
            carry over from the previous task. Defaults to None.
        name_template: A template to use for the directory name. Defaults to
            None in which case the task ID will be used.

    Returns:
        The path to the newly created task directory.
    """
    if name_template:
        dir_name = _create_templated_dir_name(name_template, dest, task)
    else:
        dir_name = str(task.task_metadata.task_id)

    task_dest = Path(dest, dir_name)
    task_dest.mkdir(parents=True)
    if src:
        copy_permissions_and_ownership(src, task_dest)
        carry_over_files(
            previous_task_src=src,
            new_task_dest=task_dest,
            new_task=task,
            files_to_carry_over=files_to_carry_over,
        )
    _ = task.write_inputs(dest=task_dest)
    return task_dest


def add_item_to_parent(
    item_id: str,
    metadata_file: Path,
    key: Literal["tasks", "task_groups"],
) -> None:
    """Add the given ID to the details.json of its parent.

    Args:
        item_id: The ID to add.
        metadata_file: The path to the metadata file of the parent to which to
            add the item ID.
        key: The key to which to add. Either ``"tasks"`` or ``"task_groups"``.
    """
    logger.debug(f"Adding {item_id} to {metadata_file}")
    with metadata_file.open(mode="r", encoding="utf-8") as file:
        metadata = json.load(file)

    metadata[key].append(item_id)

    with metadata_file.open(mode="w", encoding="utf-8") as file:
        json.dump(metadata, file, indent=4)
    logger.debug(f"Successfully added {item_id} to {metadata_file}")


def clean_up_task(
    old_job: Path,
    *,
    file_size_limit: float = FILE_SIZE_LIMIT,
    files_to_delete: list[str] | None = None,
) -> None:
    """Deletes large files from copied job.

    Args:
        old_job: A Path object representing the directory holding the
            large files to be deleted.
        file_size_limit: A float specifying the file size in bytes over which
            files will be deleted. Defaults to ``FILE_SIZE_LIMIT``.
        files_to_delete: A list of strings specifying files to delete.
            Defaults to an empty list.
    """
    files_to_delete = files_to_delete or []

    for path in old_job.iterdir():
        if path.name in files_to_delete or (
            not path.is_symlink() and path.stat().st_size >= file_size_limit
        ):
            file = path.resolve()
            file.unlink()
            logger.info(f"{'/'.join(file.parts[-5:])} deleted")


def submit_new_task(new_task: Path) -> None:
    """Submit the newly created job to the Slurm scheduler.

    Args:
        new_task: A Path to the new task's directory.
    """
    logger.info(f"Submitting task in {new_task}")
    cmd = get_scheduler_command(SETTINGS.SCHEDULER)
    output = subprocess.check_output(
        [cmd, SETTINGS.SLURM_SCRIPT],
        cwd=new_task,
        encoding="utf-8",
    )
    output = output.strip("\n")
    job_name = "/".join(new_task.parts[-4:])
    click.echo(f"{output} ({job_name})")
    logger.info("%s (%s)", output, job_name)


def create_next_step(
    *,
    src: Path,
    step: "Step",
    previous_task: "TaskBase",
    file_size_limit: float = FILE_SIZE_LIMIT,
    submit: bool = True,
    restart: bool = False,
    # ? Consider setting at task_group/study level from file
    name_template: str | None = None,
) -> list[tuple["TaskBase", Path]]:
    """Initiate a step by creating all tasks that are ready to start.

    Args:
        src: The source directory for the new tasks. That is, the directory
            containing the recently completed task.
        step: The Step to initiate.
        previous_task: The previous task.
        file_size_limit: A float specifying the threshold above which files
            of this size will be deleted from the source directory. Defaults to
            ``FILE_SIZE_LIMIT``.
        submit: Whether or not to submit the new TaskBases after creation. Defaults
            to True.
        restart: Whether or not the task to be created is a restart of a
            previous task. Defaults to False.
        name_template: A template to use for the directory name. Defaults to
            None in which case the task ID will be used.

    Returns:
        A list of 2-tuples (task, path) where task is the new TaskBase instance
        and path is the path in which it was dumped. For new task groups, `path`
        will point to the new task group directory.
    """
    files_to_carry_over = previous_task.task_inputs.files_to_carry_over
    tasks_and_dirs: list[tuple[TaskBase, Path]] = []
    for parametrization in step.parametrizations:
        new_task = initialize_task(
            task_class=step.task_class,
            parametrization=parametrization,
            previous_task=previous_task,
            restart=restart,
        )
        with TemporaryDirectory() as tmpdir:
            if restart:
                task_dest = Path(tmpdir)
                item_id = str(new_task.task_metadata.task_id)
                final_dest = src.parent
                metadata_file = Path(
                    final_dest, SETTINGS.TASK_GROUP_METADATA_FILE
                )
                key: Literal["tasks", "task_groups"] = "tasks"
            else:
                task_dest = create_task_group_tree(
                    src=src,
                    task=new_task,
                    dest=Path(tmpdir),
                    name_template=name_template,
                )
                item_id = str(new_task.task_metadata.task_group_id)
                final_dest = src.parent.parent
                metadata_file = Path(final_dest, SETTINGS.STUDY_METADATA_FILE)
                key = "task_groups"

            new_task_dir = create_task_tree(
                src=src,
                task=new_task,
                dest=task_dest,
                files_to_carry_over=files_to_carry_over,
                name_template=name_template,
            )
            new_tree = new_task_dir if restart else new_task_dir.parent
            final_tree_name = Path(final_dest, new_tree.name)
            shutil.copytree(src=new_tree, dst=final_tree_name)

        add_item_to_parent(
            item_id=item_id,
            metadata_file=metadata_file,
            key=key,
        )
        clean_up_task(
            old_job=src,
            file_size_limit=file_size_limit,
            files_to_delete=previous_task.task_inputs.files_to_delete,
        )
        tasks_and_dirs.append((new_task, final_tree_name))

        logger.debug(
            "New task%s created %s",
            "" if restart else " group",
            "/".join(final_tree_name.parts[-4:]),
        )

        if submit:
            submit_new_task(new_task=final_tree_name)

    return tasks_and_dirs
