# Copyright (c) Microsoft. All rights reserved.

"""A basic JSON-based planner for the Python Semantic Kernel"""
import os
import itertools
import json
import logging
import regex
import unicodedata
from typing import Dict, List, Optional

from flexible_semantic_kernel.kernel import Kernel
from flexible_semantic_kernel.orchestration.sk_context import SKContext
from flexible_semantic_kernel.orchestration.sk_function_base import SKFunctionBase
from flexible_semantic_kernel.orchestration.context_variables import ContextVariables
from flexible_semantic_kernel.planning.planning_exception import PlanningException
from flexible_semantic_kernel.planning.plan import Plan
from flexible_semantic_kernel.plugin_definition import sk_function, sk_function_context_parameter
from flexible_semantic_kernel.plugin_definition.function_view import FunctionView
from flexible_semantic_kernel.plugin_definition.parameter_view import ParameterView

from flexible_semantic_kernel.planning.planner_config import PlannerConfig

logger: logging.Logger = logging.getLogger(__name__)


class BasicPlanner:
    """
    Basic JSON-based planner for the Semantic Kernel.
    """
    RESTRICTED_PLUGIN_NAME = "BasicPlanner_Excluded"
    config: PlannerConfig

    _planner_function: SKFunctionBase

    _kernel: Kernel
    _prompt_template: str

    def __init__(
            self,
            kernel: Kernel,
            config: Optional[PlannerConfig] = None,
            prompt: Optional[str] = None,
            **kwargs,
    ) -> None:
        """
        初始化
        :param kernel: Semantic Kernel
        :param config: PlannerConfig
        :param prompt: planner prompt
        :param kwargs:
        :return:
        """
        if kwargs.get("logger"):
            logger.warning("The `logger` parameter is deprecated. Please use the `logging` module instead.")
        if kernel is None:
            raise PlanningException(
                PlanningException.ErrorCodes.InvalidConfiguration,
                "Kernel cannot be `None`.",
            )

        self.config = config or PlannerConfig()

        __cur_dir = os.path.dirname(os.path.abspath(__file__))
        __prompt_file = os.path.join(__cur_dir, "skprompt.txt")
        self._prompt_template = prompt if prompt else open(__prompt_file, "r").read()

        self._planner_function = kernel.create_semantic_function(
            plugin_name=self.RESTRICTED_PLUGIN_NAME,
            prompt_template=self._prompt_template,
            **self.config.extension_data
        )
        kernel.import_plugin(self, self.RESTRICTED_PLUGIN_NAME)

        self._kernel = kernel
        self._context = kernel.create_new_context()

    @sk_function(description="List all functions available in the kernel", name="ListOfFunctions")
    @sk_function_context_parameter(name="goal", description="The current goal processed by the planner")
    def list_of_functions(self, goal: str, context: SKContext) -> str:
        """
        list all functions
        """
        if context.plugins is None:
            raise PlanningException(
                error_code=PlanningException.ErrorCodes.InvalidConfiguration,
                message="Suitable plan not generated by BasicPlanner.",
                inner_exception=ValueError("No plugins are available."),
            )

        functions_view = context.plugins.get_functions_view()

        available_functions: List[FunctionView] = [
            *functions_view.semantic_functions.values(),
            *functions_view.native_functions.values(),
        ]
        available_functions = itertools.chain.from_iterable(available_functions)

        available_functions = [
            self._create_function_string(func)
            for func in available_functions
            if (
                    func.plugin_name != self.RESTRICTED_PLUGIN_NAME
                    and func.plugin_name not in self.config.excluded_plugins
                    and func.name not in self.config.excluded_functions
            )
        ]

        available_functions_str = "\n\n".join(available_functions)

        logger.info(f"List of available functions:\n{available_functions_str}")

        return available_functions_str

    def _create_function_string(self, function: FunctionView) -> str:
        """
        Given an instance of the Kernel, create the [AVAILABLE FUNCTIONS]
        string for the prompt.
        :param function: The function to be converted into a string.
        """
        if not function.description:
            logger.warning(f"{function.plugin_name}.{function.name} is missing a description")
            description = f"description: {function.plugin_name}.{function.name}."
        else:
            description = f"description: {function.description}"

        # add trailing period for description if not present
        if description[-1] != ".":
            description = f"{description}."

        name = f"{function.plugin_name}.{function.name}"

        parameters_list = [
            result for x in function.parameters if (result := self._create_parameter_string(x)) is not None
        ]

        if len(parameters_list) == 0:
            parameters = "No parameters."
        else:
            parameters = "\n".join(parameters_list)

        func_str = f"{name}\n{description}\nparameters:\n{parameters}"

        return func_str

    def _create_parameter_string(self, parameter: ParameterView) -> str:
        """
        Takes an instance of ParameterView and returns a string that consists of
        parameter name, parameter description and default value for the parameter
        in the following format
        Parameter ""<parameter-name>"": <parameter-description> (default value: <default-value>)

        :param parameter: An instance of ParameterView for which the string representation needs to be generated
        :return: string representation of parameter
        """

        name = parameter.name
        description = desc if (desc := parameter.description) else name

        # add trailing period for description if not present
        if description[-1] != ".":
            description = f"{description}."

        default_value = f"(default value: {val})" if (val := parameter.default_value) else ""

        param_str = f'- "{name}": {description} {default_value}'

        return param_str.strip()

    def _to_plan_from_dict(self, goal: str, generated_plan: Dict) -> Plan:
        """
        将子任务列表转换为plan对象
        :param generated_plan: 生成的plan字典
        :return: plan对象
        """
        plan = Plan.from_goal(goal)

        plan.parameters["input"] = str(generated_plan["input"])
        plan.state["input"] = str(generated_plan["input"])

        subtasks = generated_plan["subtasks"]
        if subtasks is None:
            logger.warning("No suitable function has been identified by BasicPlanner.")
            return plan

        for subtask in subtasks:
            function = subtask["function"]
            try:
                plugin_name, function_name = function.split(".")
            except ValueError:
                raise PlanningException(PlanningException.ErrorCodes.InvalidPlan, f"Function {function} is invalid.")
            if function_name:
                plugin_function = self._kernel.plugins.get_function(plugin_name, function_name)
                if plugin_function is not None:
                    plan_step = Plan.from_function(plugin_function)

                    function_variables = ContextVariables()
                    function_outputs = []
                    function_results = []
                    view = plugin_function.describe()
                    for p in view.parameters:
                        if p.name in subtask.get("parameters", {}):
                            function_variables.set(p.name, str(subtask["parameters"][p.name]))
                        else:
                            function_variables.set(p.name, str(p.default_value))
                    if subtask.get("set_context_variable", None):
                        function_outputs.append(subtask["set_context_variable"])
                    elif subtask.get("append_to_result", None):
                        function_outputs.append(subtask["append_to_result"])
                        function_results.append(subtask["append_to_result"])
                    plan_step._parameters = function_variables
                    plan_step._state = function_variables
                    plan_step._outputs = function_outputs

                    for result in function_results:
                        plan._outputs.append(result)

                    plan.add_steps([plan_step])
                else:
                    raise PlanningException(
                        PlanningException.ErrorCodes.InvalidPlan,
                        f"Failed to find function '{function_name}' in plugin '{plugin_name}'.",
                    )

        return plan

    async def create_plan_async(self, goal: str, **kwargs) -> Plan:
        """
        Creates a plan for the given goal based off the functions that
        are available in the kernel.
        :param goal: The goal to create a plan for.
        """
        if kwargs.get("logger"):
            logger = kwargs.get("logger")

        if goal is None:
            raise PlanningException(PlanningException.ErrorCodes.InvalidGoal, "Goal cannot be `None`.")

        logger.info(f"Finding the best function for achieving the goal: {goal}")

        # Add the goal to the context
        self._context.variables.update(goal)
        generated_plan_raw = await self._planner_function.invoke_async(context=self._context)

        # plan字符串
        json_regex = r"\{(?:[^{}]|(?R))*\}"
        generated_plan_str = regex.search(json_regex, generated_plan_raw.result).group()
        generated_plan_str = generated_plan_str.replace('""', '"')
        generated_plan_str = generated_plan_str.replace('\\_', '_')
        # 去除json中的注释
        generated_plan_str_list = []
        for line in generated_plan_str.split("\n"):
            generated_plan_str_list.append(line.split("//")[0])
        generated_plan_str = " ".join(generated_plan_str_list)
        generated_plan_str = unicodedata.normalize("NFKC", generated_plan_str)

        logger.info(f"Plan generated by BasicPlanner:\n{generated_plan_str}")

        if not generated_plan_str:
            logger.error("No valid plan has been generated.")
            raise PlanningException(
                PlanningException.ErrorCodes.InvalidPlan,
                "No valid plan has been generated.",
                inner_exception=ValueError(generated_plan_str),
            )

        # json格式的plan转换为字典格式
        generated_plan = json.loads(generated_plan_str)

        return self._to_plan_from_dict(goal, generated_plan)
