"""State-space model (SSM) utility functions.

This module provides numerical stability utilities for NumPy-based SSM operations:
- Matrix validation and cleaning
- Covariance matrix stabilization
- Safe matrix operations (inverse, determinant)
- Numerical stability for Kalman filtering and EM algorithm

These utilities are critical for numerical stability, handling near-singular matrices
and ensuring positive definiteness for covariance matrices.
"""

from typing import Optional, Union, TYPE_CHECKING
import numpy as np

if TYPE_CHECKING:
    import torch

from ..logger import get_logger

_logger = get_logger(__name__)

# Default numerical stability constants for NumPy operations
DEFAULT_MIN_EIGENVAL = 1e-8
DEFAULT_MIN_DIAGONAL_VARIANCE = 1e-6
DEFAULT_INV_REGULARIZATION = 1e-6


def _to_numpy(x: Union["torch.Tensor", np.ndarray]) -> np.ndarray:
    """Convert input to NumPy array, handling both torch.Tensor and np.ndarray.
    
    Parameters
    ----------
    x : torch.Tensor or np.ndarray
        Input to convert
        
    Returns
    -------
    np.ndarray
        NumPy array
    """
    if isinstance(x, np.ndarray):
        return x
    # Handle torch.Tensor
    try:
        import torch
        if isinstance(x, torch.Tensor):
            return x.detach().cpu().numpy()
    except ImportError:
        pass
    # Fallback: try np.asarray
    return np.asarray(x)


def check_finite(arr: Union["torch.Tensor", np.ndarray], name: str = "array") -> bool:
    """Check if array contains only finite values.
    
    Parameters
    ----------
    arr : torch.Tensor or np.ndarray
        Array to check
    name : str
        Name for error messages
        
    Returns
    -------
    bool
        True if array is finite, False otherwise
    """
    arr_np = _to_numpy(arr)
    has_nan = np.any(np.isnan(arr_np))
    has_inf = np.any(np.isinf(arr_np))
    
    if has_nan or has_inf:
        nan_count = np.sum(np.isnan(arr_np))
        inf_count = np.sum(np.isinf(arr_np))
        msg = f"{name} contains "
        issues = []
        if nan_count > 0:
            issues.append(f"{nan_count} NaN values")
        if inf_count > 0:
            issues.append(f"{inf_count} Inf values")
        msg += " and ".join(issues)
        _logger.warning(msg)
        return False
    return True


def ensure_real(arr: Union["torch.Tensor", np.ndarray]) -> np.ndarray:
    """Ensure array is real by extracting real part if complex.
    
    Parameters
    ----------
    arr : torch.Tensor or np.ndarray
        Array to ensure is real
        
    Returns
    -------
    np.ndarray
        Real array
    """
    arr_np = _to_numpy(arr)
    if np.iscomplexobj(arr_np):
        return np.real(arr_np)
    return arr_np


def ensure_symmetric(arr: Union["torch.Tensor", np.ndarray]) -> np.ndarray:
    """Ensure matrix is symmetric by averaging with its transpose.
    
    Parameters
    ----------
    arr : torch.Tensor or np.ndarray
        Matrix to symmetrize
        
    Returns
    -------
    np.ndarray
        Symmetric matrix
    """
    arr_np = _to_numpy(arr)
    return 0.5 * (arr_np + arr_np.T)


def cap_max_eigenval(
    M: Union["torch.Tensor", np.ndarray], 
    max_eigenval: float = 1e6,
    warn: bool = True
) -> np.ndarray:
    """Cap maximum eigenvalue of matrix to prevent numerical explosion.
    
    Parameters
    ----------
    M : torch.Tensor or np.ndarray
        Matrix to cap (square matrix)
    max_eigenval : float, default 1e6
        Maximum allowed eigenvalue
    warn : bool, default True
        Whether to log warnings
        
    Returns
    -------
    np.ndarray
        Matrix with capped eigenvalues
    """
    M_np = _to_numpy(M)
    if M_np.size == 0 or M_np.shape[0] == 0:
        return M_np
    
    try:
        eigenvals = np.linalg.eigvalsh(M_np)
        max_eig = np.max(eigenvals)
        
        if max_eig > max_eigenval:
            # Scale matrix to cap maximum eigenvalue
            scale_factor = max_eigenval / max_eig
            M_np = M_np * scale_factor
            M_np = ensure_symmetric(M_np)
            if warn:
                _logger.warning(
                    f"Matrix maximum eigenvalue capped: {max_eig:.2e} -> {max_eigenval:.2e} "
                    f"(scale_factor={scale_factor:.2e})"
                )
    except (RuntimeError, ValueError, np.linalg.LinAlgError):
        # If eigendecomposition fails, return matrix as-is
        pass
    
    return M_np


def ensure_real_and_symmetric(arr: Union["torch.Tensor", np.ndarray]) -> np.ndarray:
    """Ensure matrix is real and symmetric.
    
    Parameters
    ----------
    arr : torch.Tensor or np.ndarray
        Matrix to process
        
    Returns
    -------
    np.ndarray
        Real and symmetric matrix
    """
    arr_np = ensure_real(arr)
    arr_np = ensure_symmetric(arr_np)
    return arr_np


