"""MCP tool decorator for standardized request/response logging."""

import functools
from collections.abc import Callable
from typing import Any

from quantum_code.utils.context import clear_context, set_request_context
from quantum_code.utils.mcp_logger import log_mcp_interaction


def mcp_monitor(func: Callable | None = None, *, tool_name: str | None = None) -> Callable:
    """Decorator that wraps MCP tools with request/response logging.

    Args:
        func: The function to wrap (when used without parentheses)
        tool_name: Name of the tool for logging (defaults to function name)

    Usage:
        @mcp.tool()
        @mcp_monitor
        async def codereview(...) -> dict:
            return await codereview_impl(**locals())

        # Or with explicit name:
        @mcp.tool()
        @mcp_monitor(tool_name="custom_name")
        async def codereview(...) -> dict:
            return await codereview_impl(**locals())

    Note:
        thread_id is auto-generated in mcp_factory.py before reaching the decorator,
        so all tools receive a valid thread_id (never None).
    """

    def decorator(fn: Callable) -> Callable:
        name = tool_name or fn.__name__

        @functools.wraps(fn)
        async def wrapper(*args: Any, **kwargs: Any) -> dict:
            # Extract thread_id from kwargs (already generated by mcp_factory)
            thread_id = kwargs.get("thread_id")

            # Set request context at entry point (for logging throughout the call stack)
            set_request_context(
                thread_id=thread_id,
                workflow=name,
                step_number=kwargs.get("step_number"),
                base_path=kwargs.get("base_path"),
                name=kwargs.get("name"),
            )

            try:
                # Log request (uses context for thread_id)
                log_mcp_interaction(
                    direction="request",
                    tool_name=name,
                    data=kwargs,
                )

                result = await fn(*args, **kwargs)

                log_mcp_interaction(
                    direction="response",
                    tool_name=name,
                    data=result,
                )

                return result

            except Exception as e:
                log_mcp_interaction(
                    direction="error",
                    tool_name=name,
                    data={"error": str(e), "type": type(e).__name__},
                )
                raise

            finally:
                # Always clear context after request completes
                clear_context()

        return wrapper

    # Support both @mcp_monitor and @mcp_monitor(tool_name="...")
    if func is not None:
        return decorator(func)
    return decorator
