"""Pydantic schema utilities."""

from collections.abc import Callable
import json
from typing import Annotated
from typing import Any
from uuid import uuid4

from ase import Atoms
from ase.io.jsonio import decode
from ase.io.jsonio import encode
from pydantic import UUID4
from pydantic import BaseModel
from pydantic import FieldSerializationInfo
from pydantic import GetCoreSchemaHandler
from pydantic import GetJsonSchemaHandler
from pydantic import SerializerFunctionWrapHandler
from pydantic import ValidationError
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import PydanticUndefined
from pydantic_core import core_schema
from shortuuid import uuid

from autojob import SETTINGS

Unset: Any = PydanticUndefined


def atoms_as_dict(s: Atoms) -> dict[str, str]:
    """Represent an :class:`ase.atoms.Atoms` object as a dictionary."""
    # Uses Monty's MSONable spec
    # Normally, we would want to this to be a wrapper around
    # atoms.todict() with @module and
    # @class key-value pairs inserted. However,
    # atoms.todict()/atoms.fromdict() does not currently
    # work properly with constraints.
    return {"@module": "ase.atoms", "@class": "Atoms", "atoms_json": encode(s)}


def atoms_from_dict(d: dict[str, Any]) -> Atoms:
    """Instantiate an :class:`ase.atoms.Atoms` object from a dictionary."""
    # Uses Monty's MSONable spec
    # Normally, we would want to have this be a wrapper around atoms.fromdict()
    # that just ignores the @module/@class key-value pairs. However,
    # atoms.todict()/atoms.fromdict()
    # does not currently work properly with constraints.
    return decode(d["atoms_json"])


class AtomsAnnotation(BaseModel):
    """The Pydantic-compatible annotation for an `ase.atoms.Atoms` object."""

    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source_type: Any,
        _handler: GetCoreSchemaHandler,
    ) -> core_schema.CoreSchema:
        """Return a pydantic_core.CoreSchema.

        The schema behaves in the following ways:

        `Atoms` instances will be parsed as `Atoms` instances without any
            changes
        Everything else will be validated according to AtomsAnnotation
        Serialization is done by atoms.as_dict()
        """

        def validate_from_dict(value: dict[str, Any] | None) -> Atoms | None:
            if value is None:
                return value

            try:
                return atoms_from_dict(value)
            except json.JSONDecodeError as err:
                msg = "Unable to convert 'atoms' value to Atoms object"
                raise ValueError(msg) from err

        from_dict_schema = core_schema.chain_schema(
            [
                core_schema.dict_schema(),
                core_schema.no_info_plain_validator_function(
                    validate_from_dict
                ),
            ]
        )

        def serialize_atoms(
            v: Any,
            serializer: SerializerFunctionWrapHandler,
            info: FieldSerializationInfo,
        ) -> Any:
            if info.mode == "python":
                return serializer(v)
            if isinstance(v, Atoms):
                v.calc = None
                return atoms_as_dict(v)
            msg = f"Unable to serialize atoms: {v}"
            raise ValidationError(msg)

        return core_schema.json_or_python_schema(
            json_schema=from_dict_schema,
            python_schema=core_schema.union_schema(
                [
                    core_schema.is_instance_schema(Atoms),
                    from_dict_schema,
                ]
            ),
            serialization=core_schema.wrap_serializer_function_ser_schema(
                serialize_atoms, info_arg=True
            ),
        )

    @classmethod
    def __get_pydantic_json_schema__(
        cls,
        _core_schema: core_schema.CoreSchema,
        handler: GetJsonSchemaHandler,
    ) -> JsonSchemaValue:
        """Get the Pydantic JSON schema."""
        # Use the same schema that would be used for `dict`
        return handler(core_schema.dict_schema())


PydanticAtoms = Annotated[Atoms, AtomsAnnotation]


def hyphenate(v: str) -> str:
    """Replace underscores with hyphens."""
    return v.replace("_", "-")


def space_capitalize(v: str) -> str:
    """Replace underscores with spaces and capitalize each word."""
    return v.replace("_", " ").title()


def id_factory(prefix: str) -> Callable[[], UUID4 | str]:
    """Create a default ID factory based on whether legacy mode is enabled."""

    def func() -> UUID4 | str:
        return f"{prefix}{uuid()[:9]}" if SETTINGS.LEGACY_MODE else uuid4()

    return func
