from __future__ import annotations

import asyncio
import inspect
import uuid
from typing import Callable, Iterable, Optional, Sequence, Tuple

import grpc
from grpc import aio as grpc_aio

from .config import SDKOptions, TelemetryConfig

MetadataSequence = Optional[Sequence[Tuple[str, str]]]


def _append_metadata(existing: MetadataSequence, additions: Iterable[Tuple[str, str]]) -> Sequence[Tuple[str, str]]:
    merged = []
    if existing:
        merged.extend(existing)
    merged.extend(additions)
    return merged


def _has_header(metadata: MetadataSequence, key: str) -> bool:
    if not metadata:
        return False
    key_lower = key.lower()
    return any(existing_key.lower() == key_lower for existing_key, _ in metadata)


def _metadata_additions(options: SDKOptions, existing: MetadataSequence, request_id_provider) -> list[Tuple[str, str]]:
    additions: list[Tuple[str, str]] = []
    if options.tenant_id:
        additions.append(("tenant-id", options.tenant_id))
    if options.user_id:
        additions.append(("user-id", options.user_id))
    for key, value in options.metadata.items():
        additions.append((key.lower(), value))
    if not _has_header(existing, "x-request-id"):
        additions.append(("x-request-id", request_id_provider()))
    return additions


class _ClientCallDetails(grpc.ClientCallDetails):
    def __init__(
        self,
        method: str,
        timeout: Optional[float],
        metadata: MetadataSequence,
        credentials,
        wait_for_ready: Optional[bool],
        compression,
    ):
        self.method = method
        self.timeout = timeout
        self.metadata = metadata
        self.credentials = credentials
        self.wait_for_ready = wait_for_ready
        self.compression = compression


class _AsyncClientCallDetails(grpc_aio.ClientCallDetails):
    def __init__(
        self,
        method: str,
        timeout: Optional[float],
        metadata: MetadataSequence,
        credentials,
        wait_for_ready: Optional[bool],
        compression,
    ):
        self.method = method
        self.timeout = timeout
        self.metadata = metadata
        self.credentials = credentials
        self.wait_for_ready = wait_for_ready
        self.compression = compression


def metadata_interceptor(options: SDKOptions, *, async_mode: bool = False):
    request_id_provider = options.request_id_provider or (lambda: str(uuid.uuid4()))

    if async_mode:
        class AsyncMetadataInterceptor(grpc_aio.UnaryUnaryClientInterceptor):
            async def intercept_unary_unary(self, continuation, client_call_details, request):
                additions = _metadata_additions(options, client_call_details.metadata, request_id_provider)
                details = _AsyncClientCallDetails(
                    method=client_call_details.method,
                    timeout=client_call_details.timeout,
                    metadata=_append_metadata(client_call_details.metadata, additions),
                    credentials=client_call_details.credentials,
                    wait_for_ready=getattr(client_call_details, "wait_for_ready", None),
                    compression=getattr(client_call_details, "compression", None),
                )
                return await continuation(details, request)

        return AsyncMetadataInterceptor()

    class SyncMetadataInterceptor(grpc.UnaryUnaryClientInterceptor):
        def intercept_unary_unary(self, continuation, client_call_details, request):
            additions = _metadata_additions(options, client_call_details.metadata, request_id_provider)
            details = _ClientCallDetails(
                method=client_call_details.method,
                timeout=client_call_details.timeout,
                metadata=_append_metadata(client_call_details.metadata, additions),
                credentials=client_call_details.credentials,
                wait_for_ready=getattr(client_call_details, "wait_for_ready", None),
                compression=getattr(client_call_details, "compression", None),
            )
            return continuation(details, request)

    return SyncMetadataInterceptor()


async def _resolve_token(provider: Callable[[], Optional[str]]):
    token = provider()
    if inspect.isawaitable(token):
        return await token
    return token


def auth_interceptor(
    provider: Optional[Callable[[], Optional[str]]],
    *,
    async_mode: bool = False,
):
    if provider is None:
        return None

    if async_mode:
        class AsyncAuthInterceptor(grpc_aio.UnaryUnaryClientInterceptor):
            async def intercept_unary_unary(self, continuation, client_call_details, request):
                token = await _resolve_token(provider)
                additions = [("authorization", f"Bearer {token}")] if token else []
                details = _AsyncClientCallDetails(
                    method=client_call_details.method,
                    timeout=client_call_details.timeout,
                    metadata=_append_metadata(client_call_details.metadata, additions),
                    credentials=client_call_details.credentials,
                    wait_for_ready=getattr(client_call_details, "wait_for_ready", None),
                    compression=getattr(client_call_details, "compression", None),
                )
                return await continuation(details, request)

        return AsyncAuthInterceptor()

    class SyncAuthInterceptor(grpc.UnaryUnaryClientInterceptor):
        def intercept_unary_unary(self, continuation, client_call_details, request):
            token = provider()
            if inspect.isawaitable(token):
                try:
                    loop = asyncio.get_running_loop()
                    token_resolved = asyncio.run_coroutine_threadsafe(token, loop).result()
                except RuntimeError:
                    token_resolved = asyncio.run(token)
                token_value = token_resolved
            else:
                token_value = token
            additions = [("authorization", f"Bearer {token_value}")] if token_value else []
            details = _ClientCallDetails(
                method=client_call_details.method,
                timeout=client_call_details.timeout,
                metadata=_append_metadata(client_call_details.metadata, additions),
                credentials=client_call_details.credentials,
                wait_for_ready=getattr(client_call_details, "wait_for_ready", None),
                compression=getattr(client_call_details, "compression", None),
            )
            return continuation(details, request)

    return SyncAuthInterceptor()


