import threading
from abc import ABC, abstractmethod
from typing import Optional

from prompt_toolkit import print_formatted_text
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.styles import Style

from python_tty.runtime.context import get_current_emitter, get_current_run_id, get_current_source
from python_tty.runtime.events import RuntimeEvent, RuntimeEventKind, UIEvent, UIEventLevel
from python_tty.runtime.provider import get_router


MSG_LEVEL_SYMBOL = {
    0: "[*] ",
    1: "[!] ",
    2: "[x] ",
    3: "[+] ",
    4: "[-] ",
    5: "[@] "
}

MSG_LEVEL_SYMBOL_STYLE = {
    0: "fg:green",
    1: "fg:yellow",
    2: "fg:red",
    3: "fg:blue",
    4: "fg:white",
    5: "fg:pink"
}


class BaseRouter(ABC):
    @abstractmethod
    def emit(self, event):
        raise NotImplementedError


class OutputRouter(BaseRouter):
    def __init__(self):
        self._lock = threading.Lock()
        self._app = None
        self._output = None

    def bind_session(self, session):
        if session is None:
            return
        with self._lock:
            self._app = getattr(session, "app", None)
            self._output = getattr(session, "output", None)

    def clear_session(self, session=None):
        with self._lock:
            if session is None or getattr(session, "app", None) == self._app:
                self._app = None
                self._output = None

    def emit(self, event):
        audit_event = event
        if isinstance(event, RuntimeEvent):
            if event.kind in (RuntimeEventKind.STDOUT, RuntimeEventKind.STATE, RuntimeEventKind.LOG):
                event = event.to_ui_event()
            else:
                return
        with self._lock:
            app = self._app
            output = self._output

        def _render():
            text, style = _format_event(event)
            if output is not None:
                print_formatted_text(text, style=style, output=output)
            else:
                print_formatted_text(text, style=style)

        if app is not None and getattr(app, "is_running", False):
            if hasattr(app, "call_from_executor") and hasattr(app, "run_in_terminal"):
                app.call_from_executor(lambda: app.run_in_terminal(_render))
                return
        _render()


def _normalize_level(level):
    if isinstance(level, UIEventLevel):
        return level
    if level is None:
        return UIEventLevel.TEXT
    if level == UIEventLevel.TEXT.value:
        return UIEventLevel.TEXT
    mapped = UIEventLevel.map_level(level)
    return UIEventLevel.TEXT if mapped is None else mapped


def _format_event(event: UIEvent):
    level = _normalize_level(event.level)
    if level == UIEventLevel.TEXT:
        return event.msg, None
    formatted_text = FormattedText([
        ("class:level", MSG_LEVEL_SYMBOL[level.value]),
        ("class:text", str(event.msg)),
    ])
    style = Style.from_dict({
        "level": MSG_LEVEL_SYMBOL_STYLE[level.value]
    })
    return formatted_text, style


def get_output_router() -> Optional[BaseRouter]:
    return get_router()


def proxy_print(text="", text_type=UIEventLevel.TEXT, source="custom", run_id=None):
    """Emit a UIEvent for display.

    Args:
        text: Display text or object to render.
        text_type: UIEventLevel or int.
        source: Event source. Use "tty"/"rpc" for framework events.
            External callers can rely on the default "custom".
        run_id: Optional run identifier to correlate output with an invocation.
    """
    level = _normalize_level(text_type)
    context_run_id = get_current_run_id()
    emitter = get_current_emitter()
    if context_run_id is not None and emitter is not None:
        kind = RuntimeEventKind.STDOUT if level == UIEventLevel.TEXT else RuntimeEventKind.LOG
        event = RuntimeEvent(
            kind=kind,
            msg=text,
            level=level,
            run_id=context_run_id,
            source=get_current_source() or source,
        )
        emitter(event)
        return
    event = UIEvent(msg=text, level=level, source=source, run_id=run_id)
    router = get_router()
    if router is None:
        return
    router.emit(event)
