from __future__ import annotations

import json
import os
from typing import Any

from pydantic import BaseModel, TypeAdapter
from slimschema import to_data, to_schema  # type: ignore[import-untyped]

from .events import ErrorEvent, OpenCodeEvent, TimeInfo
from .helper import unix_ms
from .logging_config import get_logger
from .request import Request
from .response import Response
from .usage import compute_usage

logger = get_logger("structured")


def _to_slimschema_yaml(
    response_format: str | type[BaseModel] | TypeAdapter[Any],
) -> str:
    """Convert a response format into token-efficient SlimSchema YAML."""

    if isinstance(response_format, str):
        return response_format

    schema = to_schema(response_format)
    return str(schema)


def _dump_json(value: Any) -> str:
    if isinstance(value, BaseModel):
        return value.model_dump_json(indent=2)
    return json.dumps(value, indent=2)


def _resolve_output_path(
    workdir: str | None, session: str | None, override: str | None
) -> str:
    wd = workdir or os.getcwd()
    if override:
        target = override
    else:
        sid = session or "default"
        safe = "".join(c if c.isalnum() else "_" for c in sid)
        target = os.path.join(
            wd, ".innerloop", f"structured_{safe}_{unix_ms()}.json"
        )

    abs_path = os.path.abspath(target)
    abs_wd = os.path.abspath(wd)
    if not abs_wd.endswith(os.sep):
        abs_wd += os.sep
    if not abs_path.startswith(abs_wd):
        raise ValueError(
            "Structured output file must be under workdir. "
            f"File: {abs_path}, Workdir: {abs_wd.rstrip(os.sep)}"
        )

    os.makedirs(os.path.dirname(abs_path), exist_ok=True)
    return abs_path


def _extract_and_parse_json(
    text: str, validation_format: str | type[BaseModel] | TypeAdapter[Any]
) -> Any:
    """Extract and validate JSON using SlimSchema's parser/validator."""

    data, error = to_data(text, validation_format)
    if error is not None:
        raise ValueError(error)
    return data


def _write_structured_output(path: str, value: Any) -> None:
    """Persist validated structured output for downstream consumers."""

    try:
        with open(path, "w") as f:
            f.write(_dump_json(value))
    except OSError as exc:
        logger.warning(
            "Failed to write structured output to %s: %s", path, exc
        )


async def _attempt(
    request: Request,
    *,
    prompt: str,
    abs_path: str,
    validation_format: str | type[BaseModel] | TypeAdapter[Any],
    session: str | None,
    timeout: float | None = None,
) -> tuple[Response[Any], Any | None, str | None]:
    from .invoke import async_invoke  # local import to avoid cycles

    resp = await async_invoke(
        Request(
            model=request.model,
            prompt=prompt,
            permission=request.permission,
            providers=request.providers,
            mcp=request.mcp,
            response_format=None,
            session=session,
            workdir=request.workdir,
        ),
        timeout=timeout,
    )

    try:
        out = _extract_and_parse_json(str(resp.output), validation_format)
    except ValueError as exc:
        return resp, None, str(exc)

    _write_structured_output(abs_path, out)
    return resp, out, None


async def invoke_structured(
    request: Request,
    *,
    max_retries: int = 3,
    timeout: float | None = None,
) -> Response[Any]:
    """Structured output invocation with SlimSchema-only validation."""

    if request.response_format is None:
        from .invoke import async_invoke  # local import

        return await async_invoke(request, timeout=timeout)

    schema_yaml = _to_slimschema_yaml(request.response_format)
    validation_format: str | type[BaseModel] | TypeAdapter[Any]
    if isinstance(request.response_format, TypeAdapter):
        validation_format = schema_yaml
    else:
        validation_format = request.response_format

    abs_path = _resolve_output_path(
        request.workdir, request.session, request.output_file
    )

    logger.info(
        "Starting structured invocation (max_retries=%s, model=%s)",
        max_retries,
        request.model,
    )
    logger.debug("Output file: %s", abs_path)
    logger.debug(
        "Response format: %s",
        getattr(request.response_format, "__name__", "SlimSchema"),
    )

    attempts = 0
    total_events: list[OpenCodeEvent] = []
    final_session: str | None = request.session
    final_output: BaseModel | dict[str, Any] | None = None
    wall_start = unix_ms()
    last_err = "invalid"

    while attempts < max_retries:
        attempts += 1
        logger.info("Attempt %s/%s", attempts, max_retries)

        from .prompt_context import PromptContext
        from .prompt_renderer import render_prompt

        ctx = PromptContext(
            user_prompt=request.prompt,
            permissions=request.permission,
            response_format=request.response_format,
            output_file=abs_path,
            attempt=attempts,
            validation_errors=[last_err] if attempts > 1 else [],
        )
        # Avoid re-rendering SlimSchema for every attempt.
        ctx._slimschema_yaml = schema_yaml
        prompt = render_prompt(ctx)

        if attempts > 1:
            logger.warning(
                "Retrying structured output validation (attempt %s/%s) - Previous validation error: %s",
                attempts,
                max_retries,
                last_err,
            )

        resp, out, err = await _attempt(
            request,
            prompt=prompt,
            abs_path=abs_path,
            validation_format=validation_format,
            session=final_session,
            timeout=timeout,
        )
        total_events.extend(resp.events)
        final_session = resp.session_id

        if out is not None:
            final_output = out
            logger.info(
                "Structured output validation successful on attempt %s",
                attempts,
            )
            break

        last_err = err or last_err
        logger.warning(
            "Validation failed on attempt %s: %s", attempts, last_err
        )

        total_events.append(
            ErrorEvent(
                timestamp=unix_ms(),
                sessionID=final_session or "",
                type="error",
                message=last_err,
                code=None,
                severity="error",
            )
        )

    if final_output is None:
        logger.error(
            "Structured output validation failed after %s attempts. Final error: %s",
            attempts,
            last_err,
        )
        raise RuntimeError(
            "Structured output validation failed after "
            f"{attempts} attempts. Expected file: {abs_path}\nError: {last_err}"
        )

    wall_end = unix_ms()
    logger.info(
        "Structured invocation completed successfully after %s attempt(s)",
        attempts,
    )
    logger.debug(
        "Total events: %s, Session: %s", len(total_events), final_session
    )

    out_resp = Response(
        session_id=final_session or "",
        input=request.prompt,
        output=final_output,
        structured_output_file=abs_path,
        events=total_events,
        attempts=attempts,
        time=TimeInfo(start=wall_start, end=wall_end),
    )
    out_resp.usage = compute_usage(total_events)
    return out_resp
