from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Callable

from aiconfig_tools.AIConfigSettings import AIConfig, InferenceResponse, Output, Prompt


if TYPE_CHECKING:
    from aiconfig_tools.Config import AIConfigRuntime


class ModelParser(ABC):
    @abstractmethod
    def id(self) -> str:
        """
        Returns an identifier for the model (e.g. llama-2, gpt-4, etc.).
        """
        pass

    @abstractmethod
    def serialize(
        self,
        prompt_name: str,
        data: Any,
        ai_config: 'AIConfigRuntime',
        parameters: Optional[Dict] = None,
        **kwargs
    ) -> Prompt:
        """
        Serialize a prompt and additional metadata/model settings into a Prompt object that can be saved in the AIConfig.

        Args:
            prompt_name (str): Name to identify the prompt.
            data (Any): The prompt data to be serialized.
            ai_config (AIConfig): The AIConfig that the prompt belongs to.
            parameters (dict, optional): Optional parameters to include in the serialization.
            **kwargs: Additional keyword arguments, if needed.

        Returns:
            Prompt: Serialized representation of the input data.
        """
        pass

    @abstractmethod
    async def deserialize(
        self,
        prompt: Prompt,
        aiConfig: 'AIConfigRuntime',
        options: Optional['InferenceOptions'] = None,
        params: Optional[Dict] = None,
    ) -> Any:
        """
        Deserialize a Prompt object loaded from an AIConfig into a structure that can be used for model inference.

        Args:
            prompt (Prompt): The Prompt object from an AIConfig to deserialize into a structure that can be used for model inference.
            aiConfig (AIConfigRuntime): The AIConfig that the prompt belongs to.
            params (dict, optional): Optional parameters to override the prompt's parameters with.

        Returns:
            R: Completion params that can be used for model inference.
        """
        pass

    @abstractmethod
    async def run(self, prompt: Prompt, aiconfig: AIConfig, options: Optional['InferenceOptions'] = None, parameters: Dict = {}) -> InferenceResponse:
        """
        Execute model inference based on completion data to be constructed in deserialize(), which includes the input prompt and
        model-specific inference settings. Saves the response or output in prompt.outputs.

        Args:
            prompt (Prompt): The prompt to be used for inference.
            aiconfig (AIConfig): The AIConfig object containing all prompts and parameters.
            options (InferenceOptions, optional): Options that determine how to run inference for the prompt
            parameters (dict, optional): Optional parameters to include in the serialization.

        Returns:
            InferenceResponse: The response generated by the model.
        """
        pass
    
    @abstractmethod
    def get_output_text(self, prompt: Prompt, aiconfig: 'AIConfigRuntime', output: Optional[Output] = None) -> str:
        """
        Get the output text from the model inference response.

        Args:
            prompt (Prompt): The prompt to be used for inference.
            aiconfig (AIConfig): The AIConfig object containing all prompts and parameters.

        Returns:
            str: The output text from the model inference response.
        """
        pass


def print_stream_callback(data, accumulated_data, index: int):
    """
    Default streamCallback function that prints the output to the console.
    """
    print("\ndata: {}\naccumulated_data:{}\nindex:{}\n".format(data, accumulated_data, index))

def print_stream_delta(data, accumulated_data, index: int):
    """
    Default streamCallback function that prints the output to the console.
    """
    if "content" in data:
        content = data['content']
        print(content, end = "", flush=True)


class InferenceOptions():
    """
    Options that determine how to run inference for the prompt (e.g., whether to stream the output or not, callbacks, etc.)
    """

    def __init__(self, stream_callback: Callable[[Any, Any, int], Any] = print_stream_delta,  stream=True, **kwargs ):
        super().__init__()

        """ 
        Called when a model is in streaming mode and an update is available.

        Args:
            data: The new data chunk from the stream.
            accumulatedData: The running sum of all data chunks received so far.
            index (int): The index of the choice that the data chunk belongs to
                (default is 0, but if the model generates multiple choices, this will be the index of
                the choice that the data chunk belongs to).

            Returns:
                None
        """
        self.stream_callback = stream_callback
        
        self.stream = stream

        for key, value in kwargs.items():
            setattr(self, key, value)

    def update_stream_callback(self, callback: Callable[[Any, Any, int], Any]):
        """
        Update the streamCallback function in the callbacks dictionary.
        """
        self.stream_callback= callback