"""
Structured Output

Tool-based structured output using Pydantic models.
Forces the model to call a 'respond' tool with validated schema.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, TypeVar

from pydantic import BaseModel, ValidationError

from .types import Response, Tool

if TYPE_CHECKING:
    from .api import Loop

T = TypeVar("T", bound=BaseModel)


class ResponseTool(Tool):
    """
    Special tool for structured output.

    The model is forced to call this tool (via tool_choice).
    The tool's input is validated against the Pydantic schema.
    """

    _output_type: type[BaseModel]

    def __init__(self, output_type: type[BaseModel]):
        # Get schema from Pydantic model
        schema = output_type.model_json_schema()

        # Inline $defs if present (simpler schema)
        if "$defs" in schema:
            schema = _inline_defs(schema)

        super().__init__(
            name="respond",
            description=f"Submit your final response as {output_type.__name__}",
            input_schema=schema,
        )

        object.__setattr__(self, "_output_type", output_type)

    async def execute(self, input: dict[str, Any]) -> tuple[str, bool]:
        """
        Validate and return the structured output.

        Returns:
            ("Success", False) if valid
            (error_message, True) if validation fails
        """
        try:
            self._output_type.model_validate(input)
            return "Success", False
        except ValidationError as e:
            error_msg = f"Validation error: {e}. Fix the errors and call respond again."
            return error_msg, True


def create_response_tool(output_type: type[BaseModel]) -> ResponseTool:
    """Create a response tool from a Pydantic model."""
    return ResponseTool(output_type=output_type)


async def run_structured(
    loop: Loop,
    prompt: str,
    output_type: type[T],
    max_retries: int = 3,
) -> T:
    """
    Run a prompt and return validated structured output.

    Uses tool_choice to force the model to call a 'respond' tool.
    On validation failure, retries with error feedback.

    Args:
        loop: Loop instance to use
        prompt: User prompt
        output_type: Pydantic model class for output
        max_retries: Max validation retry attempts

    Returns:
        Validated instance of output_type

    Raises:
        ValueError: If validation fails after max_retries
    """
    # Create respond tool
    respond_tool = create_response_tool(output_type)

    # Add respond to loop's tools temporarily
    original_tools = loop.tools
    loop.tools = [*original_tools, respond_tool]

    try:
        for attempt in range(max_retries):
            # Use empty prompt on retries (session has context)
            current_prompt = prompt if attempt == 0 else ""

            # Run with tool_choice forcing respond
            response = await _run_with_tool_choice(loop, current_prompt)

            # Find the respond tool call
            for tr in response.tool_results:
                if tr.tool_name == "respond":
                    # If validation passed, return the validated object
                    if not tr.is_error:
                        return output_type.model_validate(tr.input)

                    # Validation failed - error is already in session
                    # Next iteration will retry with that context
                    break

        # All retries exhausted
        raise ValueError(
            f"Structured output failed after {max_retries} attempts"
        )

    finally:
        # Restore original tools
        loop.tools = original_tools


async def _run_with_tool_choice(loop: Loop, prompt: str) -> Response:
    """Run loop with tool_choice forcing respond tool."""
    from .loop import execute as loop_execute
    from .types import UserMessage

    # Add user message if prompt provided
    if prompt:
        user_msg = UserMessage(content=prompt)
        loop.messages.append(user_msg)
        loop._save_message(user_msg)

    # Build config with tool_choice
    config = loop._build_config()

    # Execute with tool_choice
    updated_messages, response = await loop_execute(
        provider=loop._provider,
        messages=loop.messages,
        tools=loop.tools,
        config=config,
        tool_choice={"type": "tool", "name": "respond"},
    )

    # Update messages and save new ones
    new_messages = updated_messages[len(loop.messages) :]
    for msg in new_messages:
        loop._save_message(msg)
    loop.messages = updated_messages

    # Set session ID
    response.session_id = loop.session_id

    return response


def _inline_defs(schema: dict[str, Any]) -> dict[str, Any]:
    """Inline $defs references for simpler schema.

    This is a simplified version - just removes $defs for now.
    Full implementation would recursively resolve $ref pointers.
    """
    result = dict(schema)
    result.pop("$defs", None)
    return result


__all__ = [
    "ResponseTool",
    "create_response_tool",
    "run_structured",
]
