"""Represent a reference to a variable."""

from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import MutableMapping
from functools import reduce
import logging
from typing import TYPE_CHECKING
from typing import Annotated
from typing import Any
from typing import Generic
from typing import TypeVar

from pydantic import Field
from pydantic import GetCoreSchemaHandler
from pydantic import ValidatorFunctionWrapHandler
from pydantic_core import CoreSchema
from pydantic_core import core_schema

from autojob.utils.schemas import Unset

if TYPE_CHECKING:
    from autojob.bases.task_base import TaskBase

_T = TypeVar("_T")
_Referenceable = TypeVar("_Referenceable", MutableMapping[str, Any], object)

logger = logging.getLogger(__name__)

AttributePath = Annotated[list[str], Field(min_length=1)]
AttributePaths = Annotated[list[AttributePath], Field(min_length=1)]


# ! Only single source VariableRefences are supported ATM
# ! There should be a check upon Workflow creation for circular
# ! references
class VariableReference(Generic[_T]):
    """A reference to a variable.

    Attributes:
        set_path: A list of strings indicating the path to the variable
            to be set.
        get_path: A list of strings indicating the path to the variable
            to be obtained.
        get_paths: A list of lists of strings each indicating a path to
            a variable to be obtained.
        constant: A value to be used to set the variable.
        composer: A function that takes in an ``AttributePath`` and
            ``AttributePaths`` and returns a value.

    Example: Evaluate the value of a VariableReference
        >>> from autojob.parametrizations import VariableReference
        >>> context = {
        ...     "a": {
        ...         "b": 4,
        ...     }
        ... }
        >>> ref = VariableReference(
        ...     set_path=["a"],
        ...     get_path=["a", "b"],
        ...     constant=4,
        ... )
        >>> ref.evaluate(context)
        4

    Example: Set a value using a VariableReference
        >>> from autojob.parametrizations import VariableReference
        >>> context = {
        ...     "a": {
        ...         "b": 4,
        ...     }
        ... }
        >>> ref = VariableReference(
        ...     set_path=["a"],
        ...     get_path=["a", "b"],
        ...     constant=4,
        ... )
        >>> class Object(object):
        ...     pass
        >>> target_object = Object()
        >>> target_dict = {}
        >>> ref.set_input_value(context, target_object)
        >>> target_object.a
        4
        >>> ref.set_input_value(context, target_dict)
        >>> target_dict["a"]
        4
    """

    def __init__(
        self,
        *,
        set_path: AttributePath,
        get_path: AttributePath | None = None,
        get_paths: AttributePaths | None = None,
        constant: Any = None,
        composer: Callable | None = None,
    ) -> None:
        """Instantiate a ``VariableReference``.

        Args:
            set_path: An ``AttributePath`` indicating the variable to set.
            get_path: An ``AttributePath`` indicating the source variable.
                Defaults to None.
            get_paths: A list of ``AttributePath`` s, each of which will be
                combined to the source variable. Defaults to None.
            constant: A constant value used to set the variable. Defaults to
                None.
            composer: A function that accepts the value of the source
                variable(s) and returns a value to be used to set the
                variable. Defaults to None.
        """
        self.set_path = set_path
        self.get_path = get_path
        self.get_paths = get_paths
        self.constant = constant
        self.composer = composer
        super().__init__()

    @classmethod
    def __get_pydantic_core_schema__(
        cls, source_type: Any, handler: GetCoreSchemaHandler
    ) -> CoreSchema:
        """Get a Pydantic schema."""
        set_path_schema = handler.generate_schema(AttributePath)
        get_path_schema = handler.generate_schema(AttributePath | None)
        get_paths_schema = handler.generate_schema(AttributePaths | None)
        constant_schema = handler.generate_schema(Any)
        composer_schema = handler.generate_schema(Callable | None)

        def _set_path(
            v: VariableReference[Any], handler: ValidatorFunctionWrapHandler
        ) -> VariableReference[Any]:
            v.set_path = handler(v.set_path)
            return v

        def _get_path(
            v: VariableReference[Any], handler: ValidatorFunctionWrapHandler
        ) -> VariableReference[Any]:
            v.get_path = (
                handler(v.get_path) if v.get_path is not None else v.get_path
            )
            return v

        def _get_paths(
            v: VariableReference[Any], handler: ValidatorFunctionWrapHandler
        ) -> VariableReference[Any]:
            v.get_paths = handler(v.get_paths)
            return v

        def _constant(
            v: VariableReference[Any], handler: ValidatorFunctionWrapHandler
        ) -> VariableReference[Any]:
            v.constant = handler(v.constant)
            return v

        def _composer(
            v: VariableReference[Any], handler: ValidatorFunctionWrapHandler
        ) -> VariableReference[Any]:
            v.composer = handler(v.composer)
            return v

        python_schema = core_schema.chain_schema(
            [
                core_schema.is_instance_schema(cls),
                core_schema.no_info_wrap_validator_function(
                    _set_path, set_path_schema
                ),
                core_schema.no_info_wrap_validator_function(
                    _get_path, get_path_schema
                ),
                core_schema.no_info_wrap_validator_function(
                    _get_paths, get_paths_schema
                ),
                core_schema.no_info_wrap_validator_function(
                    _constant, constant_schema
                ),
                core_schema.no_info_wrap_validator_function(
                    _composer, composer_schema
                ),
            ]
        )

        return core_schema.json_or_python_schema(
            json_schema=core_schema.chain_schema(
                [
                    core_schema.typed_dict_schema(
                        {
                            "set_path": core_schema.typed_dict_field(
                                set_path_schema
                            ),
                            "get_path": core_schema.typed_dict_field(
                                set_path_schema,
                                required=False,
                            ),
                            "get_paths": core_schema.typed_dict_field(
                                get_paths_schema,
                                required=False,
                            ),
                            # ! Use default JSON caster
                            "constant": core_schema.typed_dict_field(
                                constant_schema,
                                required=False,
                            ),
                            "composer": core_schema.typed_dict_field(
                                composer_schema,
                                required=False,
                            ),
                        }
                    ),
                    core_schema.no_info_before_validator_function(
                        lambda data: VariableReference(
                            set_path=data["set_path"],
                            get_path=data.get("get_path", None),
                            get_paths=data.get("get_paths", None),
                            constant=data.get("constant", None),
                            composer=data.get("composer", None),
                        ),
                        python_schema,
                    ),
                ]
            ),
            python_schema=python_schema,
        )

    def evaluate(self, context: _Referenceable) -> _T:
        """Evaluate a variable reference in the given context.

        Args:
            context: A dictionary (or object) containing values to be used to
                evaluate the ``VariableReference``.

        Raises:
            NotImplementedError: ``get_paths`` and ``composer``
            ``VariableReference`` s are not supported.

        Returns:
            The value.
        """
        if self.get_path is not None:
            value: _T = getattrpath(
                context,
                self.get_path,
            )
        elif not all(x is None for x in (self.get_paths, self.composer)):
            msg = "Multiple get paths and composers are not yet implemented"
            raise NotImplementedError(msg)
        else:
            value = self.constant

        return value

    def set_input_value(
        self, context: dict[str, Any], shell: _Referenceable
    ) -> None:
        """Set the value of a key specified by the ``VariableReference``.

        This method modifies ``shell`` in place.

        Args:
            context: A dictionary containing values to be used to evaluate
                the ``VariableReference``.
            shell: A dictionary or object containing values to be set.
        """
        to_set = self.set_path[-1]
        to_get = getattrpath(shell, self.set_path[:-1])
        value = self.evaluate(context)

        if value == Unset:
            logger.info(f"Unsetting value: {to_set}")
            if isinstance(to_get, Mapping):
                del to_get[to_set]
            else:
                delattr(to_get, to_set)
        else:
            logger.info(f"Setting value: {to_set} to: {value}")
            if isinstance(to_get, Mapping):
                to_get[to_set] = value
            else:
                setattr(to_get, to_set, value)


