import base64
import io
import time

import numpy as np
from openai import OpenAI
from PIL import Image

from docviz.constants import DEFAULT_VISION_PROMPT
from docviz.logging import get_logger

logger = get_logger(__name__)


class ChartSummarizer:
    """
    Summarizes and exports data from charts and diagrams in a given image using OpenAI's vision models.

    Attributes:
        model_name (str): The name of the OpenAI model to use.
        openai_api_key (str): The API key for OpenAI.
        openai_base_url (str): The base URL for the OpenAI API.
        retries (int): Number of times to retry the request on failure.
        timeout (int): Timeout for each request in seconds.
    """

    def __init__(
        self,
        model_name: str,
        api_key: str,
        base_url: str,
        retries: int,
        timeout: int,
    ) -> None:
        """
        Initialize the ChartSummarizer.

        Args:
            model_name (str): The OpenAI model to use for vision tasks.
            api_key (str): The OpenAI API key.
            base_url (str): The base URL for the OpenAI API.
            retries (int): Number of times to retry the request on failure.
            timeout (int): Timeout for each request in seconds.
        """
        logger.info(f"Initializing ChartSummarizer with model: {model_name}")
        self.model_name = model_name
        self.retries = retries
        self.timeout = timeout
        config = {}
        if api_key:
            config["api_key"] = api_key
        if base_url:
            config["base_url"] = base_url
        self.client = OpenAI(**config)
        logger.info("ChartSummarizer initialized successfully")

    def _numpy_to_base64(self, image: np.ndarray, format: str = "PNG") -> str:
        """
        Convert numpy array to base64 encoded image string.

        Args:
            image (np.ndarray): Image as numpy array
            format (str): Output image format (PNG, JPEG, etc.)

        Returns:
            str: Base64 encoded image string
        """
        # Handle different numpy array formats
        if image.dtype != np.uint8:
            # Convert to uint8 if not already
            if image.max() <= 1.0:  # type: ignore
                # Assume normalized float values [0,1]
                image = (image * 255).astype(np.uint8)
            else:
                # Clip to valid range
                image = np.clip(image, 0, 255).astype(np.uint8)

        # Handle grayscale vs color images
        if len(image.shape) == 2:
            # Grayscale image
            pil_image = Image.fromarray(image, mode="L")
        elif len(image.shape) == 3:
            if image.shape[2] == 1:
                # Single channel, squeeze and treat as grayscale
                pil_image = Image.fromarray(image.squeeze(), mode="L")
            elif image.shape[2] == 3:
                # RGB image
                pil_image = Image.fromarray(image, mode="RGB")
            elif image.shape[2] == 4:
                # RGBA image
                pil_image = Image.fromarray(image, mode="RGBA")
            else:
                raise ValueError(f"Unsupported image shape: {image.shape}")
        else:
            raise ValueError(f"Unsupported image dimensions: {image.shape}")

        # Convert to base64
        buffer = io.BytesIO()
        pil_image.save(buffer, format=format)
        buffer.seek(0)
        image_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")

        logger.debug(f"Converted numpy array {image.shape} to base64 {format}")
        return image_b64

    def summarize_charts_from_page(
        self,
        image: np.ndarray,
        prompt: str | None = None,
        extra_context: str | None = None,
    ) -> str:
        """
        Summarize charts found in the given image using the vision model.

        Args:
            image (np.ndarray): Image to summarize.
            prompt (Optional[str]): Custom prompt for the model.
            extra_context (Optional[str]): Additional context to provide to the model.

        Returns:
            str: The summary generated by the model.

        Raises:
            RuntimeError: If OpenAI API is not available or the request fails.
        """
        logger.debug(f"Summarizing charts from image shape: {image.shape}")

        if prompt is None:
            prompt = DEFAULT_VISION_PROMPT
            logger.debug("Using default chart summarization prompt")
        else:
            logger.debug("Using custom prompt for summarization")

        if extra_context:
            prompt = f"{prompt}\n\n{extra_context}"
            logger.debug("Added extra context to prompt")

        try:
            image_b64 = self._numpy_to_base64(image)
        except Exception as e:
            logger.error(f"Failed to convert numpy array to base64: {e}")
            raise RuntimeError(f"Image conversion failed: {e}") from e

        last_exception: Exception | None = None
        for attempt in range(1, self.retries + 1):
            try:
                logger.debug(
                    f"Sending request to model: {self.model_name} (attempt {attempt}/{self.retries})"
                )
                messages = [
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {
                                "type": "image_url",
                                "image_url": {"url": f"data:image/png;base64,{image_b64}"},
                            },
                        ],
                    }
                ]

                response = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=messages,  # type: ignore
                )
                result = response.choices[0].message.content if response.choices else ""
                logger.debug("Successfully received response from OpenAI")
                return result.strip() if result else ""
            except Exception as exc:
                logger.warning(f"Attempt {attempt} failed to summarize charts: {exc}")
                last_exception = exc
                if attempt < self.retries:
                    logger.debug(f"Retrying in {self.timeout} seconds...")
                    time.sleep(self.timeout)

        logger.error(f"Failed to summarize charts after {self.retries} attempts: {last_exception}")
        raise RuntimeError(f"Failed to summarize charts: {last_exception}") from last_exception
