"""Supplement harvested data with data patches.

Oftentimes, you may have additional data that you either

a. can't determine a priori (and thus mark the task with it prior to
   submission), or

b. extract programatically (these may be analyses that require fuzzy
   intuition).

but nonetheless want to store with your data. This module defines some
simple routines and classes to facilitate the latter use-case.

A :class:`Patch` is just that, a "patch" - it fills in the gap in data
that may exist. To define one, you specify to a feature of the data
to which it should be applied and what data should be added when it is
applied.

.. code-block:: python

    from ase import Atoms
    from autojob.harvest.patch import Patch

    pch = Patch(match_path=["study_id"],
        match_value="123456789",
        patch_path=["atoms", "positions"]
        patch_value=[0.0, 0.0, 0.0]
    )

    datapoint1 = {
        "study_id": None,
        "atoms": None
    }

    atoms = Atoms("C", positions=[[0.0, 1.0, 2.0]])
    datapoint2 = {
        "study_id": None,
        "atoms": atoms
    }

    pch.apply(datapoint1)
    print(datapoint1["atoms"])
    None

    pch.apply(datapoint2)
    print(datapoint2["atoms"].positions)
    [0.0, 0.0, 0.0]

To what data the :class:`Patch` will apply is specified by ``match_path`` and
``match_value``. While, what will be applied is specified by ``patch_path`` and
``patch_value``.

Note:
    Patch applies to both dictionaries and objects alike!

Example:
    Apply a set of patches in batch

    .. code-block:: python

        from autojob.task import Task

        tasks = [Task(...), Task(...), ...]
        patches = [Patch(..., Patch(...), ...]

        for task in tasks:
            for patch in patches:
                patch.apply(task)

"""

import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
from typing import NamedTuple

from autojob import SETTINGS
from autojob.parametrizations import VariableReference
from autojob.parametrizations import getattrpath
from autojob.utils.files import find_study_dirs
from autojob.utils.files import find_study_group_dirs
from autojob.utils.files import find_task_group_dirs

if TYPE_CHECKING:
    from autojob.bases.task_base import TaskBase

logger = logging.getLogger(__name__)


class Patch(NamedTuple):
    """A data patch.

    Attributes:
        match_path: A list of attribute/key names used to identify which
            attributes are to be patched by the path.
        match_value: The value of the attribute/key that must match.
        patch_path: The value of the attribute/key to be patched.
        patch_value: The value of the attribute/key to be set.

    """

    match_path: list[str]
    match_value: Any
    patch_path: list[str]
    patch_value: Any

    def apply(self, data: object) -> None:
        """Apply a patch to an object.

        Args:
            data: the data to which the patch will be applied. Note that this
                method may or may not end up modifying ``data``, but if it
                does, it will do in place.
        """
        condition = getattrpath(
            data,
            self.match_path,
        )
        if condition == self.match_value:
            reference = VariableReference(
                set_path=self.patch_path,
                get_path=None,
                constant=self.patch_value,
            )
            reference.set_input_value({}, data)


def patch_tasks(patches: list[Patch], tasks: "list[TaskBase]") -> None:
    """Patch a list of tasks.

    This method modifies ``tasks`` in place.

    Args:
        patches: The patches to apply.
        tasks: The tasks to which the patches will be applied.
    """
    for task in tasks:
        for patch in patches:
            patch.apply(task)


def build_metadata_patches(
    dir_name: Path,
    *,
    metadata_type: Literal[
        "study_group", "study", "task_group"
    ] = "study_group",
    strict_mode: bool | None = None,
) -> list[Patch]:
    """Create patches from metadata files.

    Args:
        dir_name: The name of the directory under which to search for
            metadata. Defaults to the current working directory.
        strict_mode: Whether or not to abort metadata collection if
            metadata cannot be found. Defaults to ``SETTINGS.STRICT_MODE``.
        metadata_type: The type of metadata file from which patches are to be
            built. Must be one of ``"study_group"``, ``"study"``,
            ``"calculation"``. Defaults to ``"study_group"``.

    Returns:
        A list of :class:`Patch` objects which will add metadata
        to :attr:`TaskMetadata.__pydantic_extra__`. Further, patch paths are
        defined such that study group, study, and calculation metadata will be
        added under the ``"study_group_metadata"``, ``"study_metadata"``, and
        ``"calculation_metadata"`` keys, respectively.

    Example:
        Patch study group and study metadata for all tasks in a subdirectory.

        .. code-block:: python

            from pathlib import Path

            from autojob.harvest.harvest import harvest
            from autojob.harvest.patch import build_metadata_patches
            from autojob.harvest.patch import patch_tasks

            dir_name = Path().cwd()
            tasks = harvest(dir_name)
            patches = build_metadata_patches(dir_name)
            patch_tasks(patches, tasks)
    """
    strict_mode = SETTINGS.STRICT_MODE if strict_mode is None else strict_mode
    match metadata_type:
        case "study_group":
            finder = find_study_group_dirs
            filename = SETTINGS.STUDY_GROUP_METADATA_FILE
            subsidiary: Literal["study", "task_group", ""] = "study"
        case "study":
            finder = find_study_dirs
            filename = SETTINGS.STUDY_METADATA_FILE
            subsidiary = "task_group"
        case "task_group":
            finder = find_task_group_dirs
            filename = SETTINGS.TASK_GROUP_METADATA_FILE
            subsidiary = ""

    source_dirs = finder(dir_name)
    patches: list[Patch] = []

    for source_dir in source_dirs:
        src = Path(source_dir, filename)
        try:
            with src.open(mode="r", encoding="utf-8") as file:
                metadata = json.load(file)
                match_path = ["task_metadata", f"{metadata_type}_id"]
                patch_path = [
                    "task_metadata",
                    # ? Can this be substituted with model_extra?
                    "__pydantic_extra__",
                    f"{metadata_type}_metadata",
                ]
                patches.append(
                    Patch(
                        match_path=match_path,
                        match_value=source_dir.name,
                        patch_path=patch_path,
                        patch_value=metadata,
                    )
                )
                if subsidiary:
                    patches.extend(
                        build_metadata_patches(
                            dir_name=source_dir,
                            metadata_type=subsidiary,
                            strict_mode=strict_mode,
                        )
                    )
        except FileNotFoundError:
            if strict_mode:
                raise

            logger.warning(
                "Unable to build metadata patches for %s %s",
                metadata_type,
                source_dir,
            )

    return patches