def ensure_positive_definite(
    M: Union["torch.Tensor", np.ndarray], 
    min_eigenval: float = DEFAULT_MIN_EIGENVAL, 
    warn: bool = True
) -> np.ndarray:
    """Ensure matrix is positive semi-definite by adding regularization if needed.
    
    Parameters
    ----------
    M : torch.Tensor or np.ndarray
        Matrix to stabilize
    min_eigenval : float
        Minimum eigenvalue to enforce
    warn : bool
        Whether to log warnings
        
    Returns
    -------
    np.ndarray
        Positive semi-definite matrix
    """
    M_np = ensure_symmetric(M)
    
    if M_np.size == 0 or M_np.shape[0] == 0:
        return M_np
    
    try:
        eigenvals = np.linalg.eigvalsh(M_np)
        min_eig = float(np.min(eigenvals))
        
        if min_eig < min_eigenval:
            reg_amount = min_eigenval - min_eig
            M_np = M_np + np.eye(M_np.shape[0], dtype=M_np.dtype) * reg_amount
            M_np = ensure_symmetric(M_np)
            if warn:
                _logger.warning(
                    f"Matrix regularization applied: min eigenvalue {min_eig:.2e} < {min_eigenval:.2e}, "
                    f"added {reg_amount:.2e} to diagonal. This biases the covariance matrix."
                )
    except (RuntimeError, ValueError, np.linalg.LinAlgError) as e:
        M_np = M_np + np.eye(M_np.shape[0], dtype=M_np.dtype) * min_eigenval
        M_np = ensure_symmetric(M_np)
        if warn:
            _logger.warning(
                f"Matrix regularization applied (eigendecomposition failed: {e}). "
                f"Added {min_eigenval:.2e} to diagonal. This biases the covariance matrix."
            )
    
    return M_np


def ensure_covariance_stable(
    M: Union["torch.Tensor", np.ndarray], 
    min_eigenval: float = DEFAULT_MIN_EIGENVAL,
    ensure_real_flag: bool = True
) -> np.ndarray:
    """Ensure covariance matrix is real, symmetric, and positive semi-definite.
    
    Parameters
    ----------
    M : torch.Tensor or np.ndarray
        Covariance matrix to stabilize
    min_eigenval : float
        Minimum eigenvalue to enforce
    ensure_real_flag : bool
        Whether to ensure matrix is real (renamed to avoid conflict with ensure_real function)
        
    Returns
    -------
    np.ndarray
        Stable covariance matrix
    """
    M_np = _to_numpy(M)
    if M_np.size == 0 or M_np.shape[0] == 0:
        return M_np
    
    # Step 1: Ensure real (if needed)
    if ensure_real_flag:
        M_np = ensure_real(M_np)  # Call the ensure_real function
    
    # Step 2: Ensure symmetric and positive semi-definite
    M_np = ensure_positive_definite(M_np, min_eigenval=min_eigenval, warn=False)
    
    return M_np


def clean_matrix(
    M: Union["torch.Tensor", np.ndarray], 
    matrix_type: str = 'general', 
    default_nan: float = 0.0, 
    default_inf: Optional[float] = None,
    min_eigenval: float = DEFAULT_MIN_EIGENVAL,
    min_diagonal_variance: float = DEFAULT_MIN_DIAGONAL_VARIANCE
) -> np.ndarray:
    """Clean matrix by removing NaN/Inf values and ensuring numerical stability.
    
    Parameters
    ----------
    M : torch.Tensor or np.ndarray
        Matrix to clean
    matrix_type : str
        Type of matrix: 'covariance', 'diagonal', 'loading', or 'general'
    default_nan : float
        Default value for NaN replacement
    default_inf : float, optional
        Default value for Inf replacement
    min_eigenval : float
        Minimum eigenvalue for covariance matrices
    min_diagonal_variance : float
        Minimum diagonal variance for diagonal matrices
        
    Returns
    -------
    np.ndarray
        Cleaned matrix
    """
    M_np = _to_numpy(M)
    if matrix_type == 'covariance':
        M_np = np.nan_to_num(M_np, nan=default_nan, posinf=1e6, neginf=-1e6)
        M_np = ensure_symmetric(M_np)
        try:
            eigenvals = np.linalg.eigvalsh(M_np)
            min_eigenval_val = np.min(eigenvals)
            if min_eigenval_val < min_eigenval:
                M_np = M_np + np.eye(M_np.shape[0], dtype=M_np.dtype) * (min_eigenval - min_eigenval_val)
                M_np = ensure_symmetric(M_np)
        except (RuntimeError, ValueError, np.linalg.LinAlgError):
            M_np = M_np + np.eye(M_np.shape[0], dtype=M_np.dtype) * min_eigenval
            M_np = ensure_symmetric(M_np)
    elif matrix_type == 'diagonal':
        diag = np.diag(M_np)
        default_inf_val = default_inf if default_inf is not None else 1e4
        diag = np.nan_to_num(diag, nan=default_nan, posinf=default_inf_val, neginf=default_nan)
        diag = np.clip(diag, a_min=min_diagonal_variance, a_max=None)
        M_np = np.diag(diag)
    elif matrix_type == 'loading':
        M_np = np.nan_to_num(M_np, nan=default_nan, posinf=1.0, neginf=-1.0)
    else:
        default_inf_val = default_inf if default_inf is not None else 1e6
        M_np = np.nan_to_num(M_np, nan=default_nan, posinf=default_inf_val, neginf=-default_inf_val)
    return M_np


