"""This module defines a concrete implementation of :class:`TaskBase`."""

import json
import logging
from pathlib import Path
from typing import Any
from typing import Self
from uuid import UUID

import ase.io
import jinja2
from pydantic import UUID4
from pydantic import Field
from pydantic import ValidationError
from pydantic import ValidationInfo
from pydantic import ValidatorFunctionWrapHandler
from pydantic import field_validator

from autojob import SETTINGS
from autojob.bases.task_base import SetTaskClassMixin
from autojob.bases.task_base import TaskBase
from autojob.bases.task_base import TaskInputsBase
from autojob.bases.task_base import TaskMetadataBase
from autojob.bases.task_base import TaskOutputsBase
from autojob.plugins import get_task_class
from autojob.utils.files import get_loader
from autojob.utils.files import get_uri

logger = logging.getLogger(__name__)

# TODO: define annotated ID_string type for Pydantic models
LEGACY_TASK_ID_LENGTH = 10


class TaskMetadata(TaskMetadataBase):
    """A concrete implementation of TaskMetadataBase."""

    @field_validator(
        "study_group_id", "study_id", "task_group_id", "task_id", mode="wrap"
    )
    @classmethod
    def validate_ids(
        cls,
        v: Any,
        handler: ValidatorFunctionWrapHandler,
        info: ValidationInfo,
    ) -> str | UUID4:
        """Validate an ID.

        IDs can either be a UUID or a 10-digit alphanumeric shortuuid string.
        """
        value = v
        try:
            value = handler(v)
            # string validation is insuffucient, so we only accept validation
            # on UUIDs
            # TODO: define annotated ID_string
            if isinstance(value, UUID):
                return value
        except ValidationError:
            pass

        if (
            isinstance(value, str)
            and len(value) == LEGACY_TASK_ID_LENGTH
            and v.isalnum()
        ):
            return value

        if (
            info.field_name
            in (
                "study_group_id",
                "study_id",
                "workflow_step_id",
                "task_group_id",
            )
            and value is None
        ):
            return None

        msg = f"{v} is not a UUID4 or a 10-digit alphanumeric shortuuid string"
        raise ValueError(msg)

    # TODO: use get_last_updated and test
    @classmethod
    def from_directory(cls, src: str | Path) -> "TaskMetadataBase":
        """Create a TaskMetadata document from a task directory."""
        logger.debug(f"Loading task metadata from {src}")

        task_file = Path(src).joinpath(SETTINGS.TASK_METADATA_FILE)
        with task_file.open(mode="r", encoding="utf-8") as file:
            raw_metadata: dict[str, Any] = json.load(file)

        raw_metadata["uri"] = get_uri(dir_name=src)

        logger.debug(f"Successfully loaded task metadata from {src}")
        return cls(**raw_metadata)


class TaskInputs(TaskInputsBase):
    """The set of task-level inputs."""

    @classmethod
    def from_directory(cls, src: str | Path, **kwargs) -> "TaskInputs":
        """Generate a TaskInputs document from a completed task's directory.

        Args:
            src: The directory of a completed Task.
            kwargs: Additional keyword arguments:

            - strict_mode: Whether or not to raise an error if the input
                atoms are not found. Defaults to `SETTINGS.STRICT_MODE`.

        Returns:
            A class:`TaskInputs` object.
        """
        strict_mode = kwargs.get("strict_mode", SETTINGS.STRICT_MODE)
        logger.debug("Loading task inputs from directory: %s", src)
        logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis")
        inputs_json = Path(src, SETTINGS.INPUTS_FILE)

        with inputs_json.open(mode="r", encoding="utf-8") as file:
            data = json.load(file)
            inputs = cls(**data.pop("task_inputs"), **data)

        try:
            if inputs.atoms_filename:
                logger.debug("Retrieving input atoms")
                inputs.atoms = ase.io.read(Path(src, inputs.atoms_filename))
                logger.info(
                    "Successfully loaded input atoms from directory: %s", src
                )
            else:
                logger.debug("No input atoms to retrieve")
        except FileNotFoundError:
            logger.info(
                "Unable to retrieve input atoms from directory: %s", src
            )
            if strict_mode:
                raise
        logger.info("Successfully loaded task inputs from directory: %s", src)
        return inputs


