from __future__ import annotations

from typing import Any, Dict
import base64
import os
from html import escape
from mimetypes import guess_type

from .base import Payload


class ImagePayload(Payload):
    def __init__(self, location: str) -> None:
        super().__init__()
        self._location = location

    def get_type(self) -> str:
        return "custom"

    def get_content(self) -> Dict[str, Any]:
        location = self._location

        # If the location is a local file, read it and embed it as a data URL so
        # Buggregator/Ray can render it even when they run in a different
        # environment (e.g. a Docker container without access to the host
        # filesystem).
        if os.path.isfile(location):
            try:
                with open(location, "rb") as fh:
                    raw = fh.read()
                b64 = base64.b64encode(raw).decode("ascii")
            except Exception:
                b64 = ""

            mime, _ = guess_type(location)
            if not mime or not mime.startswith("image/"):
                mime = "image/png"

            if b64:
                location = f"data:{mime};base64,{b64}"
            else:
                # Fall back to a file:// URL if we could not read/encode.
                location = "file://" + os.path.abspath(location)
        else:
            # Non-file locations (http(s), existing data URLs, etc.) are
            # preserved, but we still normalize explicit base64 strings.
            data = self._strip_data_prefix(location)
            if self._is_base64_data(data):
                location = "data:image/png;base64," + data

        location = location.replace('"', "")
        src = escape(location, quote=True)
        return {
            "content": f"<img src=\"{src}\" alt=\"\" />",
            "label": "Image",
        }

    def _strip_data_prefix(self, data: str) -> str:
        prefix = "data:image/"
        if data.startswith(prefix):
            # find ";base64,"
            try:
                return data.split(",", 1)[1]
            except IndexError:
                return data
        return data

    def _is_base64_data(self, data: str) -> bool:
        try:
            return base64.b64encode(base64.b64decode(data)) == data.encode()
        except Exception:
            return False
