import asyncio
import json
import aiohttp
from typing import Union, List, Optional

from sycommon.config.Config import SingletonMeta
from sycommon.config.EmbeddingConfig import EmbeddingConfig
from sycommon.config.RerankerConfig import RerankerConfig
from sycommon.logging.kafka_log import SYLogger


class Embedding(metaclass=SingletonMeta):
    def __init__(self):
        # 1. 并发限制
        self.max_concurrency = 20
        # 保留默认模型名称
        self.default_embedding_model = "bge-large-zh-v1.5"
        self.default_reranker_model = "bge-reranker-large"

        # 初始化默认模型的基础URL
        self.embeddings_base_url = EmbeddingConfig.from_config(
            self.default_embedding_model).baseUrl
        self.reranker_base_url = RerankerConfig.from_config(
            self.default_reranker_model).baseUrl

        # 并发信号量
        self.semaphore = asyncio.Semaphore(self.max_concurrency)

    async def _get_embeddings_http_async(
        self,
        input: Union[str, List[str]],
        encoding_format: str = None,
        model: str = None,
        **kwargs
    ):
        async with self.semaphore:
            # 优先使用传入的模型名，无则用默认值
            target_model = model or self.default_embedding_model
            target_base_url = EmbeddingConfig.from_config(target_model).baseUrl
            url = f"{target_base_url}/v1/embeddings"

            request_body = {
                "model": target_model,
                "input": input,
                "encoding_format": encoding_format or "float"
            }
            request_body.update(kwargs)

            async with aiohttp.ClientSession() as session:
                async with session.post(url, json=request_body) as response:
                    if response.status != 200:
                        error_detail = await response.text()
                        SYLogger.error(
                            f"Embedding request failed (model: {target_model}): {error_detail}")
                        return None
                    return await response.json()

    async def _get_reranker_http_async(
        self,
        documents: List[str],
        query: str,
        top_n: Optional[int] = None,
        model: str = None,
        max_chunks_per_doc: Optional[int] = None,
        return_documents: Optional[bool] = True,
        return_len: Optional[bool] = True,
        **kwargs
    ):
        async with self.semaphore:
            # 优先使用传入的模型名，无则用默认值
            target_model = model or self.default_reranker_model
            target_base_url = RerankerConfig.from_config(target_model).baseUrl
            url = f"{target_base_url}/v1/rerank"

            request_body = {
                "model": target_model,
                "documents": documents,
                "query": query,
                "top_n": top_n or len(documents),
                "max_chunks_per_doc": max_chunks_per_doc,
                "return_documents": return_documents,
                "return_len": return_len,
                "kwargs": json.dumps(kwargs),
            }
            request_body.update(kwargs)

            async with aiohttp.ClientSession() as session:
                async with session.post(url, json=request_body) as response:
                    if response.status != 200:
                        error_detail = await response.text()
                        SYLogger.error(
                            f"Rerank request failed (model: {target_model}): {error_detail}")
                        return None
                    return await response.json()

    async def get_embeddings(
        self,
        corpus: List[str],
        model: str = None
    ):
        """
        获取语料库的嵌入向量，结果顺序与输入语料库顺序一致

        Args:
            corpus: 待生成嵌入向量的文本列表
            model: 可选，指定使用的embedding模型名称，默认使用bge-large-zh-v1.5
        """
        SYLogger.info(
            f"Requesting embeddings for corpus: {corpus} (model: {model or self.default_embedding_model}, max_concurrency: {self.max_concurrency})")
        # 给每个异步任务传入模型名称
        tasks = [self._get_embeddings_http_async(
            text, model=model) for text in corpus]
        results = await asyncio.gather(*tasks)

        vectors = []
        for result in results:
            if result is None:
                zero_vector = [0.0] * 1024
                vectors.append(zero_vector)
                SYLogger.warning(
                    f"Embedding request failed, append zero vector (1024D)")
                continue
            for item in result["data"]:
                vectors.append(item["embedding"])

        SYLogger.info(
            f"Embeddings for corpus: {corpus} created (model: {model or self.default_embedding_model})")
        return vectors

    async def get_reranker(
        self,
        top_results: List[str],
        query: str,
        model: str = None
    ):
        """
        对搜索结果进行重排序

        Args:
            top_results: 待重排序的文本列表
            query: 排序参考的查询语句
            model: 可选，指定使用的reranker模型名称，默认使用bge-reranker-large
        """
        SYLogger.info(
            f"Requesting reranker for top_results: {top_results} (model: {model or self.default_reranker_model}, max_concurrency: {self.max_concurrency})")
        data = await self._get_reranker_http_async(top_results, query, model=model)
        SYLogger.info(
            f"Reranker for top_results: {top_results} completed (model: {model or self.default_reranker_model})")
        return data