class TaskOutputs(TaskOutputsBase):
    """The set of task-level outputs."""

    @classmethod
    def from_directory(
        cls,
        src: str | Path,
        **kwargs,
    ) -> "TaskOutputs":
        """Generate a TaskOutputs document from a completed task's directory.

        Args:
            src: The directory of a completed task.
            kwargs: Additional keyword arguments:

            - strict_mode: Whether or not to catch thrown errors. Errors
                will not be caught if ``strict_mode=True``. Defaults to
                `SETTINGS.STRICT_MODE`.

        Returns:
            A :class:`~TaskOutputs` object.
        """
        src = Path(src)
        strict_mode = kwargs.get("strict_mode", SETTINGS.STRICT_MODE)
        logger.debug("Loading task outputs from directory: %s", src)
        logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis")
        structure = Path(src, SETTINGS.OUTPUT_ATOMS_FILE)

        try:
            atoms = ase.io.read(structure)
        except FileNotFoundError:
            if strict_mode:
                raise
            atoms = None
            logger.warning(
                "Unable to retrieve output atoms from directory: %s", src
            )

        logger.debug(
            "Successfully loaded task outputs from directory: %s", src
        )
        return cls(atoms=atoms)


class Task(TaskBase, SetTaskClassMixin):
    """A concrete implementation of TaskBase."""

    task_metadata: TaskMetadata = Field(
        default_factory=TaskMetadata, description="Task metadata"
    )
    task_inputs: TaskInputs = Field(
        default_factory=TaskInputs, description="Task inputs"
    )
    task_outputs: TaskOutputs | None = Field(
        default=None, description="Task outputs"
    )

    @classmethod
    def load_magic(cls, src: str | Path, *, strict_mode: bool = True) -> Self:
        """Load a :class:`~TaskBase` subclass using its "base class" metadata.

        Args:
            src: The directory from which to load the task.
            strict_mode: Whether or not to require all outputs. If True,
                errors will be thrown on missing outputs. Defaults to
                ``SETTINGS.STRICT_MODE``.

        Raises:
            RuntimeError: No build class specified in the task metadata. Only
                raised if ``strict_mode`` is True.

        Returns:
            The loaded task.
        """
        logger.debug("Magically loading task from directory: %s", src)
        logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis")

        class_name = TaskMetadata.from_directory(src).task_class

        if class_name:
            logger.debug("Loading task with task class: %s", class_name)
            task_class = get_task_class(class_name)
            return task_class.from_directory(src, strict_mode=strict_mode)
        elif strict_mode:
            msg = (
                f"No build class provided for task in {src!s}. "
                "Unable to use magic mode"
            )
            raise RuntimeError(msg)
        else:
            msg = (
                "No build class provided for task in %s. Unable to "
                "use magic mode, so a %s will be created instead."
            )
            logger.warning(msg, src, cls.__name__)

        return cls.from_directory(src=src, strict_mode=strict_mode)

    @classmethod
    def from_directory(cls, src: str | Path, **kwargs) -> Self:
        """Generate a Task document from a completed task's directory.

        Args:
            src: The directory of a completed Task.
            kwargs: Additional keyword arguments:

            - strict_mode: Whether or not to require all outputs. If True,
                errors will be thrown on missing outputs. Defaults to
                ``SETTINGS.STRICT_MODE``.
            - magic_mode: Whether or not to instantiate subclasses. If True,
                the task returned must be an instance determined by
                metadata in the directory. Defaults to False.

        Returns:
            An instance of :class:`Task` or a :class:`Task` subclass.
        """
        strict_mode = kwargs.get("strict_mode", SETTINGS.STRICT_MODE)
        magic_mode = kwargs.get("magic_mode", False)
        logger.debug("Loading task from directory: %s", src)
        logger.debug("Magic mode: %sabled", "en" if magic_mode else "dis")
        logger.debug("Strict mode: %sabled", "en" if strict_mode else "dis")

        if magic_mode:
            return cls.load_magic(src=src, strict_mode=strict_mode)

        metadata = TaskMetadata.from_directory(src=src)
        inputs = TaskInputs.from_directory(src=src, strict_mode=strict_mode)
        outputs = TaskOutputs.from_directory(src=src, strict_mode=strict_mode)
        new_task = cls(
            task_metadata=metadata, task_inputs=inputs, task_outputs=outputs
        )

        logger.debug("Successfully loaded task from directory: %s", src)
        return new_task

    def write_input_atoms(self, dest: str | Path) -> Path | None:
        """Write the input atoms to a file.

        Args:
            dest: The directory in which to write the Atoms file.

        Returns:
            The filename in which the Atoms where written.
        """
        atoms = None
        logger.debug("Writing input atoms to directory: %s", dest)
        if self.task_inputs.atoms is None:
            logger.debug("No input atoms to write")
        else:
            atoms = Path(dest, self.task_inputs.atoms_filename)
            logger.debug("Successfully wrote task metadata to file: %s", atoms)
            self.task_inputs.atoms.write(atoms)

        return atoms

    def write_inputs_json(
        self,
        dest: str | Path,
        *,
        additional_data: dict[str, Any] | None = None,
    ) -> Path:
        """Write the inputs JSON to a file.

        Args:
            dest: The directory in which to write the inputs JSON.
            additional_data: A dictionary mapping strings to JSON-serializable
                values to be merged with the task inputs that will be written
                to the inputs JSON. Defaults to an empty dictionary.

        Returns:
            The filename in which the inputs JSON written.
        """
        logger.debug("Writing input atoms to directory: %s", dest)
        additional_data = additional_data or {}
        inputs_json_data = {
            "task_inputs": self.task_inputs.model_dump(
                mode="json", exclude={"atoms"}
            ),
            **additional_data,
        }

        inputs_json = Path(dest, SETTINGS.INPUTS_FILE)
        with inputs_json.open(mode="w", encoding="utf-8") as file:
            json.dump(inputs_json_data, file, indent=4)

        logger.debug("Successfully wrote input json to file: %s", inputs_json)
        return inputs_json

    def write_metadata(self, dest: str | Path) -> Path:
        """Write the task metadata to a file.

        Args:
            dest: The directory in which to write the task metadata.

        Returns:
            The filename in which the task metadata was written.
        """
        logger.debug("Writing task metadata to directory: %s", dest)
        task_metadata = self.task_metadata.model_dump(mode="json")
        metadata = Path(dest, SETTINGS.TASK_METADATA_FILE)
        with metadata.open(mode="w", encoding="utf-8") as file:
            json.dump(task_metadata, file, indent=4)
        logger.debug("Successfully wrote task metadata to file: %s", metadata)
        return metadata

    def write_task_script(
        self,
        dest: str | Path,
        *,
        additional_data: dict[str, Any] | None = None,
    ) -> Path:
        """Write the SLURM input script using the template given.

        Args:
            dest: The directory in which to write the SLURM file.
            additional_data: A dictionary mapping strings to JSON-serializable
                values to be merged with the task inputs that will be written
                to the task script. Defaults to an empty dictionary.

        Returns:
            A Path representing the filename of the written SLURM script.
        """
        logger.debug("Writing task script to directory: %s", dest)
        env = jinja2.Environment(
            loader=get_loader(),
            autoescape=False,  # noqa: S701
            trim_blocks=True,
            lstrip_blocks=True,
            keep_trailing_newline=True,
        )
        to_render = env.get_template(self.task_inputs.task_script_template)
        filename = Path(dest, self.task_inputs.task_script)
        context = {**self.model_dump(), "settings": SETTINGS.model_dump()}
        context |= additional_data or {}

        with filename.open(mode="x", encoding="utf-8") as file:
            file.write(to_render.render(**context))

        logger.debug("Successfully wrote task script to file: %s", filename)
        return filename

    def write_inputs(
        self,
        dest: str | Path,
        **kwargs,  # noqa: ARG002
    ) -> list[Path]:
        """Write required inputs for a task to a diretory.

        Args:
            dest: The directory in which to save the task results.
            kwargs: Additional keyword arguments.

        Returns:
            A list of input files written.
        """
        logger.debug(
            "Writing %s inputs to directory: %s", self.__class__.__name__, dest
        )
        atoms = self.write_input_atoms(dest)
        input_json = self.write_inputs_json(dest)
        task_metadata = self.write_metadata(dest)
        script = self.write_task_script(dest)
        files = [atoms] if atoms else []
        files = [*files, input_json, task_metadata, script]
        logger.debug(
            "Successfully wrote %s inputs to directory: %s",
            self.__class__.__name__,
            dest,
        )
        return files