def safe_inverse(
    M: Union["torch.Tensor", np.ndarray],
    regularization: float = DEFAULT_INV_REGULARIZATION,
    use_pinv_fallback: bool = True
) -> np.ndarray:
    """Safely compute matrix inverse with robust error handling.
    
    This function implements a progressive fallback strategy for matrix inversion:
    1. Try standard np.linalg.inv()
    2. If that fails, try regularized inversion
    3. If that fails, use pseudo-inverse (if enabled)
    
    This is critical for numerical stability, handling singular or near-singular matrices.
    
    Parameters
    ----------
    M : torch.Tensor or np.ndarray
        Matrix to invert (must be square)
    regularization : float
        Regularization amount to add to diagonal before inversion
    use_pinv_fallback : bool
        Whether to use pseudo-inverse as final fallback
        
    Returns
    -------
    np.ndarray
        Inverse of M (or pseudo-inverse if standard inversion fails)
    """
    M_np = _to_numpy(M)
    dtype = M_np.dtype
    
    try:
        # First try: standard inversion (fastest)
        return np.linalg.inv(M_np)
    except (RuntimeError, ValueError, np.linalg.LinAlgError) as e:
        # Second try: regularized inversion
        try:
            M_reg = M_np + np.eye(M_np.shape[0], dtype=dtype) * regularization
            return np.linalg.inv(M_reg)
        except (RuntimeError, ValueError, np.linalg.LinAlgError):
            # Third try: pseudo-inverse (most robust)
            if use_pinv_fallback:
                M_reg = M_np + np.eye(M_np.shape[0], dtype=dtype) * regularization
                return np.linalg.pinv(M_reg)
            else:
                raise RuntimeError(f"Matrix inversion failed and pinv fallback disabled: {e}")


def safe_determinant(M: Union["torch.Tensor", np.ndarray], use_logdet: bool = True) -> float:
    """Compute determinant safely to avoid overflow warnings.
    
    Parameters
    ----------
    M : torch.Tensor or np.ndarray
        Matrix for which to compute determinant
    use_logdet : bool
        Whether to use log-determinant computation (default: True)
        
    Returns
    -------
    float
        Determinant of M, or 0.0 if computation fails
    """
    M_np = _to_numpy(M)
    if M_np.size == 0 or M_np.shape[0] == 0:
        return 0.0
    
    if M_np.shape[0] != M_np.shape[1]:
        _logger.debug("safe_determinant: non-square matrix, returning 0.0")
        return 0.0
    
    # Check for NaN/Inf
    if not np.all(np.isfinite(M_np)):
        _logger.debug("safe_determinant: matrix contains NaN/Inf, returning 0.0")
        return 0.0
    
    # For small matrices, direct computation is safe
    if M_np.shape[0] <= 2:
        try:
            det = np.linalg.det(M_np)
            if np.isfinite(det):
                return float(det)
        except (RuntimeError, ValueError, np.linalg.LinAlgError):
            pass
    
    # Use log-determinant for stability
    if use_logdet:
        try:
            # Try Cholesky decomposition first (more stable for PSD matrices)
            try:
                L = np.linalg.cholesky(M_np)
                log_det = 2.0 * np.sum(np.log(np.diag(L)))
                if log_det > 700:  # exp(700) is near float64 max
                    _logger.debug("safe_determinant: log_det too large, returning 0.0")
                    return 0.0
                det = np.exp(log_det)
                if np.isfinite(det) and det > 0:
                    return float(det)
            except (RuntimeError, ValueError, np.linalg.LinAlgError):
                # Not PSD: fall back to slogdet for general matrices
                try:
                    sign, log_det = np.linalg.slogdet(M_np)
                    if not np.isfinite(log_det) or sign <= 0:
                        return 0.0
                    if log_det > 700:
                        _logger.debug("safe_determinant: log_det too large, returning 0.0")
                        return 0.0
                    det = np.exp(log_det)
                    if np.isfinite(det):
                        return float(det)
                except (RuntimeError, ValueError, np.linalg.LinAlgError):
                    pass
        except (RuntimeError, ValueError, np.linalg.LinAlgError):
            pass
    
    # Fallback: direct computation
    try:
        det = np.linalg.det(M_np)
        if np.isfinite(det):
            return float(det)
    except (RuntimeError, ValueError, np.linalg.LinAlgError):
        pass
    
    _logger.debug("safe_determinant: all methods failed, returning 0.0")
    return 0.0

