from typing import Dict, Type, List, Union, Optional, Callable

from sycommon.llm.llm_logger import LLMLogger
from langchain_core.language_models import BaseChatModel
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
from langchain.chat_models import init_chat_model
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from pydantic import BaseModel, ValidationError
from sycommon.config.LLMConfig import LLMConfig


def get_llm(model: str = None, streaming: bool = False) -> BaseChatModel:
    if not model:
        model = "Qwen2.5-72B"
        # model = "SyMid"
    llmConfig = LLMConfig.from_config(model)
    llm = None
    if llmConfig:
        llm = init_chat_model(
            model_provider=llmConfig.provider,
            model=llmConfig.model,
            base_url=llmConfig.baseUrl,
            api_key="-",
            temperature=0.1,
            streaming=streaming,
        )
    else:
        raise Exception("Invalid model")

    # 为LLM动态添加with_structured_output方法，官方的with_structured_output方法有概率在qwen2.5中导致模型卡死不返回数据，2.5对functioncall支持不好
    def with_structured_output(
        self: BaseChatModel,
        output_model: Type[BaseModel],
        max_retries: int = 3,
        is_extract: bool = False,
        override_prompt: ChatPromptTemplate = None,
        # 自定义处理函数列表（每个函数接收str，返回str）
        custom_processors: Optional[List[Callable[[str], str]]] = None,
        # 自定义解析函数（接收str，返回BaseModel）
        custom_parser: Optional[Callable[[str], BaseModel]] = None
    ) -> Runnable[List[BaseMessage], BaseModel]:
        parser = PydanticOutputParser(pydantic_object=output_model)

        accuracy_instructions = """
        字段值的抽取准确率（0~1之间），评分规则：
        1.0（完全准确）：直接从原文提取，无需任何加工，且格式与原文完全一致
        0.9（轻微处理）：数据来源明确，但需进行格式标准化或冗余信息剔除（不改变原始数值）
        0.8（有限推断）：数据需通过上下文关联或简单计算得出，仍有明确依据
        0.8以下（不可靠）：数据需大量推测、存在歧义或来源不明，处理方式：直接忽略该数据，设置为None
        """

        if is_extract:
            # 抽取模式下使用固定的抽取专用prompt
            prompt = ChatPromptTemplate.from_messages([
                MessagesPlaceholder(variable_name="messages"),
                HumanMessage(content=f"""
                请提取信息并遵循以下规则：
                1. 准确率要求：{accuracy_instructions.strip()}
                2. 输出格式：{parser.get_format_instructions()}
                """)
            ])
        else:
            if override_prompt:
                prompt = override_prompt
            else:
                prompt = ChatPromptTemplate.from_messages([
                    MessagesPlaceholder(variable_name="messages"),
                    HumanMessage(content=f"""
                    输出格式：{parser.get_format_instructions()}
                    """)
                ])

        # ========== 基础处理函数 ==========
        def extract_response_content(response: BaseMessage) -> str:
            """提取响应中的文本内容"""
            try:
                return response.content
            except Exception as e:
                raise ValueError(f"提取响应内容失败：{str(e)}") from e

        def strip_code_block_markers(content: str) -> str:
            """移除JSON代码块标记（```json/```）"""
            try:
                return content.strip("```json").strip("```").strip()
            except Exception as e:
                raise ValueError(
                    f"移除代码块标记失败（内容：{str(content)[:100]}）：{str(e)}") from e

        def normalize_in_json(content: str) -> str:
            """将None替换为null，确保JSON格式合法"""
            try:
                cleaned = content.replace("None", "null")
                cleaned = cleaned.replace("none", "null")
                cleaned = cleaned.replace("NONE", "null")
                cleaned = cleaned.replace("''", '""')
                return cleaned
            except Exception as e:
                raise ValueError(
                    f"替换None为null失败（内容：{str(content)[:100]}）：{str(e)}") from e

        def default_parse_to_pydantic(content: str) -> BaseModel:
            """默认解析函数：将处理后的文本解析为Pydantic模型"""
            try:
                return parser.parse(content)
            except (ValidationError, ValueError) as e:
                raise type(e)(f"解析失败（原始内容：{content[:200]}）：{str(e)}") from e

        # ========== 构建处理链条 ==========
        # 基础链 prompt → LLM → 提取响应内容
        base_chain = (
            prompt
            | self
            | RunnableLambda(extract_response_content)
        )

        # 处理函数链 优先使用自定义，否则用默认
        if custom_processors:
            # 自定义处理函数 → 转为RunnableLambda列表
            process_runnables = [RunnableLambda(
                func) for func in custom_processors]
        else:
            # 默认处理函数：移除代码块标记 → 标准化JSON空值
            process_runnables = [
                RunnableLambda(strip_code_block_markers),
                RunnableLambda(normalize_in_json)
            ]

        # 拼接处理链
        process_chain = base_chain
        for runnable in process_runnables:
            process_chain = process_chain | runnable

        # 解析函数 优先使用自定义，否则用默认
        parse_func = custom_parser if custom_parser else default_parse_to_pydantic
        parse_chain = process_chain | RunnableLambda(parse_func)

        retry_chain = parse_chain.with_retry(
            retry_if_exception_type=(ValidationError, ValueError),
            stop_after_attempt=max_retries,
            wait_exponential_jitter=True,
            exponential_jitter_params={
                "initial": 0.1,  # 初始等待时间（秒）
                "max": 3.0,      # 最大等待时间（秒）
                "exp_base": 2.0,  # 指数基数（默认2）
                "jitter": 1.0    # 随机抖动值（默认1）
            }
        )

        class StructuredRunnable(Runnable[Union[List[BaseMessage], BaseMessage, str, Dict[str, str]], BaseModel]):
            def _adapt_input(self, input: Union[List[BaseMessage], BaseMessage, str, Dict[str, str]]) -> List[BaseMessage]:
                """将多种输入格式统一转换为 List[BaseMessage]"""
                if isinstance(input, list) and all(isinstance(x, BaseMessage) for x in input):
                    return input
                elif isinstance(input, BaseMessage):
                    return [input]
                elif isinstance(input, str):
                    return [HumanMessage(content=input)]
                elif isinstance(input, dict) and "input" in input:
                    return [HumanMessage(content=str(input["input"]))]
                else:
                    raise ValueError(
                        "不支持的输入格式，请使用消息列表、单条消息、文本或 {'input': '文本'}")

            def invoke(self, input: Union[List[BaseMessage], BaseMessage, str, Dict[str, str]], config={"callbacks": [LLMLogger()]}):
                adapted_input = self._adapt_input(input)
                return retry_chain.invoke({"messages": adapted_input}, config=config)

            async def ainvoke(self, input: Union[List[BaseMessage], BaseMessage, str, Dict[str, str]], config={"callbacks": [LLMLogger()]}):
                adapted_input = self._adapt_input(input)
                return await retry_chain.ainvoke({"messages": adapted_input}, config=config)

        return StructuredRunnable()

    llm.__class__.with_structured_output = with_structured_output
    return llm