def timeout_interceptor(timeout: Optional[float], *, async_mode: bool = False):
    if timeout is None or timeout <= 0:
        return None

    if async_mode:
        class AsyncTimeoutInterceptor(grpc_aio.UnaryUnaryClientInterceptor):
            async def intercept_unary_unary(self, continuation, client_call_details, request):
                current_timeout = client_call_details.timeout
                effective_timeout = timeout if not current_timeout else min(current_timeout, timeout)
                details = _AsyncClientCallDetails(
                    method=client_call_details.method,
                    timeout=effective_timeout,
                    metadata=client_call_details.metadata,
                    credentials=client_call_details.credentials,
                    wait_for_ready=getattr(client_call_details, "wait_for_ready", None),
                    compression=getattr(client_call_details, "compression", None),
                )
                return await continuation(details, request)

        return AsyncTimeoutInterceptor()

    class SyncTimeoutInterceptor(grpc.UnaryUnaryClientInterceptor):
        def intercept_unary_unary(self, continuation, client_call_details, request):
            current_timeout = client_call_details.timeout
            effective_timeout = timeout if not current_timeout else min(current_timeout, timeout)
            details = _ClientCallDetails(
                method=client_call_details.method,
                timeout=effective_timeout,
                metadata=client_call_details.metadata,
                credentials=client_call_details.credentials,
                wait_for_ready=getattr(client_call_details, "wait_for_ready", None),
                compression=getattr(client_call_details, "compression", None),
            )
            return continuation(details, request)

    return SyncTimeoutInterceptor()


def tracing_interceptor(
    telemetry: Optional[TelemetryConfig],
    *,
    async_mode: bool = False,
):
    if telemetry is None:
        return None

    if async_mode:
        class AsyncTracingInterceptor(grpc_aio.UnaryUnaryClientInterceptor):
            async def intercept_unary_unary(self, continuation, client_call_details, request):
                tracer = telemetry.tracer
                if tracer is None:
                    return await continuation(client_call_details, request)
                try:
                    from opentelemetry.trace import Status, StatusCode
                except ModuleNotFoundError as exc:
                    raise RuntimeError(
                        "Tracing requires the 'opentelemetry-api' package; install it to enable telemetry.",
                    ) from exc

                span_name = telemetry.span_name or client_call_details.method.rsplit("/", 1)[-1]
                attributes = dict(telemetry.attributes) if telemetry.attributes else {}
                method_path = client_call_details.method or ""
                attributes.setdefault("rpc.system", "grpc")
                attributes.setdefault("rpc.method", method_path if method_path else client_call_details.method)
                service_name = ""
                if "/" in method_path:
                    parts = method_path.split("/")
                    if len(parts) > 1:
                        service_name = parts[1]
                if service_name:
                    attributes.setdefault("rpc.service", service_name)

                with tracer.start_as_current_span(span_name) as span:
                    for key, value in attributes.items():
                        span.set_attribute(key, value)
                    try:
                        response = await continuation(client_call_details, request)
                        span.set_status(Status(status_code=StatusCode.OK))
                        return response
                    except Exception as exc:  # pragma: no cover
                        span.record_exception(exc)
                        span.set_status(Status(status_code=StatusCode.ERROR, description=str(exc)))
                        raise

        return AsyncTracingInterceptor()

    class SyncTracingInterceptor(grpc.UnaryUnaryClientInterceptor):
        def __init__(self):
            self._telemetry = telemetry

        def intercept_unary_unary(self, continuation, client_call_details, request):
            tracer = self._telemetry.tracer
            if tracer is None:
                return continuation(client_call_details, request)
            try:
                from opentelemetry.trace import Status, StatusCode
            except ModuleNotFoundError as exc:
                raise RuntimeError(
                    "Tracing requires the 'opentelemetry-api' package; install it to enable telemetry.",
                ) from exc

            span_name = self._telemetry.span_name or client_call_details.method.rsplit("/", 1)[-1]
            attributes = dict(self._telemetry.attributes) if self._telemetry.attributes else {}
            method_path = client_call_details.method or ""
            attributes.setdefault("rpc.system", "grpc")
            attributes.setdefault("rpc.method", method_path if method_path else client_call_details.method)
            service_name = ""
            if "/" in method_path:
                parts = method_path.split("/")
                if len(parts) > 1:
                    service_name = parts[1]
            if service_name:
                attributes.setdefault("rpc.service", service_name)

            with tracer.start_as_current_span(span_name) as span:
                for key, value in attributes.items():
                    span.set_attribute(key, value)
                try:
                    response = continuation(client_call_details, request)
                    span.set_status(Status(status_code=StatusCode.OK))
                    return response
                except Exception as exc:  # pragma: no cover - exercised in integration
                    span.record_exception(exc)
                    span.set_status(Status(status_code=StatusCode.ERROR, description=str(exc)))
                    raise

    return SyncTracingInterceptor()
