import asyncio
from typing import Optional, List, Set, Iterator, Tuple, Dict, Callable
from aio_pika import connect_robust, Channel, Message
from aio_pika.abc import (
    AbstractRobustConnection, AbstractQueue, AbstractExchange, AbstractMessage
)
from aio_pika.exceptions import ChannelClosed
import aiormq.exceptions

from sycommon.logging.kafka_log import SYLogger

logger = SYLogger


class RabbitMQConnectionPool:
    """单连接RabbitMQ通道池（核心特性：严格单连接+重连释放旧资源+新连接保留自动恢复+全场景加锁）"""

    def __init__(
        self,
        hosts: List[str],
        port: int,
        username: str,
        password: str,
        virtualhost: str = "/",
        channel_pool_size: int = 1,
        heartbeat: int = 30,
        app_name: str = "",
        connection_timeout: int = 30,
        reconnect_interval: int = 30,
        prefetch_count: int = 2,
    ):
        # 基础配置校验与初始化
        self.hosts = [host.strip() for host in hosts if host.strip()]
        if not self.hosts:
            raise ValueError("至少需要提供一个RabbitMQ主机地址")

        self.port = port
        self.username = username
        self.password = password
        self.virtualhost = virtualhost
        self.app_name = app_name or "rabbitmq-client"
        self.heartbeat = heartbeat
        self.connection_timeout = connection_timeout
        self.reconnect_interval = reconnect_interval
        self.prefetch_count = prefetch_count
        self.channel_pool_size = max(1, channel_pool_size)  # 确保池大小不小于1

        # 节点轮询（仅重连时切换）
        self._host_iterator: Iterator[str] = self._create_host_iterator()
        self._current_host: Optional[str] = None

        # 核心资源（严格单连接+通道池，仅绑定当前活跃连接）
        self._connection: Optional[AbstractRobustConnection] = None  # 唯一活跃连接
        self._free_channels: List[Channel] = []  # 当前连接的空闲通道（带自动恢复）
        self._used_channels: Set[Channel] = set()  # 当前连接的使用中通道（带自动恢复）
        self._consumer_channels: Dict[str, Tuple[Channel,
                                                 AbstractRobustConnection, Callable, bool, dict]] = {}  # 消费者通道跟踪

        # 状态控制（确保并发安全和单连接）
        self._lock = asyncio.Lock()  # 全局唯一锁，保护所有共享状态
        self._initialized = False
        self._is_shutdown = False
        self._reconnecting = False  # 避免并发重连
        self._connection_version = 0  # 连接版本号（区分新旧连接/通道）

    def _create_host_iterator(self) -> Iterator[str]:
        """创建无限循环的节点轮询迭代器"""
        while True:
            for host in self.hosts:
                yield host

    async def _is_connection_valid(self) -> bool:
        """原子化检查连接有效性（加锁保证无竞态）"""
        async with self._lock:
            return (
                self._connection is not None
                and not self._connection.is_closed
                and not self._reconnecting
            )

    @property
    async def is_alive(self) -> bool:
        """对外暴露的连接存活状态（异步+原子化）"""
        if self._is_shutdown:
            return False
        return await self._is_connection_valid()

    async def _safe_close_old_resources(self):
        """强制关闭所有旧资源（加锁保证原子性，重连前必调用）"""
        async with self._lock:
            logger.info(f"开始释放旧资源（连接版本: {self._connection_version}）...")

            # 1. 关闭所有消费者通道（独立管理，终止旧自动恢复）
            for queue_name, (channel, _, _, _, _) in self._consumer_channels.items():
                try:
                    if not channel.is_closed:
                        await channel.close()
                    logger.info(f"已关闭队列 {queue_name} 的旧消费者通道（自动恢复终止）")
                except Exception as e:
                    logger.warning(f"关闭消费者通道 {queue_name} 失败: {str(e)}")
            self._consumer_channels.clear()

            # 2. 关闭所有普通通道（空闲+使用中，终止旧自动恢复）
            all_channels = self._free_channels + list(self._used_channels)
            for channel in all_channels:
                try:
                    if not channel.is_closed:
                        await channel.close()
                except Exception as e:
                    logger.warning(f"关闭旧通道失败: {str(e)}")
            self._free_channels.clear()
            self._used_channels.clear()

            # 3. 强制关闭旧连接（彻底终止旧连接的所有自动恢复）
            if self._connection:
                try:
                    if not self._connection.is_closed:
                        await self._connection.close()
                    logger.info(
                        f"已关闭旧连接: {self._current_host}:{self.port}（版本: {self._connection_version}）")
                except Exception as e:
                    logger.warning(f"关闭旧连接失败: {str(e)}")
                self._connection = None  # 置空，确保单连接

            logger.info("旧资源释放完成（所有旧自动恢复逻辑已终止）")

    async def _create_single_connection(self) -> AbstractRobustConnection:
        """创建唯一活跃连接（重连前已释放旧资源，确保单连接）"""
        max_attempts = len(self.hosts) * 2  # 每个节点尝试2次
        attempts = 0
        last_error: Optional[Exception] = None

        while attempts < max_attempts and not self._is_shutdown:
            self._current_host = next(self._host_iterator)
            conn_url = f"amqp://{self.username}:{self.password}@{self._current_host}:{self.port}/{self.virtualhost}"

            try:
                target_version = self._connection_version + 1
                logger.info(
                    f"尝试创建连接: {self._current_host}:{self.port} "
                    f"（目标版本: {target_version}，{attempts+1}/{max_attempts}）"
                )
                # 创建连接（保留aio-pika原生自动恢复）
                conn = await connect_robust(
                    conn_url,
                    properties={
                        "connection_name": f"{self.app_name}_conn_v{target_version}",
                        "product": self.app_name
                    },
                    heartbeat=self.heartbeat,
                    timeout=self.connection_timeout,
                    reconnect_interval=5,  # 单节点内部短间隔重连（原生自动恢复）
                    max_reconnect_attempts=3,  # 单节点最大重试3次
                )
                logger.info(
                    f"连接创建成功: {self._current_host}:{self.port}（版本: {target_version}）")
                return conn
            except Exception as e:
                attempts += 1
                last_error = e
                logger.error(
                    f"连接节点 {self._current_host}:{self.port} 失败（{attempts}/{max_attempts}）: {str(e)}",
                    exc_info=True
                )
                await asyncio.sleep(min(5 * attempts, self.reconnect_interval))

        raise ConnectionError(
            f"所有节点创建连接失败（节点列表: {self.hosts}）"
        ) from last_error

    async def _init_channel_pool(self):
        """初始化通道池（加锁保证原子性，绑定当前连接）"""
        async with self._lock:
            if self._is_shutdown:
                raise RuntimeError("通道池已关闭，无法初始化")

            # 校验当前连接有效性
            if not self._connection or self._connection.is_closed:
                raise RuntimeError("无有效连接，无法初始化通道池")

            self._free_channels.clear()
            self._used_channels.clear()

            # 创建指定数量的通道（保留原生自动恢复）
            for i in range(self.channel_pool_size):
                try:
                    channel = await self._connection.channel()  # 新通道自带自动恢复
                    await channel.set_qos(prefetch_count=self.prefetch_count)
                    self._free_channels.append(channel)
                except Exception as e:
                    logger.error(
                        f"创建通道失败（第{i+1}个，连接版本: {self._connection_version}）: {str(e)}",
                        exc_info=True
                    )
                    continue

            logger.info(
                f"通道池初始化完成 - 连接: {self._current_host}:{self.port}（版本: {self._connection_version}）, "
                f"可用通道数: {len(self._free_channels)}/{self.channel_pool_size}（均带自动恢复）"
            )

    async def _reconnect_if_needed(self) -> bool:
        """连接失效时重连（加锁保护，严格单连接+释放旧资源）"""
        # 快速判断，避免无效加锁
        if self._is_shutdown or self._reconnecting:
            return False

        self._reconnecting = True
        try:
            logger.warning(f"连接失效（当前版本: {self._connection_version}），开始重连...")

            # 1. 强制释放所有旧资源（加锁保证原子性）
            await self._safe_close_old_resources()

            # 2. 递增连接版本号（加锁保证原子性，区分新旧连接）
            async with self._lock:
                self._connection_version += 1
                target_version = self._connection_version

            # 3. 创建新连接（保留原生自动恢复）
            new_conn = await self._create_single_connection()

            # 4. 绑定新连接（加锁保证原子性）
            async with self._lock:
                self._connection = new_conn

            # 5. 重新初始化通道池（新通道带自动恢复）
            await self._init_channel_pool()

            # 6. 恢复消费者通道（新通道带自动恢复）
            await self._restore_consumer_channels()

            logger.info(f"重连成功（新连接版本: {target_version}），所有通道均带自动恢复")
            async with self._lock:
                self._initialized = True  # 重连成功后标记为已初始化
            return True
        except Exception as e:
            logger.error(f"重连失败: {str(e)}", exc_info=True)
            async with self._lock:
                self._initialized = False
            return False
        finally:
            self._reconnecting = False

    async def _restore_consumer_channels(self):
        """重连后恢复消费者通道（加锁保证原子性，新通道带自动恢复）"""
        async with self._lock:
            if not self._consumer_channels or not self._connection or self._connection.is_closed:
                return
            logger.info(
                f"开始恢复 {len(self._consumer_channels)} 个消费者通道（连接版本: {self._connection_version}）")

            # 临时保存消费者配置（队列名、回调、auto_ack、参数）
            consumer_configs = list(self._consumer_channels.items())
            self._consumer_channels.clear()

        # 重新创建消费者（不加锁，避免阻塞其他操作）
        for queue_name, (_, _, callback, auto_ack, kwargs) in consumer_configs:
            try:
                await self.consume_queue(queue_name, callback, auto_ack, **kwargs)
            except Exception as e:
                logger.error(
                    f"恢复消费者队列 {queue_name} 失败: {str(e)}", exc_info=True)

    async def _clean_invalid_channels(self):
        """清理失效通道并补充（加锁保证原子性，仅处理当前连接）"""
        async with self._lock:
            if self._is_shutdown or self._reconnecting:
                return

            # 1. 校验当前连接有效性
            current_valid = (
                self._connection is not None
                and not self._connection.is_closed
                and not self._reconnecting
            )
            if not current_valid:
                # 连接失效，触发重连（不加锁，避免死锁）
                asyncio.create_task(self._reconnect_if_needed())
                return

            # 2. 清理空闲通道（仅保留当前连接的有效通道）
            valid_free = []
            for chan in self._free_channels:
                try:
                    if not chan.is_closed and chan.connection == self._connection:
                        valid_free.append(chan)
                    else:
                        logger.warning(f"清理失效空闲通道（连接版本不匹配或已关闭）")
                except Exception:
                    logger.warning(f"清理异常空闲通道")
            self._free_channels = valid_free

            # 3. 清理使用中通道（仅保留当前连接的有效通道）
            valid_used = set()
            for chan in self._used_channels:
                try:
                    if not chan.is_closed and chan.connection == self._connection:
                        valid_used.add(chan)
                    else:
                        logger.warning(f"清理失效使用中通道（连接版本不匹配或已关闭）")
                except Exception:
                    logger.warning(f"清理异常使用中通道")
            self._used_channels = valid_used

            # 4. 补充通道到指定大小（新通道带自动恢复）
            total_valid = len(self._free_channels) + len(self._used_channels)
            missing = self.channel_pool_size - total_valid
            if missing > 0:
                logger.info(
                    f"通道池缺少{missing}个通道，补充中（连接版本: {self._connection_version}）...")
                for _ in range(missing):
                    try:
                        channel = await self._connection.channel()  # 新通道带自动恢复
                        await channel.set_qos(prefetch_count=self.prefetch_count)
                        self._free_channels.append(channel)
                    except Exception as e:
                        logger.error(f"补充通道失败: {str(e)}", exc_info=True)
                        break

    async def init_pools(self):
        """初始化通道池（加锁保证原子性，仅执行一次）"""
        async with self._lock:
            if self._initialized:
                logger.warning("通道池已初始化，无需重复调用")
                return
            if self._is_shutdown:
                raise RuntimeError("通道池已关闭，无法初始化")

        try:
            # 1. 创建新连接（保留原生自动恢复）
            new_conn = await self._create_single_connection()

            # 2. 初始化连接版本号和绑定连接（加锁保证原子性）
            async with self._lock:
                self._connection_version += 1
                self._connection = new_conn

            # 3. 初始化通道池（新通道带自动恢复）
            await self._init_channel_pool()

            # 4. 标记为已初始化（加锁保证原子性）
            async with self._lock:
                self._initialized = True

            logger.info("RabbitMQ单连接通道池初始化完成（所有通道均带自动恢复）")
        except Exception as e:
            logger.error(f"初始化失败: {str(e)}", exc_info=True)
            await self._safe_close_old_resources()
            raise

    async def acquire_channel(self) -> Tuple[Channel, AbstractRobustConnection]:
        """获取通道（加锁保证原子性，返回当前连接+带自动恢复的通道）"""
        # 快速判断，避免无效加锁
        async with self._lock:
            if not self._initialized:
                raise RuntimeError("通道池未初始化，请先调用init_pools()")
            if self._is_shutdown:
                raise RuntimeError("通道池已关闭，无法获取通道")

        # 先清理失效通道（加锁保证原子性）
        await self._clean_invalid_channels()

        async with self._lock:
            # 双重校验连接有效性
            current_valid = (
                self._connection is not None
                and not self._connection.is_closed
                and not self._reconnecting
            )
            if not current_valid:
                # 连接失效，触发重连（不加锁，避免死锁）
                reconnect_success = await self._reconnect_if_needed()
                if not reconnect_success:
                    raise RuntimeError("连接失效且重连失败，无法获取通道")

            # 优先从空闲池获取（带自动恢复的通道）
            if self._free_channels:
                channel = self._free_channels.pop()
                self._used_channels.add(channel)
                return channel, self._connection

            # 通道池已满，创建临时通道（带自动恢复，用完关闭）
            try:
                channel = await self._connection.channel()  # 临时通道带自动恢复
                await channel.set_qos(prefetch_count=self.prefetch_count)
                self._used_channels.add(channel)
                logger.warning(
                    f"通道池已达上限（{self.channel_pool_size}），创建临时通道（带自动恢复，用完自动关闭）"
                )
                return channel, self._connection
            except Exception as e:
                logger.error(f"获取通道失败: {str(e)}", exc_info=True)
                raise

    async def release_channel(self, channel: Channel, conn: AbstractRobustConnection):
        """释放通道（加锁保证原子性，仅归还当前连接的有效通道）"""
        # 快速判断，避免无效加锁
        if not channel or not conn or self._is_shutdown:
            return

        async with self._lock:
            # 仅处理当前连接的通道（旧连接的通道直接关闭）
            if conn != self._connection:
                try:
                    if not channel.is_closed:
                        await channel.close()
                    logger.warning(f"已关闭非当前连接的通道（版本不匹配，自动恢复终止）")
                except Exception as e:
                    logger.warning(f"关闭非当前连接通道失败: {str(e)}")
                return

            # 通道不在使用中，直接返回
            if channel not in self._used_channels:
                return

            # 移除使用中标记
            self._used_channels.remove(channel)

            # 仅归还有效通道（当前连接有效+通道未关闭+池未满）
            current_valid = (
                self._connection is not None
                and not self._connection.is_closed
                and not self._reconnecting
            )
            if current_valid and not channel.is_closed and len(self._free_channels) < self.channel_pool_size:
                self._free_channels.append(channel)
            else:
                # 无效通道直接关闭（终止自动恢复）
                try:
                    if not channel.is_closed:
                        await channel.close()
                except Exception as e:
                    logger.warning(f"关闭通道失败: {str(e)}")

    async def declare_queue(self, queue_name: str, **kwargs) -> AbstractQueue:
        """声明队列（使用池内通道，带自动恢复）"""
        channel, conn = await self.acquire_channel()
        try:
            return await channel.declare_queue(queue_name, **kwargs)
        finally:
            await self.release_channel(channel, conn)

    async def declare_exchange(self, exchange_name: str, exchange_type: str = "direct", **kwargs) -> AbstractExchange:
        """声明交换机（使用池内通道，带自动恢复）"""
        channel, conn = await self.acquire_channel()
        try:
            return await channel.declare_exchange(exchange_name, exchange_type, **kwargs)
        finally:
            await self.release_channel(channel, conn)

    async def publish_message(self, routing_key: str, message_body: bytes, exchange_name: str = "", **kwargs):
        """发布消息（使用池内通道，带自动恢复）"""
        channel, conn = await self.acquire_channel()
        try:
            exchange = channel.default_exchange if not exchange_name else await channel.get_exchange(exchange_name)
            message = Message(body=message_body, **kwargs)
            await exchange.publish(message, routing_key=routing_key)
            logger.debug(
                f"消息发布成功 - 连接: {self._current_host}:{self.port}（版本: {self._connection_version}）, "
                f"交换机: {exchange.name}, 路由键: {routing_key}"
            )
        except Exception as e:
            logger.error(f"发布消息失败: {str(e)}", exc_info=True)
            # 发布失败触发重连（下次使用新通道）
            asyncio.create_task(self._reconnect_if_needed())
            raise
        finally:
            await self.release_channel(channel, conn)

    async def consume_queue(self, queue_name: str, callback: Callable[[AbstractMessage], asyncio.Future], auto_ack: bool = False, **kwargs):
        """消费队列（使用独立通道，带自动恢复，支持多消费者）"""
        # 快速判断，避免无效加锁
        async with self._lock:
            if not self._initialized:
                raise RuntimeError("通道池未初始化，请先调用init_pools()")
            if self._is_shutdown:
                raise RuntimeError("通道池已关闭，无法启动消费")

        # 先声明队列（确保队列存在）
        await self.declare_queue(queue_name, **kwargs)

        # 获取独立通道（消费者通道不放入普通池）
        channel, conn = await self.acquire_channel()

        # 注册消费者通道（加锁保证原子性）
        async with self._lock:
            self._consumer_channels[queue_name] = (
                channel, conn, callback, auto_ack, kwargs)

        async def consume_callback_wrapper(message: AbstractMessage):
            """消费回调包装（处理通道失效和自动恢复）"""
            try:
                # 加锁校验通道和连接状态（原子性）
                async with self._lock:
                    conn_matched = conn == self._connection
                    channel_valid = not channel.is_closed
                    current_conn_valid = (
                        self._connection is not None
                        and not self._connection.is_closed
                        and not self._reconnecting
                    )

                # 通道/连接失效，触发重连恢复
                if not conn_matched or not channel_valid or not current_conn_valid:
                    logger.warning(
                        f"消费者通道 {queue_name} 失效（连接版本不匹配/通道关闭），触发重连恢复")

                    # 移除旧消费者记录（加锁保证原子性）
                    async with self._lock:
                        if self._consumer_channels.get(queue_name) == (channel, conn, callback, auto_ack, kwargs):
                            del self._consumer_channels[queue_name]

                    # 释放旧通道（加锁保证原子性）
                    await self.release_channel(channel, conn)

                    # 重新创建消费者（新通道带自动恢复）
                    asyncio.create_task(self.consume_queue(
                        queue_name, callback, auto_ack, **kwargs))

                    # Nack消息（避免丢失）
                    if not auto_ack:
                        await message.nack(requeue=True)
                    return

                # 执行业务回调
                await callback(message)
                if not auto_ack:
                    await message.ack()
            except ChannelClosed as e:
                logger.error(f"消费者通道 {queue_name} 关闭: {str(e)}", exc_info=True)
                if not auto_ack:
                    await message.nack(requeue=True)
                asyncio.create_task(self._reconnect_if_needed())
            except aiormq.exceptions.ChannelInvalidStateError as e:
                logger.error(
                    f"消费者通道 {queue_name} 状态异常: {str(e)}", exc_info=True)
                if not auto_ack:
                    await message.nack(requeue=True)
                asyncio.create_task(self.consume_queue(
                    queue_name, callback, auto_ack, **kwargs))
            except Exception as e:
                logger.error(
                    f"消费消息失败（队列: {queue_name}）: {str(e)}", exc_info=True)
                if not auto_ack:
                    await message.nack(requeue=True)

        # 日志输出（加锁获取当前连接信息）
        async with self._lock:
            current_host = self._current_host
            current_version = self._connection_version

        logger.info(
            f"开始消费队列: {queue_name} - 连接: {current_host}:{self.port}（版本: {current_version}）, "
            f"通道带自动恢复"
        )

        try:
            # 启动消费（使用带自动恢复的通道）
            await channel.basic_consume(
                queue_name,
                consumer_callback=consume_callback_wrapper,
                auto_ack=auto_ack,
                **kwargs
            )
        except Exception as e:
            logger.error(f"启动消费失败（队列: {queue_name}）: {str(e)}", exc_info=True)
            # 清理异常的消费者记录和通道（加锁保证原子性）
            await self.release_channel(channel, conn)
            async with self._lock:
                if self._consumer_channels.get(queue_name) == (channel, conn, callback, auto_ack, kwargs):
                    del self._consumer_channels[queue_name]
            raise

    async def close(self):
        """关闭通道池（加锁保证原子性，释放所有资源）"""
        async with self._lock:
            if self._is_shutdown:
                logger.warning("通道池已关闭，无需重复操作")
                return
            self._is_shutdown = True

        logger.info("开始关闭RabbitMQ单连接通道池...")
        # 强制释放所有资源（包括自动恢复的通道）
        await self._safe_close_old_resources()
        logger.info("RabbitMQ单连接通道池已完全关闭（所有自动恢复逻辑已终止）")