def getattrpath(obj: _Referenceable, path: Iterable[str]) -> Any:
    """Access an attribute or dictionary value with an attribute path.

    Args:
        obj: A dictionary or object.
        path: A iterable of strings indicating the sequence of attributes or
            dictionary keys pointing to the value to get.

    Returns:
        The attribute or dictionary value.
    """

    def _get(attr, o) -> _T:
        if isinstance(o, Mapping):
            if attr not in o:
                o[attr] = {}
            return o.get(attr)
        return getattr(o, attr)

    return reduce(
        lambda _obj, _name: _get(_name, _obj),
        path,
        obj,
    )


# TODO: Add support for task_mods, opt_mods, and analysis_mods
def create_parametrization(
    previous: "TaskBase",
    calc_mods: dict[str, Any] | None = None,
    sched_mods: dict[str, Any] | None = None,
    exclude_metadata: Iterable[str] | None = None,
) -> list[VariableReference[Any]]:
    """Create a parametrization from parameter mod and a previous task.

    Args:
        previous: A :class:`.calculation.Calculation` representing the
            previous calculation.
        calc_mods: A dictionary containing modifications to calculator
            **parameters**. Defaults to an empty dictionary.
        sched_mods: A dictionary containing modifications to scheduler
            inputs. Defaults to an empty dictionary.
        exclude_metadata: A list of metadata fields to exclude from the
            parametrization.

    Returns:
        A list of ``VariableReference`` s that can be used to set the values
        of the new calculation.

    Warning:
        When specifying `sched_mods`, be wary of setting mutually exclusive
        scheduler parameters (e.g, `mem` and `mem_per_cpu` or `cores` and
        `cores_per_node`). For example, if the `mem` parameter is set and one
        wants to set the `mem_per_cpu` parameter, set the `mem` key to `Unset`
        in `sched_mods` in addition to setting the `mem_per_cpu` key.
    """
    calc_mods = calc_mods or {}
    sched_mods = sched_mods or {}
    exclude_metadata = set(exclude_metadata) if exclude_metadata else None
    metadata = previous.task_metadata.model_dump(exclude=exclude_metadata)
    task_inputs = previous.task_inputs.model_dump(exclude=exclude_metadata)

    parametrization: list[VariableReference] = []
    # order matters since earlier VariableReferences will be overwritten by
    # later ones
    targets_and_sources = [
        (["task_metadata"], metadata),
        (["task_inputs"], task_inputs),
    ]

    if hasattr(previous, "calculation_inputs") and hasattr(
        previous, "scheduler_inputs"
    ):
        calc_inputs = previous.calculation_inputs.model_dump(exclude_none=True)
        sched_inputs = previous.scheduler_inputs.model_dump(exclude_none=True)
        targets_and_sources.extend(
            [
                (["calculation_inputs"], calc_inputs),
                # calc_mods only modify calc_params
                (["calculation_inputs", "calc_params"], calc_mods),
                (["scheduler_inputs"], sched_inputs),
                (["scheduler_inputs"], sched_mods),
            ]
        )

    for target, source in targets_and_sources:
        for input_, value in source.items():
            set_path = [*target, input_]
            parametrization.append(
                VariableReference(set_path=set_path, constant=value)
            )

    return parametrization
