# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Progress reporting infrastructure for workflows.

Provides a clean way to report workflow progress without coupling
UI logic to business logic.
"""

from typing import Protocol, List, Callable, Optional, Dict, Any
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime


class ProgressLevel(str, Enum):
    """Progress message levels."""
    INFO = "info"
    SUCCESS = "success"
    WARNING = "warning"
    ERROR = "error"
    DEBUG = "debug"


@dataclass
class ProgressEvent:
    """
    Progress event data.
    
    Represents a single progress update from a workflow.
    """
    message: str
    """Progress message"""
    
    percentage: int
    """Completion percentage (0-100)"""
    
    level: ProgressLevel = ProgressLevel.INFO
    """Message level/severity"""
    
    timestamp: datetime = None
    """When the event occurred"""
    
    details: Dict[str, Any] = field(default_factory=dict)
    """Additional context data"""
    
    def __post_init__(self):
        if self.timestamp is None:
            self.timestamp = datetime.now()


class ProgressCallback(Protocol):
    """
    Progress callback protocol.
    
    Any callable that accepts a ProgressEvent can be used as a listener.
    """
    
    def __call__(self, event: ProgressEvent) -> None:
        """Handle progress event."""
        ...


class ProgressReporter:
    """
    Progress reporter with multiple listeners support.
    
    Uses Observer pattern to allow multiple consumers to receive
    progress updates independently.
    
    Example:
        >>> reporter = ProgressReporter()
        >>> reporter.add_listener(console_listener)
        >>> reporter.add_listener(log_listener)
        >>> reporter.report("Building image", 50)
    """
    
    def __init__(self):
        self._listeners: List[ProgressCallback] = []
        self._enabled = True
    
    def add_listener(self, listener: ProgressCallback):
        """
        Add a progress listener.
        
        Args:
            listener: Callable that accepts ProgressEvent.
        """
        if listener not in self._listeners:
            self._listeners.append(listener)
    
    def remove_listener(self, listener: ProgressCallback):
        """
        Remove a progress listener.
        
        Args:
            listener: Previously added listener.
        """
        if listener in self._listeners:
            self._listeners.remove(listener)
    
    def clear_listeners(self):
        """Remove all listeners."""
        self._listeners.clear()
    
    def report(self, 
               message: str, 
               percentage: int = 0,
               level: ProgressLevel = ProgressLevel.INFO,
               **details):
        """
        Report progress to all listeners.
        
        Args:
            message: Progress message.
            percentage: Completion percentage (0-100).
            level: Message level.
            **details: Additional context to include in event.
        
        Example:
            >>> reporter.report("Uploading file", 75, 
            ...                 level=ProgressLevel.INFO,
            ...                 file_size=1024000)
        """
        if not self._enabled:
            return
        
        event = ProgressEvent(
            message=message,
            percentage=percentage,
            level=level,
            details=details
        )
        
        for listener in self._listeners:
            try:
                listener(event)
            except Exception:
                # Don't let listener errors break the workflow
                # Silently ignore to maintain business logic integrity
                pass
    
    def disable(self):
        """
        Disable progress reporting.
        
        Useful for testing or when progress output is not desired.
        """
        self._enabled = False
    
    def enable(self):
        """Enable progress reporting."""
        self._enabled = True
    
    @property
    def is_enabled(self) -> bool:
        """Check if reporter is enabled."""
        return self._enabled
    
    @property
    def listener_count(self) -> int:
        """Get number of active listeners."""
        return len(self._listeners)


# ============================================================================
# Built-in Listeners
# ============================================================================

def create_cli_progress_listener(console=None):
    """
    Create a progress listener for CLI output using Rich.
    
    Args:
        console: Optional Rich Console instance. If None, creates new one.
        
    Returns:
        Listener function that can be added to ProgressReporter.
        
    Example:
        >>> from rich.console import Console
        >>> reporter = ProgressReporter()
        >>> reporter.add_listener(create_cli_progress_listener())
        >>> reporter.report("Building", 50)
        50% ▶ Building
    """
    from rich.console import Console
    
    console = console or Console()
    
    def listener(event: ProgressEvent):
        color_map = {
            ProgressLevel.INFO: "cyan",
            ProgressLevel.SUCCESS: "green",
            ProgressLevel.WARNING: "yellow",
            ProgressLevel.ERROR: "red",
            ProgressLevel.DEBUG: "dim"
        }
        color = color_map.get(event.level, "white")
        
        if event.percentage > 0:
            console.print(f"[{color}]{event.percentage:3d}% ▶ {event.message}[/{color}]")
        else:
            console.print(f"[{color}]▶ {event.message}[/{color}]")
    
    return listener


def create_simple_print_listener():
    """
    Create a simple print-based listener.
    
    Useful for environments where Rich is not available.
    
    Returns:
        Listener function.
    """
    def listener(event: ProgressEvent):
        if event.percentage > 0:
            print(f"[{event.level.value.upper()}] {event.percentage:3d}% - {event.message}")
        else:
            print(f"[{event.level.value.upper()}] {event.message}")
    
    return listener


def create_logging_listener(logger=None):
    """
    Create a listener that writes to a logger.
    
    Args:
        logger: Optional logger instance. If None, uses root logger.
        
    Returns:
        Listener function.
        
    Example:
        >>> import logging
        >>> logger = logging.getLogger(__name__)
        >>> reporter.add_listener(create_logging_listener(logger))
    """
    import logging
    
    logger = logger or logging.getLogger(__name__)
    
    level_map = {
        ProgressLevel.DEBUG: logging.DEBUG,
        ProgressLevel.INFO: logging.INFO,
        ProgressLevel.SUCCESS: logging.INFO,
        ProgressLevel.WARNING: logging.WARNING,
        ProgressLevel.ERROR: logging.ERROR,
    }
    
    def listener(event: ProgressEvent):
        log_level = level_map.get(event.level, logging.INFO)
        logger.log(log_level, f"{event.percentage}% - {event.message}", 
                   extra=event.details)
    
    return listener


def create_test_progress_listener():
    """
    Create a listener for testing that captures all events.
    
    Returns:
        Listener function with an 'events' attribute containing all events.
        
    Example:
        >>> listener = create_test_progress_listener()
        >>> reporter.add_listener(listener)
        >>> reporter.report("Test", 50)
        >>> assert len(listener.events) == 1
        >>> assert listener.events[0].message == "Test"
    """
    events = []
    
    def listener(event: ProgressEvent):
        events.append(event)
    
    # Attach events list to function for easy access
    listener.events = events
    return listener


def create_websocket_listener(websocket):
    """
    Create a listener that sends progress via WebSocket.
    
    Args:
        websocket: WebSocket connection object with send() method.
        
    Returns:
        Listener function.
        
    Example:
        >>> reporter.add_listener(create_websocket_listener(ws))
    """
    import json
    
    def listener(event: ProgressEvent):
        data = {
            "type": "progress",
            "message": event.message,
            "percentage": event.percentage,
            "level": event.level.value,
            "timestamp": event.timestamp.isoformat(),
            "details": event.details
        }
        try:
            websocket.send(json.dumps(data))
        except Exception:
            # Ignore WebSocket errors
            pass
    
    return listener
