import asyncio
import random
from typing import Optional, List, Set, Tuple, Dict, Callable
from aio_pika import connect_robust, Channel, Message
from aio_pika.abc import (
    AbstractRobustConnection, AbstractQueue, AbstractExchange, AbstractMessage
)
from aio_pika.exceptions import ChannelClosed
import aiormq.exceptions

from sycommon.logging.kafka_log import SYLogger

logger = SYLogger


class RabbitMQConnectionPool:
    """单连接RabbitMQ通道池（核心特性：依赖connect_robust原生自动重连/恢复 + 仅关闭时释放资源 + 全场景加锁）"""

    def __init__(
        self,
        hosts: List[str],
        port: int,
        username: str,
        password: str,
        virtualhost: str = "/",
        channel_pool_size: int = 1,
        heartbeat: int = 30,
        app_name: str = "",
        connection_timeout: int = 30,
        reconnect_interval: int = 30,
        prefetch_count: int = 2,
    ):
        # 基础配置校验与初始化
        self.hosts = [host.strip() for host in hosts if host.strip()]
        if not self.hosts:
            raise ValueError("至少需要提供一个RabbitMQ主机地址")

        self.port = port
        self.username = username
        self.password = password
        self.virtualhost = virtualhost
        self.app_name = app_name or "rabbitmq-client"
        self.heartbeat = heartbeat
        self.connection_timeout = connection_timeout
        self.reconnect_interval = reconnect_interval
        self.prefetch_count = prefetch_count
        self.channel_pool_size = max(1, channel_pool_size)  # 确保池大小不小于1

        # 初始化时随机选择一个主机地址（固定使用，依赖原生重连）
        self._current_host: str = random.choice(self.hosts)
        logger.info(
            f"随机选择RabbitMQ主机: {self._current_host}（依赖connect_robust原生自动重连/恢复）")

        # 核心资源（单连接+通道池，基于原生自动重连）
        self._connection: Optional[AbstractRobustConnection] = None  # 原生自动重连连接
        self._free_channels: List[Channel] = []  # 空闲通道（原生自动恢复）
        self._used_channels: Set[Channel] = set()  # 使用中通道（原生自动恢复）
        # 消费者通道跟踪
        self._consumer_channels: Dict[str,
                                      Tuple[Channel, Callable, bool, dict]] = {}

        # 状态控制（并发安全+生命周期管理）
        self._lock = asyncio.Lock()
        self._initialized = False
        self._is_shutdown = False

    async def _is_connection_valid(self) -> bool:
        """原子化检查连接有效性（仅判断是否初始化完成且未关闭）"""
        async with self._lock:
            return (
                self._initialized
                and self._connection is not None
                and not self._connection.is_closed
                and not self._is_shutdown
            )

    @property
    async def is_alive(self) -> bool:
        """对外暴露的连接存活状态"""
        if self._is_shutdown:
            return False
        return await self._is_connection_valid()

    async def _create_connection(self) -> AbstractRobustConnection:
        """创建原生自动重连连接（仅创建一次，内部自动重试）"""
        if self._is_shutdown:
            raise RuntimeError("通道池已关闭，无法创建连接")

        conn_url = f"amqp://{self.username}:{self.password}@{self._current_host}:{self.port}/{self.virtualhost}"
        logger.info(f"尝试创建原生自动重连连接: {self._current_host}:{self.port}")

        try:
            conn = await connect_robust(
                conn_url,
                properties={
                    "connection_name": f"{self.app_name}_conn",
                    "product": self.app_name
                },
                heartbeat=self.heartbeat,
                timeout=self.connection_timeout,
                reconnect_interval=self.reconnect_interval,  # 原生重连间隔
                max_reconnect_attempts=None,  # 无限重试（按需调整）
            )
            logger.info(f"连接创建成功: {self._current_host}:{self.port}（原生自动重连已启用）")
            return conn
        except Exception as e:
            logger.error(f"连接创建失败: {str(e)}", exc_info=True)
            raise ConnectionError(
                f"无法连接RabbitMQ主机 {self._current_host}:{self.port}") from e

    async def _init_channel_pool(self):
        """初始化通道池（通道自带原生自动恢复）"""
        async with self._lock:
            if self._is_shutdown:
                raise RuntimeError("通道池已关闭，无法初始化")
            if not self._connection or self._connection.is_closed:
                raise RuntimeError("无有效连接，无法初始化通道池")

            # 清理无效通道（初始化前确保池为空）
            self._free_channels.clear()
            self._used_channels.clear()

            # 创建指定数量的通道
            for i in range(self.channel_pool_size):
                try:
                    channel = await self._connection.channel()  # 通道原生自动恢复
                    await channel.set_qos(prefetch_count=self.prefetch_count)
                    self._free_channels.append(channel)
                except Exception as e:
                    logger.error(f"创建通道失败（第{i+1}个）: {str(e)}", exc_info=True)
                    continue

            logger.info(
                f"通道池初始化完成 - 可用通道数: {len(self._free_channels)}/{self.channel_pool_size} "
                f"（均带原生自动恢复）"
            )

    async def _clean_invalid_channels(self):
        """清理失效通道并补充（仅处理通道级失效，依赖原生恢复）"""
        async with self._lock:
            if self._is_shutdown or not self._connection:
                return

            # 1. 清理空闲通道（保留有效通道）
            valid_free = []
            for chan in self._free_channels:
                try:
                    if not chan.is_closed:
                        valid_free.append(chan)
                    else:
                        logger.warning(f"清理失效空闲通道（将自动补充）")
                except Exception:
                    logger.warning(f"清理异常空闲通道")
            self._free_channels = valid_free

            # 2. 清理使用中通道（保留有效通道）
            valid_used = set()
            for chan in self._used_channels:
                try:
                    if not chan.is_closed:
                        valid_used.add(chan)
                    else:
                        logger.warning(f"清理失效使用中通道")
                except Exception:
                    logger.warning(f"清理异常使用中通道")
            self._used_channels = valid_used

            # 3. 补充缺失的通道（新通道带原生自动恢复）
            total_valid = len(self._free_channels) + len(self._used_channels)
            missing = self.channel_pool_size - total_valid
            if missing > 0:
                logger.info(f"通道池缺少{missing}个通道，补充中...")
                for _ in range(missing):
                    try:
                        channel = await self._connection.channel()
                        await channel.set_qos(prefetch_count=self.prefetch_count)
                        self._free_channels.append(channel)
                    except Exception as e:
                        logger.error(f"补充通道失败: {str(e)}", exc_info=True)
                        break

    async def init_pools(self):
        """初始化通道池（仅执行一次）"""
        async with self._lock:
            if self._initialized:
                logger.warning("通道池已初始化，无需重复调用")
                return
            if self._is_shutdown:
                raise RuntimeError("通道池已关闭，无法初始化")

        try:
            # 1. 创建原生自动重连连接
            self._connection = await self._create_connection()

            # 2. 初始化通道池
            await self._init_channel_pool()

            # 3. 标记为已初始化
            async with self._lock:
                self._initialized = True

            logger.info("RabbitMQ通道池初始化完成（原生自动重连/恢复已启用）")
        except Exception as e:
            logger.error(f"初始化失败: {str(e)}", exc_info=True)
            await self.close()  # 初始化失败直接关闭
            raise

    async def acquire_channel(self) -> Tuple[Channel, AbstractRobustConnection]:
        """获取通道（返回 (通道, 连接) 元组，保持API兼容）"""
        # 快速校验状态
        async with self._lock:
            if not self._initialized:
                raise RuntimeError("通道池未初始化，请先调用init_pools()")
            if self._is_shutdown:
                raise RuntimeError("通道池已关闭，无法获取通道")
            if not self._connection or self._connection.is_closed:
                raise RuntimeError("连接已关闭（等待原生重连）")

        # 清理失效通道
        await self._clean_invalid_channels()

        async with self._lock:
            # 优先从空闲池获取
            if self._free_channels:
                channel = self._free_channels.pop()
                self._used_channels.add(channel)
                return channel, self._connection  # 返回 (通道, 连接) 元组

            # 池满时创建临时通道（用完自动关闭）
            try:
                channel = await self._connection.channel()
                await channel.set_qos(prefetch_count=self.prefetch_count)
                self._used_channels.add(channel)
                logger.warning(
                    f"通道池已达上限（{self.channel_pool_size}），创建临时通道（用完自动关闭）"
                )
                return channel, self._connection  # 返回 (通道, 连接) 元组
            except Exception as e:
                logger.error(f"获取通道失败: {str(e)}", exc_info=True)
                raise

    async def release_channel(self, channel: Channel, conn: AbstractRobustConnection):
        """释放通道（接收通道和连接参数，保持API兼容）"""
        if not channel or not conn or self._is_shutdown:
            return

        async with self._lock:
            # 通道不在使用中，直接返回
            if channel not in self._used_channels:
                return

            # 移除使用中标记
            self._used_channels.remove(channel)

            # 仅归还有效通道（通道未关闭+池未满+连接匹配）
            if (not channel.is_closed
                and len(self._free_channels) < self.channel_pool_size
                    and conn == self._connection):  # 确保是当前连接的通道
                self._free_channels.append(channel)
            else:
                # 无效通道直接关闭（临时通道、池满或连接不匹配时）
                try:
                    if not channel.is_closed:
                        await channel.close()
                except Exception as e:
                    logger.warning(f"关闭通道失败: {str(e)}")

    async def declare_queue(self, queue_name: str, **kwargs) -> AbstractQueue:
        """声明队列（使用池内通道）"""
        channel, conn = await self.acquire_channel()
        try:
            return await channel.declare_queue(queue_name, **kwargs)
        finally:
            await self.release_channel(channel, conn)

    async def declare_exchange(self, exchange_name: str, exchange_type: str = "direct", **kwargs) -> AbstractExchange:
        """声明交换机（使用池内通道）"""
        channel, conn = await self.acquire_channel()
        try:
            return await channel.declare_exchange(exchange_name, exchange_type, **kwargs)
        finally:
            await self.release_channel(channel, conn)

    async def publish_message(self, routing_key: str, message_body: bytes, exchange_name: str = "", **kwargs):
        """发布消息（依赖原生自动重连/恢复）"""
        channel, conn = await self.acquire_channel()
        try:
            exchange = channel.default_exchange if not exchange_name else await channel.get_exchange(exchange_name)
            message = Message(body=message_body, **kwargs)
            await exchange.publish(message, routing_key=routing_key)
            logger.debug(
                f"消息发布成功 - 交换机: {exchange.name}, 路由键: {routing_key}"
            )
        except Exception as e:
            logger.error(f"发布消息失败: {str(e)}", exc_info=True)
            raise  # 原生会自动重连，无需手动处理
        finally:
            await self.release_channel(channel, conn)

    async def consume_queue(self, queue_name: str, callback: Callable[[AbstractMessage], asyncio.Future], auto_ack: bool = False, **kwargs):
        """消费队列（独立通道，带原生自动恢复）"""
        # 快速校验状态
        async with self._lock:
            if not self._initialized:
                raise RuntimeError("通道池未初始化，请先调用init_pools()")
            if self._is_shutdown:
                raise RuntimeError("通道池已关闭，无法启动消费")

        # 先声明队列（确保队列存在）
        await self.declare_queue(queue_name, **kwargs)

        # 获取独立通道（不放入普通池）
        channel, conn = await self.acquire_channel()

        # 注册消费者通道
        async with self._lock:
            self._consumer_channels[queue_name] = (
                channel, callback, auto_ack, kwargs)

        async def consume_callback_wrapper(message: AbstractMessage):
            """消费回调包装（处理通道失效，依赖原生恢复）"""
            try:
                # 校验通道和连接状态
                async with self._lock:
                    channel_valid = not channel.is_closed
                    conn_valid = self._connection and not self._connection.is_closed
                    conn_matched = conn == self._connection

                if not channel_valid or not conn_valid or not conn_matched:
                    logger.warning(f"消费者通道 {queue_name} 失效（等待原生自动恢复）")
                    if not auto_ack:
                        await message.nack(requeue=True)
                    return

                # 执行业务回调
                await callback(message)
                if not auto_ack:
                    await message.ack()
            except ChannelClosed as e:
                logger.error(f"消费者通道 {queue_name} 关闭: {str(e)}", exc_info=True)
                if not auto_ack:
                    await message.nack(requeue=True)
            except aiormq.exceptions.ChannelInvalidStateError as e:
                logger.error(
                    f"消费者通道 {queue_name} 状态异常: {str(e)}", exc_info=True)
                if not auto_ack:
                    await message.nack(requeue=True)
            except Exception as e:
                logger.error(
                    f"消费消息失败（队列: {queue_name}）: {str(e)}", exc_info=True)
                if not auto_ack:
                    await message.nack(requeue=True)

        logger.info(f"开始消费队列: {queue_name}（通道带原生自动恢复）")

        try:
            await channel.basic_consume(
                queue_name,
                consumer_callback=consume_callback_wrapper,
                auto_ack=auto_ack,
                **kwargs
            )
        except Exception as e:
            logger.error(f"启动消费失败（队列: {queue_name}）: {str(e)}", exc_info=True)
            # 清理异常资源
            await self.release_channel(channel, conn)
            async with self._lock:
                if self._consumer_channels.get(queue_name) == (channel, callback, auto_ack, kwargs):
                    del self._consumer_channels[queue_name]
            raise

    async def close(self):
        """关闭通道池（仅主动关闭时释放所有资源）"""
        async with self._lock:
            if self._is_shutdown:
                logger.warning("通道池已关闭，无需重复操作")
                return
            self._is_shutdown = True
            self._initialized = False

        logger.info("开始关闭RabbitMQ通道池（释放所有资源）...")

        # 1. 关闭所有消费者通道
        async with self._lock:
            consumer_channels = list(self._consumer_channels.values())
            self._consumer_channels.clear()
        for channel, _, _, _ in consumer_channels:
            try:
                if not channel.is_closed:
                    await channel.close()
            except Exception as e:
                logger.warning(f"关闭消费者通道失败: {str(e)}")

        # 2. 关闭所有普通通道
        async with self._lock:
            all_channels = self._free_channels + list(self._used_channels)
            self._free_channels.clear()
            self._used_channels.clear()
        for channel in all_channels:
            try:
                if not channel.is_closed:
                    await channel.close()
            except Exception as e:
                logger.warning(f"关闭通道失败: {str(e)}")

        # 3. 关闭连接（终止原生自动重连）
        if self._connection:
            try:
                if not self._connection.is_closed:
                    await self._connection.close()
                logger.info(
                    f"已关闭连接: {self._current_host}:{self.port}（终止原生自动重连）")
            except Exception as e:
                logger.warning(f"关闭连接失败: {str(e)}")
            self._connection = None

        logger.info("RabbitMQ通道池已完全关闭")
