"""EM algorithm for Dynamic Factor Models.

This module provides EMAlgorithm, a pure NumPy implementation for the Expectation-Maximization
algorithm. Uses Kalman smoother for E-step and closed-form OLS for M-step.
Follows MATLAB Nowcasting reference implementation.
"""

import gc
from typing import Tuple, Optional, Dict, Any, List, Union, TYPE_CHECKING
import numpy as np
from dataclasses import dataclass

if TYPE_CHECKING:
    import torch

from ..logger import get_logger
from ..config.utils import FREQUENCY_HIERARCHY
from .kalman import KalmanFilter
from .utils import ensure_positive_definite, ensure_symmetric, cap_max_eigenval, _to_numpy

_logger = get_logger(__name__)


@dataclass
class EMStepParams:
    """Parameters for a single EM step using NumPy arrays.
    
    This dataclass groups all parameters needed for one EM iteration.
    """
    y: np.ndarray
    A: np.ndarray
    C: np.ndarray
    Q: np.ndarray
    R: np.ndarray
    Z_0: np.ndarray
    V_0: np.ndarray
    r: np.ndarray
    p: int
    R_mat: Optional[np.ndarray]
    q: Optional[np.ndarray]
    nQ: int
    i_idio: np.ndarray
    blocks: np.ndarray
    tent_weights_dict: Dict[str, np.ndarray]
    clock: str
    frequencies: Optional[np.ndarray]
    idio_chain_lengths: np.ndarray
    config: Any  # DFMConfig
    structures_dict: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]] = None  # Maps frequency to (R_mat, q) tuple
    frequencies_list: Optional[List[str]] = None  # Original frequency strings for each series


class EMAlgorithm:
    """NumPy-based EM algorithm.
    
    This class implements the Expectation-Maximization algorithm for Dynamic
    Factor Models. It composes a KalmanFilter for the E-step and performs
    closed-form parameter updates in the M-step.
    
    Parameters
    ----------
    kalman : KalmanFilter, optional
        KalmanFilter instance to use for E-step. If None, creates a new instance.
    regularization_scale : float, default 1e-6
        Regularization scale for matrix operations in M-step
    chunk_threshold : int, default 10000
        Threshold for chunked processing
    chunk_size : int, default 1000
        Chunk size for memory optimization
    """
    
    def __init__(
        self,
        kalman: Optional[KalmanFilter] = None,
        regularization_scale: float = 1e-6,
        chunk_threshold: int = 10000,
        chunk_size: int = 1000
    ):
        # Compose KalmanFilter (create if not provided)
        if kalman is None:
            self.kalman = KalmanFilter()
        else:
            self.kalman = kalman
        self.regularization_scale = regularization_scale
        self.chunk_threshold = chunk_threshold
        self.chunk_size = chunk_size
    
    def _cap_max_eigenval(self, M: np.ndarray, max_eigenval: float = 1e6) -> np.ndarray:
        """Cap maximum eigenvalue of matrix to prevent numerical explosion."""
        return cap_max_eigenval(M, max_eigenval=max_eigenval, warn=False)
    
    def _clean_array(self, arr: np.ndarray, default_value: float = 0.0, 
                     clamp_min: Optional[float] = None, clamp_max: Optional[float] = None) -> np.ndarray:
        """Clean array by removing NaN/Inf and optionally clamping.
        
        Parameters
        ----------
        arr : np.ndarray
            Array to clean
        default_value : float
            Value to replace NaN/Inf with
        clamp_min : float, optional
            Minimum value for clamping
        clamp_max : float, optional
            Maximum value for clamping
            
        Returns
        -------
        np.ndarray
            Cleaned array
        """
        arr = np.nan_to_num(arr, nan=default_value, posinf=default_value, neginf=default_value)
        if clamp_min is not None or clamp_max is not None:
            arr = np.clip(arr, a_min=clamp_min if clamp_min is not None else -np.inf,
                         a_max=clamp_max if clamp_max is not None else np.inf)
        return arr
    
    def _ensure_positive_definite(
        self,
        M: np.ndarray,
        min_eigenval: float = 1e-8,
        dtype: Optional[np.dtype] = None
    ) -> np.ndarray:
        """Ensure matrix is positive definite with robust error handling.
        
        Parameters
        ----------
        M : np.ndarray
            Matrix to ensure positive definite
        min_eigenval : float
            Minimum eigenvalue threshold
        dtype : np.dtype, optional
            Dtype for regularization (uses M.dtype if None)
            
        Returns
        -------
        np.ndarray
            Positive definite matrix
        """
        M = _to_numpy(M)
        if dtype is None:
            dtype = M.dtype
        
        try:
            eigenvals = np.linalg.eigvalsh(M)
            min_eig = np.min(eigenvals)
            if min_eig < min_eigenval:
                M = M + np.eye(M.shape[0], dtype=dtype) * (min_eigenval - min_eig)
                M = ensure_symmetric(M)
        except (RuntimeError, ValueError, np.linalg.LinAlgError):
            # Fallback: apply stronger regularization
            M = M + np.eye(M.shape[0], dtype=dtype) * min_eigenval * 10
            M = ensure_symmetric(M)
            try:
                eigenvals = np.linalg.eigvalsh(M)
                min_eig = np.min(eigenvals)
                if min_eig < min_eigenval:
                    M = M + np.eye(M.shape[0], dtype=dtype) * (min_eigenval - min_eig)
                    M = ensure_symmetric(M)
            except (RuntimeError, ValueError, np.linalg.LinAlgError):
                # Final fallback: use diagonal matrix with minimum variance
                diag_M = np.diag(M)
                diag_M = np.maximum(diag_M, np.ones_like(diag_M) * min_eigenval)
                M = np.diag(diag_M)
        
        return M
    
    def _compute_sum_EZZ(
        self,
        Vsmooth: np.ndarray,
        EZ: np.ndarray,
        chunk_size: Optional[int] = None
    ) -> np.ndarray:
        """Compute sum_EZZ directly without creating full (T x m x m) array.
        
        Uses chunked processing if chunk_size is provided, otherwise processes all at once.
        This avoids creating the massive (T x m x m) EZZ array.
        
        Parameters
        ----------
        Vsmooth : np.ndarray
            Smoothed factor covariance (m x m x T+1)
        EZ : np.ndarray
            Smoothed factor means (T x m)
        chunk_size : int, optional
            Size of chunks for processing. If None, processes all at once.
            
        Returns
        -------
        np.ndarray
            sum_EZZ (m x m) - sum over time of E[Z_t Z_t^T]
        """
        Vsmooth = _to_numpy(Vsmooth)
        EZ = _to_numpy(EZ)
        
        T, m = EZ.shape
        dtype = EZ.dtype
        
        # Initialize sum
        sum_EZZ = np.zeros((m, m), dtype=dtype)
        
        # Determine chunk boundaries
        if chunk_size is None or chunk_size >= T:
            # Process all at once
            chunk_starts = [0]
            chunk_ends = [T]
        else:
            # Process in chunks
            chunk_starts = list(range(0, T, chunk_size))
            chunk_ends = [min(start + chunk_size, T) for start in chunk_starts]
        
        # Process each chunk
        for chunk_start, chunk_end in zip(chunk_starts, chunk_ends):
            # Sum Vsmooth over time: (m x m)
            sum_Vsmooth_chunk = np.sum(Vsmooth[:, :, chunk_start+1:chunk_end+1], axis=2)
            
            # Extract EZ chunk and compute outer products: sum_t EZ[t] @ EZ[t]^T
            EZ_chunk = EZ[chunk_start:chunk_end, :]
            sum_outer_chunk = np.sum(EZ_chunk[:, :, None] * EZ_chunk[:, None, :], axis=0)
            
            # Accumulate
            sum_EZZ = sum_EZZ + sum_Vsmooth_chunk + sum_outer_chunk
            
            # Delete chunk arrays immediately
            del EZ_chunk, sum_Vsmooth_chunk, sum_outer_chunk
        
        return sum_EZZ
    
    def _extract_or_pad_matrix(self, M: np.ndarray, target_size: int, dtype: np.dtype) -> np.ndarray:
        """Extract or pad matrix to target size."""
        M = _to_numpy(M)
        if M.shape[0] >= target_size and M.shape[1] >= target_size:
            return M[:target_size, :target_size]
        else:
            M_final = np.zeros((target_size, target_size), dtype=dtype)
            min_rows = min(M.shape[0], target_size)
            min_cols = min(M.shape[1], target_size)
            M_final[:min_rows, :min_cols] = M[:min_rows, :min_cols]
            return M_final
    
    def _compute_adaptive_regularization(
        self, 
        M: np.ndarray, 
        matrix_name: str = "matrix",
        min_reg: float = 1e-3
    ) -> float:
        """Compute adaptive regularization based on condition number.
        
        Parameters
        ----------
        M : np.ndarray
            Matrix to compute regularization for
        matrix_name : str
            Name for logging
        min_reg : float
            Minimum regularization value
            
        Returns
        -------
        float
            Regularization scale
        """
        base_reg = float(self.regularization_scale)
        reg_scale: float
        try:
            eigenvals = np.linalg.eigvalsh(M)
            eigenvals = eigenvals[eigenvals > 1e-12]
            if len(eigenvals) > 0:
                max_eig = np.max(eigenvals)
                min_eig = np.min(eigenvals)
                cond_num = float(max_eig / min_eig) if min_eig > 1e-12 else float('inf')
                
                if cond_num > 1e8:
                    reg_scale = base_reg * (cond_num / 1e8)
                    _logger.debug(f"EM: {matrix_name} ill-conditioned (cond={cond_num:.2e}), reg={reg_scale:.2e}")
                else:
                    reg_scale = base_reg
            else:
                reg_scale = max(base_reg * 100.0, min_reg)
                _logger.warning(f"EM: {matrix_name} has no valid eigenvalues, using reg={reg_scale:.2e}")
        except (RuntimeError, ValueError, np.linalg.LinAlgError) as e:
            reg_scale = max(base_reg * 10.0, min_reg)
            _logger.warning(f"EM: Failed to compute condition number for {matrix_name} ({e}), using reg={reg_scale:.2e}")
        
        return float(reg_scale)
    
    def forward(
        self,
        params: EMStepParams
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, float]:
        """Perform EM step. Main entry point.
        
        Follows MATLAB EMstep() function structure (lines 243-543 in dfm.m).
        
        Parameters
        ----------
        params : EMStepParams
            Parameters for this EM step
            
        Returns
        -------
        C : np.ndarray
            Updated observation matrix (N x m)
        R : np.ndarray
            Updated observation covariance (N x N)
        A : np.ndarray
            Updated transition matrix (m x m)
        Q : np.ndarray
            Updated process noise covariance (m x m)
        Z_0 : np.ndarray
            Updated initial state (m,)
        V_0 : np.ndarray
            Updated initial covariance (m x m)
        loglik : float
            Log-likelihood value
        """
        # Convert all parameters to NumPy arrays
        y = _to_numpy(params.y)
        A = _to_numpy(params.A)
        C = _to_numpy(params.C)
        Q = _to_numpy(params.Q)
        R = _to_numpy(params.R)
        Z_0 = _to_numpy(params.Z_0)
        V_0 = _to_numpy(params.V_0)
        
        dtype = y.dtype
        
        # E-step: Kalman smoother (uses self.kalman, NumPy-based)
        zsmooth, Vsmooth, VVsmooth, loglik = self.kalman.forward(
            y, A, C, Q, R, Z_0, V_0
        )
        
        # zsmooth is m x (T+1), transpose to (T+1) x m
        Zsmooth = zsmooth.T
        
        T = y.shape[1]
        m = A.shape[0]
        N = C.shape[0]
        
        # Extract smoothed moments needed for M-step
        # E[Z_t | y_{1:T}]: smoothed factor means
        EZ = Zsmooth[1:, :]  # T x m (skip initial state)
        
        # MEMORY OPTIMIZATION: Compute sum_EZZ directly without creating full (T x m x m) array
        # For large T (e.g., T=50,000, m=200), creating full EZZ array would use ~8GB
        # Instead, compute sum directly: sum_t (Vsmooth[:, :, t+1] + outer(EZ[t], EZ[t]))
        chunk_size = self.chunk_size if T > self.chunk_threshold else None
        sum_EZZ = self._compute_sum_EZZ(Vsmooth, EZ, chunk_size=chunk_size)
        
        # Delete large intermediate arrays to free memory immediately
        # VVsmooth is no longer needed (A update uses EZ directly, not VVsmooth)
        del VVsmooth
        
        # M-step: Update parameters via regressions
        
        # Update A (transition matrix): regression of Z_t on Z_{t-1}
        if T > 1:
            # Prepare data: Y = Z_t, X = Z_{t-1}
            Y_A = EZ[1:, :]  # (T-1) x m
            X_A = EZ[:-1, :]  # (T-1) x m
            
            # OLS: A = (X'X)^{-1} X'Y
            try:
                # Compute XTX = sum_t X_t X_t^T (vectorized: batch outer products)
                XTX_A = np.sum(X_A[:, :, None] * X_A[:, None, :], axis=0)
                # Compute XTY = sum_t X_t Y_t^T (vectorized: batch outer products)
                XTY_A = np.sum(X_A[:, :, None] * Y_A[:, None, :], axis=0)
                
                # Adaptive regularization based on condition number
                reg_scale = self._compute_adaptive_regularization(XTX_A, "XTX_A", min_reg=1e-6)
                XTX_A_reg = XTX_A + np.eye(m, dtype=dtype) * reg_scale
                A_new = np.linalg.solve(XTX_A_reg, XTY_A).T
                
                # Ensure stability
                eigenvals_A = np.linalg.eigvals(A_new)
                max_eigenval = np.max(np.abs(eigenvals_A))
                if max_eigenval >= 0.99:
                    A_new = A_new * (0.99 / max_eigenval)
            except (RuntimeError, ValueError, np.linalg.LinAlgError):
                A_new = A.copy()
            
            # Clean and clip (applied in both success and error cases)
            A_new = self._clean_array(A_new, default_value=0.0, clamp_min=-0.99, clamp_max=0.99)
        else:
            A_new = A.copy()
        
        # Update C (observation matrix): regression of y_t on Z_t
        # Set NaN to 0 for M-step calculations
        y_for_mstep = y.copy()
        y_for_mstep[np.isnan(y_for_mstep)] = 0.0
        
        # C = (sum_t y_t E[Z_t^T]) (sum_t E[Z_t Z_t^T])^{-1}
        try:
            # Compute sum_yEZ = sum_t y_t E[Z_t^T] (vectorized: batch outer products)
            # y is (N, T), EZ is (T, m)
            # Transpose y to (T, N) for batch operations
            # Use y_for_mstep (NaN replaced with 0) for M-step
            sum_yEZ = np.sum(y_for_mstep.T[:, :, None] * EZ[:, None, :], axis=0)  # (N, m)
            
            # sum_EZZ is already computed above (directly, without creating full EZZ array)
            # Check and clean NaN/Inf in sum_EZZ
            if np.any(np.isnan(sum_EZZ) | np.isinf(sum_EZZ)):
                corrupted_count = np.sum(np.isnan(sum_EZZ) | np.isinf(sum_EZZ))
                _logger.warning(f"EM: sum_EZZ contains {corrupted_count}/{sum_EZZ.size} NaN/Inf, cleaning")
                sum_EZZ = self._clean_array(sum_EZZ, default_value=0.0)
            
            # Cap maximum eigenvalue to prevent condition number explosion
            sum_EZZ = self._cap_max_eigenval(sum_EZZ, max_eigenval=1e6)
            
            # Adaptive regularization based on condition number
            reg_scale = self._compute_adaptive_regularization(sum_EZZ, "sum_EZZ", min_reg=1e-3)
            sum_EZZ_reg = sum_EZZ + np.eye(m, dtype=dtype) * reg_scale
            
            # Use pseudo-inverse as fallback when solve fails
            try:
                C_new = np.linalg.solve(sum_EZZ_reg.T, sum_yEZ.T).T
            except (RuntimeError, ValueError, np.linalg.LinAlgError) as e:
                _logger.warning(f"EM: solve failed for C matrix ({e}), using pseudo-inverse fallback")
                C_new = (np.linalg.pinv(sum_EZZ_reg.T) @ sum_yEZ.T).T
            
            # Handle NaN in C_new
            if np.any(np.isnan(C_new)):
                nan_mask = np.isnan(C_new)
                nan_count = np.sum(nan_mask)
                nan_ratio = nan_count / C_new.size
                _logger.warning(f"EM: C matrix contains {nan_count}/{C_new.size} NaN ({nan_ratio:.1%})")
                
                # Preserve previous iteration values if available
                if not np.any(np.isnan(C)):
                    C_new[nan_mask] = C[nan_mask]
                    if nan_ratio > 0.1:
                        _logger.warning(f"EM: Preserved {nan_count} NaN values from previous iteration")
                else:
                    # If previous C also has NaN, set to zero as last resort
                    C_new[nan_mask] = 0.0
                    _logger.warning(
                        f"Previous C matrix also contains NaN. Set {np.sum(nan_mask)} NaN values to zero."
                    )
            
            # Normalize C columns (factor loadings)
            for j in range(m):
                norm = np.linalg.norm(C_new[:, j])
                if norm > 1e-8:
                    C_new[:, j] = C_new[:, j] / norm
                elif norm < 1e-8:
                    # Very small norm: set column to zero to avoid division issues
                    C_new[:, j] = 0.0
                    _logger.debug(f"C matrix column {j} has very small norm ({norm:.2e}), set to zero.")
        except (RuntimeError, ValueError, np.linalg.LinAlgError) as e:
            _logger.warning(f"EM algorithm: Error updating C matrix: {e}. Keeping previous C matrix.")
            C_new = C.copy()
            # Check if previous C also contains NaN
            if np.any(np.isnan(C_new)):
                _logger.error(
                    f"EM algorithm: Previous C matrix also contains NaN. "
                    f"This indicates the model cannot be trained with current data/parameters."
                )
        
        # Update Q (process noise covariance): residual covariance from transition
        if T > 1:
            # Vectorized: residuals_Q = EZ[1:] - (A_new @ EZ[:-1].T).T
            residuals_Q = EZ[1:, :] - (A_new @ EZ[:-1, :].T).T
            # Handle single factor case
            if residuals_Q.shape[1] == 1:
                var_val = np.var(residuals_Q, axis=0, ddof=0)
                # Ensure var_val is scalar, then reshape to (1, 1)
                if var_val.ndim == 0:
                    Q_new = var_val.reshape(1, 1)  # (1, 1)
                else:
                    Q_new = var_val.reshape(1, 1)  # (1, 1)
            else:
                Q_new = np.cov(residuals_Q.T)
                Q_new = ensure_symmetric(Q_new)
            
            # Ensure positive definite (with robust error handling)
            Q_new = self._ensure_positive_definite(Q_new, min_eigenval=1e-8, dtype=dtype)
            
            # Floor for Q and clean
            Q_new = np.maximum(Q_new, np.eye(m, dtype=dtype) * 0.01)
            Q_new = ensure_symmetric(Q_new)
            Q_new = self._clean_array(Q_new, default_value=0.01, clamp_max=1e6)
            
            # Final check: ensure positive definite with stronger regularization
            Q_new = ensure_positive_definite(Q_new, min_eigenval=1e-6, warn=False)
        else:
            Q_new = Q.copy()
        
        # Update R (observation covariance): residual covariance from observation
        # Vectorized: residuals_R = y_for_mstep.T - (C_new @ EZ.T).T
        # y_for_mstep is (N, T), EZ is (T, m), C_new is (N, m)
        # Use y_for_mstep (NaN replaced with 0) for M-step (MATLAB behavior)
        residuals_R = y_for_mstep.T - (C_new @ EZ.T).T  # (T, N)
        # Handle single series case
        if residuals_R.shape[1] == 1:
            var_val = np.var(residuals_R, axis=0, ddof=0)
            R_new = var_val.reshape(1, 1)  # (1, 1)
        else:
            R_new = np.cov(residuals_R.T)
            R_new = (R_new + R_new.T) / 2
        
        # Ensure R is diagonal (idiosyncratic variances only)
        if R_new.ndim > 2:
            _logger.warning(f"R_new has unexpected shape: {R_new.shape}, reshaping")
            R_new = R_new.reshape(-1, R_new.shape[-1])[-R_new.shape[-1]:, :]
        elif R_new.ndim == 1:
            R_new = R_new.reshape(1, -1)
        
        # Extract diagonal and create diagonal matrix
        diag_R = np.diag(R_new) if R_new.ndim == 2 else R_new
        if diag_R.ndim > 1:
            diag_R = diag_R.flatten()
        
        # Clean and clamp diagonal
        # Use 1e-4 as minimum (MATLAB default) instead of 1e-6 for better numerical stability
        diag_R = self._clean_array(diag_R, default_value=1e-4, clamp_min=1e-4, clamp_max=1e4)
        R_new = np.diag(diag_R)
        
        # Ensure positive definite (minimum variance floor)
        R_new = np.maximum(R_new, np.eye(N, dtype=dtype) * 1e-4)
        
        # Update Z_0 and V_0 (use first smoothed state)
        Z_0_new = Zsmooth[0, :]  # Initial state
        V_0_new = Vsmooth[:, :, 0]  # Initial covariance
        
        # Check for NaN in all updated parameters before returning
        params_to_check = {
            'C': C_new,
            'A': A_new,
            'Q': Q_new,
            'R': R_new,
            'Z_0': Z_0_new,
            'V_0': V_0_new
        }
        
        nan_detected = False
        for param_name, param_arr in params_to_check.items():
            if np.any(np.isnan(param_arr)):
                nan_count = np.sum(np.isnan(param_arr))
                nan_ratio = nan_count / param_arr.size
                _logger.warning(
                    f"EM algorithm: {param_name} matrix contains {nan_count}/{param_arr.size} NaN values "
                    f"({nan_ratio:.1%}) after M-step. This indicates numerical instability."
                )
                nan_detected = True
                # Replace NaN with previous value or zero
                if param_name == 'C':
                    # For C, we already handled NaN above, but check again
                    if np.any(np.isnan(C_new)):
                        C_new = self._clean_array(C_new, default_value=0.0)
                elif param_name == 'A':
                    A_new = self._clean_array(A_new, default_value=0.0)
                elif param_name == 'Q':
                    Q_new = np.where(np.isnan(Q_new), Q.copy(), Q_new)
                elif param_name == 'R':
                    R_new = np.where(np.isnan(R_new), R.copy(), R_new)
                elif param_name == 'Z_0':
                    Z_0_new = np.where(np.isnan(Z_0_new), Z_0.copy(), Z_0_new)
                elif param_name == 'V_0':
                    V_0_new = np.where(np.isnan(V_0_new), V_0.copy(), V_0_new)
        
        if nan_detected:
            _logger.error(
                "EM algorithm: NaN detected in parameter updates. "
                "This usually indicates: (1) singular matrix in solve operations, "
                "(2) extreme data values, (3) insufficient regularization, or "
                "(4) numerical precision issues. Consider increasing regularization_scale "
                "or checking data quality."
            )
        
        # Delete large intermediate arrays to free memory
        # EZ, Zsmooth, zsmooth, and Vsmooth are large arrays that are no longer needed
        # Note: We no longer create EZZ or EZZ_lag1 full arrays (only sums)
        del EZ, Zsmooth, zsmooth, Vsmooth
        
        # Force Python garbage collection for CPU memory management
        gc.collect()
        
        # Ensure V_0 is positive definite (with robust error handling)
        V_0_new = self._ensure_positive_definite(V_0_new, min_eigenval=1e-8, dtype=dtype)
        
        return C_new, R_new, A_new, Q_new, Z_0_new, V_0_new, loglik
    
    def initialize_parameters(
        self,
        x: Union["torch.Tensor", np.ndarray],
        r: Union["torch.Tensor", np.ndarray],
        p: int,
        blocks: Union["torch.Tensor", np.ndarray],
        opt_nan: Dict[str, Any],
        R_mat: Optional[Union["torch.Tensor", np.ndarray]] = None,
        q: Optional[Union["torch.Tensor", np.ndarray]] = None,
        nQ: int = 0,
        i_idio: Optional[Union["torch.Tensor", np.ndarray]] = None,
        clock: str = 'm',
        tent_weights_dict: Optional[Dict[str, Union["torch.Tensor", np.ndarray]]] = None,
        frequencies: Optional[Union["torch.Tensor", np.ndarray]] = None,
        idio_chain_lengths: Optional[Union["torch.Tensor", np.ndarray]] = None,
        config: Optional[Any] = None,
        structures_dict: Optional[Dict[str, Tuple[Union["torch.Tensor", np.ndarray], Union["torch.Tensor", np.ndarray]]]] = None,
        frequencies_list: Optional[List[str]] = None,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Initialize DFM parameters using residual-based PCA (matching MATLAB InitCond).
        
        This method implements the MATLAB InitCond() approach:
        1. Start with residuals = spline-interpolated data
        2. For each block: compute PCA on residuals, extract factors, update residuals
        3. Build transition matrices block-by-block
        4. Handle idiosyncratic components (monthly AR(1) and quarterly 5-state chains)
        
        Parameters
        ----------
        x : torch.Tensor
            Standardized data matrix (T x N)
        r : torch.Tensor
            Number of factors per block (n_blocks,)
        p : int
            AR lag order (typically 1)
        blocks : torch.Tensor
            Block structure array (N x n_blocks)
        opt_nan : dict
            Missing data handling options {'method': int, 'k': int}
        R_mat : torch.Tensor, optional
            Constraint matrix for tent kernel aggregation
        q : torch.Tensor, optional
            Constraint vector for tent kernel aggregation
        nQ : int
            Number of slower-frequency series
        i_idio : torch.Tensor, optional
            Indicator array (1 for clock frequency, 0 for slower frequencies)
        clock : str
            Clock frequency ('d', 'w', 'm', 'q', 'sa', 'a')
        tent_weights_dict : dict, optional
            Dictionary mapping frequency pairs to tent weights
        frequencies : torch.Tensor, optional
            Array of frequencies for each series
        idio_chain_lengths : torch.Tensor, optional
            Array of idiosyncratic chain lengths per series
        config : Any, optional
            Configuration object
            
        Returns
        -------
        A : torch.Tensor
            Initial transition matrix (m x m)
        C : torch.Tensor
            Initial observation/loading matrix (N x m)
        Q : torch.Tensor
            Initial process noise covariance (m x m)
        R : torch.Tensor
            Initial observation noise covariance (N x N)
        Z_0 : torch.Tensor
            Initial state vector (m,)
        V_0 : torch.Tensor
            Initial state covariance (m x m)
        """
        # Convert all inputs to NumPy arrays for processing
        x = _to_numpy(x)
        T, N = x.shape
        dtype_np = x.dtype  # NumPy dtype for internal processing
        
        # Convert blocks and i_idio to numpy for processing
        blocks_np = _to_numpy(blocks) if blocks is not None else None
        i_idio_np = _to_numpy(i_idio) if i_idio is not None else None
        r_np = _to_numpy(r) if r is not None else None
        R_mat_np = _to_numpy(R_mat) if R_mat is not None else None
        q_np = _to_numpy(q) if q is not None else None
        
        # Use numpy versions for processing
        # blocks and r are required (not None)
        if blocks_np is None:
            raise ValueError("blocks cannot be None")
        if r_np is None:
            raise ValueError("r cannot be None")
        
        blocks = blocks_np
        i_idio = i_idio_np  # Can be None, handled in code
        r = r_np
        R_mat = R_mat_np  # Can be None, handled in code
        q = q_np  # Can be None, handled in code
        
        n_blocks = blocks.shape[1]
        # Note: nM and nQ are counts, but series may not be sorted by frequency
        # Use i_idio to identify clock-frequency vs slower-frequency series
        if i_idio is not None:
            nM = int(np.sum(i_idio))  # Number of clock-frequency series
            nQ = N - nM  # Number of slower-frequency series
        else:
            nM = N - nQ  # Fallback: assume first nM are clock-frequency
            nQ = nQ  # Use provided nQ
        
        # Handle missing data for initialization using NumPy-based function
        from ..utils.data import rem_nans_spline
        
        # Convert to numpy if needed
        x_np = _to_numpy(x)
        
        x_clean_np, indNaN_np = rem_nans_spline(
            x_np,
            method=opt_nan.get('method', 2), 
            k=opt_nan.get('k', 3)
        )
        
        # Remove any remaining NaN/inf
        x_clean_np = np.where(np.isfinite(x_clean_np), x_clean_np, 0.0)
        
        # Initialize residuals: res = x_clean (spline-interpolated data)
        # This matches MATLAB: res = xBal; resNaN = xNaN;
        # Note: We'll work with numpy arrays internally, convert to torch only when needed
        res = x_clean_np.copy()  # T x N (numpy array)
        resNaN = x_clean_np.copy()
        resNaN[indNaN_np] = np.nan
        indNaN = indNaN_np  # Keep reference for later use
        
        # Determine tent kernel size (pC) for slower-frequency aggregation
        # pC is determined from tent_weights_dict: use the maximum length across all slower frequencies
        # This properly handles mixed frequencies (e.g., both 'm' and 'q' series)
        # If no mixed frequency (nQ=0, tent_weights_dict empty/None), pC=1 (no tent kernel needed)
        pC = 1  # Default: no tent kernel (unified frequency case)
        if tent_weights_dict is not None and len(tent_weights_dict) > 0:
            # Use maximum tent kernel size across all slower frequencies
            max_pC = 0
            for freq_key, tent_weights in tent_weights_dict.items():
                tent_len = len(tent_weights) if hasattr(tent_weights, '__len__') else tent_weights.shape[0]
                max_pC = max(max_pC, tent_len)
            if max_pC > 0:
                pC = max_pC
        elif R_mat is not None:
            # Fallback: use R_mat shape if tent_weights_dict is not available
            # R_mat.shape[1] equals tent kernel size (n columns for n-period tent kernel)
            pC = R_mat.shape[1]
        elif nQ > 0:
            # If nQ > 0 but no tent_weights_dict or R_mat, this is an error condition
            # But for safety, use a reasonable default based on common quarterly-monthly case
            _logger.warning(
                f"nQ={nQ} > 0 but tent_weights_dict is empty and R_mat is None. "
                f"Assuming quarterly-monthly aggregation (pC=5). This may be incorrect."
            )
            pC = 5  # Only as last resort when nQ > 0 but no tent info available
        
        ppC = max(p, pC)  # max(p, pC) for lag structure
        
        # Set first pC-1 observations as NaN for slower-frequency aggregation scheme
        # Only needed when pC > 1 (i.e., when there are slower-frequency series)
        if pC > 1:
            resNaN[:pC-1, :] = np.nan
        
        # Initialize output matrices
        C_list = []  # Will concatenate block loadings
        A_list = []  # Will build block-diagonal transition matrix
        Q_list = []  # Will build block-diagonal process noise
        V_0_list = []  # Will build block-diagonal initial covariance
        
        # Process each block sequentially (residual-based approach)
        
        # Validate blocks and i_idio shapes match data
        if blocks is not None:
            if blocks.shape[0] > N:
                _logger.warning(f"blocks.shape[0]={blocks.shape[0]} > N={N}. Truncating blocks to match data size.")
                blocks = blocks[:N, :]
            elif blocks.shape[0] < N:
                _logger.warning(f"blocks.shape[0]={blocks.shape[0]} < N={N}. Padding blocks with zeros.")
                padding = np.zeros((N - blocks.shape[0], blocks.shape[1]), dtype=dtype_np)
                blocks = np.vstack([blocks, padding])
        
        # Validate i_idio shape matches data
        if i_idio is not None:
            if i_idio.shape[0] > N:
                _logger.warning(f"i_idio.shape[0]={i_idio.shape[0]} > N={N}. Truncating i_idio to match data size.")
                i_idio = i_idio[:N]
            elif i_idio.shape[0] < N:
                _logger.warning(f"i_idio.shape[0]={i_idio.shape[0]} < N={N}. Padding i_idio with ones (clock frequency).")
                padding = np.ones(N - i_idio.shape[0], dtype=dtype_np)
                i_idio = np.concatenate([i_idio, padding])
        
        for i in range(n_blocks):
            r_i = int(r[i])  # Number of factors for this block (r is numpy array)
            
            # Find series indices loading on this block (only valid indices < N)
            # Use numpy operations since blocks is numpy array
            idx_i_all = np.where(blocks[:, i] > 0)[0]  # All series loading block i
            idx_i = idx_i_all[idx_i_all < N]  # Filter to valid indices
            
            # Identify clock-frequency vs slower-frequency series using i_idio
            # Clock-frequency series: i_idio == 1
            if len(idx_i) > 0 and i_idio is not None:
                i_idio_idx = i_idio[idx_i]
                idx_iM = idx_i[i_idio_idx > 0]  # Clock-frequency series indices (numpy array)
                idx_iQ = idx_i[i_idio_idx == 0]  # Slower-frequency series indices (numpy array)
            else:
                idx_iM = np.array([], dtype=np.int64)
                idx_iQ = np.array([], dtype=np.int64)
            
            # Initialize observation matrix for this block (use numpy)
            C_i = np.zeros((N, r_i * ppC), dtype=dtype_np)
            
            if len(idx_iM) > 0:
                # === CLOCK-FREQUENCY SERIES: PCA on residuals ===
                # Compute covariance of residuals for clock-frequency series in this block
                # res is numpy array
                res_M = res[:, idx_iM]  # T x n_iM (numpy array)
                # Center the data
                res_M_centered = res_M - res_M.mean(axis=0, keepdims=True)
                # Compute covariance
                n_iM = len(idx_iM)
                if res_M_centered.shape[0] > 1 and n_iM > 1:
                    # Multiple series: use np.cov
                    cov_res = np.cov(res_M_centered.T)  # n_iM x n_iM
                    cov_res = (cov_res + cov_res.T) / 2  # Symmetrize
                elif res_M_centered.shape[0] > 1 and n_iM == 1:
                    # Single series: use np.var and convert to 2D matrix
                    var_val = np.var(res_M_centered, axis=0, ddof=0)
                    cov_res = var_val.reshape(1, 1)  # (1, 1)
                else:
                    # Not enough data: use identity
                    cov_res = np.eye(n_iM, dtype=dtype_np)
                
                # Compute PCA: extract r_i principal components (use numpy version)
                from ..encoder.pca import compute_principal_components
                try:
                    eigenvalues, eigenvectors = compute_principal_components(cov_res, r_i, block_idx=i)
                    v = eigenvectors  # n_iM x r_i (numpy array)
                    
                    # Sign flipping for cleaner output (MATLAB: if sum(v) < 0, v = -v)
                    v_sum = np.sum(v, axis=0)
                    v = np.where(v_sum < 0, -v, v)
                except (RuntimeError, ValueError):
                    # Fallback: use identity
                    v = np.eye(len(idx_iM), dtype=dtype_np)[:, :r_i]
                
                # Set loadings for clock-frequency series
                C_i[idx_iM, :r_i] = v
                
                # Extract factors: f = res(:,idx_iM) * v
                f = res[:, idx_iM] @ v  # T x r_i (numpy array)
            else:
                # No clock-frequency series in this block, use zeros
                f = np.zeros((T, r_i), dtype=dtype_np)
            
            # Build lag matrix F for slower-frequency series (and transition equation)
            # MATLAB: for kk = 0:max(p+1,pC)-1, F = [F f(pC-kk:end-kk,:)]
            # This builds lagged factors for tent kernel aggregation
            # Use numpy since f is numpy array
            F = np.zeros((T, 0), dtype=dtype_np)
            max_lag = max(p + 1, pC)
            for kk in range(max_lag):
                start_idx = pC - kk
                end_idx = T - kk
                if start_idx < 0:
                    start_idx = 0
                if end_idx > T:
                    end_idx = T
                if start_idx < end_idx:
                    f_lag = f[start_idx:end_idx, :]
                    # Ensure f_lag has correct number of columns (r_i)
                    if f_lag.shape[1] != r_i:
                        _logger.warning(f"Block {i}, kk={kk}: f_lag shape mismatch: {f_lag.shape}, expected (?, {r_i}). Adjusting...")
                        if f_lag.shape[1] < r_i:
                            # Pad columns
                            padding = np.zeros((f_lag.shape[0], r_i - f_lag.shape[1]), dtype=dtype_np)
                            f_lag = np.hstack([f_lag, padding])
                        else:
                            # Trim columns
                            f_lag = f_lag[:, :r_i]
                    # Pad to match T
                    if start_idx > 0:
                        padding = np.zeros((start_idx, r_i), dtype=dtype_np)
                        f_lag = np.vstack([padding, f_lag])
                    if len(f_lag) < T:
                        padding = np.zeros((T - len(f_lag), r_i), dtype=dtype_np)
                        f_lag = np.vstack([f_lag, padding])
                    F = np.hstack([F, f_lag])  # T x (r_i * (kk+1))
            
            # Extract ff for slower-frequency series: ff = F(:, 1:r_i*pC)
            # This is used for constrained least squares with tent kernel
            ff = F[:, :r_i * pC] if F.shape[1] >= r_i * pC else F
            
            # === SLOWER-FREQUENCY SERIES: Constrained least squares with tent kernel ===
            if R_mat is not None and q is not None and len(idx_iQ) > 0:
                # Validate indices (idx_iQ is numpy array)
                if np.any(idx_iQ >= N):
                    invalid_indices = idx_iQ[idx_iQ >= N]
                    _logger.error(f"Block {i}: Invalid indices in idx_iQ: {invalid_indices.tolist()}. N={N}, idx_i={idx_i.tolist()}, idx_iQ={idx_iQ.tolist()}")
                    # Filter out invalid indices
                    idx_iQ = idx_iQ[idx_iQ < N]
                    if len(idx_iQ) == 0:
                        _logger.warning(f"Block {i}: No valid slower-frequency series after filtering")
                        continue
                
                for j in idx_iQ:
                    j_idx = int(j)  # j is numpy array element, no .item() needed
                    if j_idx >= N:
                        _logger.error(f"Block {i}: j_idx={j_idx} >= N={N}, skipping")
                        continue
                    
                    # Get frequency-specific R_mat and q for this series
                    if structures_dict is not None and frequencies_list is not None:
                        series_freq = frequencies_list[j_idx]
                        if series_freq in structures_dict:
                            series_R_mat, series_q = structures_dict[series_freq]
                            series_R_mat = _to_numpy(series_R_mat)
                            series_q = _to_numpy(series_q)
                        else:
                            raise ValueError(
                                f"Block {i}, series {j_idx}: frequency '{series_freq}' not in structures_dict. "
                                f"Available: {list(structures_dict.keys())}"
                            )
                    elif R_mat is not None and q is not None:
                        series_R_mat = _to_numpy(R_mat)
                        series_q = _to_numpy(q)
                    else:
                        raise ValueError(f"Block {i}, series {j_idx}: R_mat/q not available for slower-frequency series")
                    
                    # Compute Kronecker products using numpy
                    # Rcon_i = kron(series_R_mat, eye(r_i))
                    eye_r_i = np.eye(r_i, dtype=dtype_np)
                    Rcon_i = np.kron(series_R_mat, eye_r_i)
                    # q_i = kron(series_q, zeros(r_i))
                    zeros_r_i = np.zeros(r_i, dtype=dtype_np)
                    q_i = np.kron(series_q, zeros_r_i)
                    
                    # Extract series j data (drop first pC observations for lag structure)
                    # resNaN and res are numpy arrays
                    xx_j = resNaN[pC:, j_idx]
                    
                    # Check if enough non-NaN observations
                    non_nan_mask = ~np.isnan(xx_j)
                    if np.sum(non_nan_mask) < ff.shape[1] + 2:
                        # Use spline data if too many NaNs
                        xx_j = res[pC:, j_idx]
                        non_nan_mask = np.ones(len(xx_j), dtype=bool)
                    
                    # Extract non-NaN rows
                    # Ensure ff[pC:] and xx_j have matching lengths
                    ff_slice = ff[pC:, :]
                    min_len = min(len(ff_slice), len(xx_j))
                    if len(ff_slice) > min_len:
                        ff_slice = ff_slice[:min_len, :]
                    if len(xx_j) > min_len:
                        xx_j = xx_j[:min_len]
                        non_nan_mask = non_nan_mask[:min_len]
                    
                    ff_j = ff_slice[non_nan_mask, :]
                    xx_j_clean = xx_j[non_nan_mask]
                    
                    if len(ff_j) > 0 and ff_j.shape[0] >= ff_j.shape[1]:
                        try:
                            # OLS: Cc = (ff_j'*ff_j)^{-1} * ff_j' * xx_j
                            iff_j = np.linalg.pinv(ff_j.T @ ff_j)
                            Cc = iff_j @ ff_j.T @ xx_j_clean  # r_i*pC
                            
                            # Apply tent kernel constraint: Cc = Cc - iff_j*Rcon_i'*inv(Rcon_i*iff_j*Rcon_i')*(Rcon_i*Cc-q_i)
                            Rcon_iff = Rcon_i @ iff_j
                            Rcon_iff_RconT = Rcon_iff @ Rcon_i.T
                            try:
                                Cc_constrained = Cc - iff_j @ Rcon_i.T @ np.linalg.solve(
                                    Rcon_iff_RconT + np.eye(Rcon_iff_RconT.shape[0], dtype=dtype_np) * 1e-6,
                                    Rcon_i @ Cc - q_i
                                )
                            except (RuntimeError, ValueError, np.linalg.LinAlgError):
                                Cc_constrained = Cc
                            
                            # Set loadings for slower-frequency series
                            C_i[j_idx, :r_i * pC] = Cc_constrained
                        except (RuntimeError, ValueError, np.linalg.LinAlgError):
                            # Fallback: use zeros
                            pass
            
            # Pad ff with zeros for first pC-1 entries (MATLAB: ff = [zeros(pC-1,pC*r_i);ff])
            # ff is numpy array
            if pC > 1:
                ff_slice = ff[:T - (pC - 1), :r_i * pC] if T > (pC - 1) else ff[:, :r_i * pC]
                padding_top = np.zeros((pC - 1, r_i * pC), dtype=dtype_np)
                ff_padded = np.vstack([padding_top, ff_slice])
                if len(ff_padded) < T:
                    padding_bottom = np.zeros((T - len(ff_padded), r_i * pC), dtype=dtype_np)
                    ff_padded = np.vstack([ff_padded, padding_bottom])
                ff = ff_padded[:T, :]
            
            # Update residuals: res = res - ff * C_i'
            # MATLAB: res = res - ff*C_i'
            # Ensure dimensions match (ff should be T x (r_i * pC), res should be T x N)
            # res, ff, C_i are all numpy arrays
            if res.shape[0] != ff.shape[0]:
                # Pad or trim ff to match res
                if res.shape[0] < ff.shape[0]:
                    ff = ff[:res.shape[0], :]
                else:
                    padding = np.zeros((res.shape[0] - ff.shape[0], ff.shape[1]), dtype=dtype_np)
                    ff = np.vstack([ff, padding])
            res = res - ff @ C_i[:, :r_i * pC].T
            resNaN = res.copy()
            resNaN[indNaN] = np.nan
            
            # Accumulate C
            C_list.append(C_i)
            
            # === TRANSITION EQUATION for this block ===
            # MATLAB: z = F(:,1:r_i), Z = F(:,r_i+1:r_i*(p+1))
            # F is numpy array
            z = F[:, :r_i] if F.shape[1] >= r_i else np.zeros((T, r_i), dtype=dtype_np)
            Z = F[:, r_i:r_i * (p + 1)] if F.shape[1] >= r_i * (p + 1) else np.zeros((T, r_i * p), dtype=dtype_np)
            
            # Initialize transition matrix for this block
            A_i = np.zeros((r_i * ppC, r_i * ppC), dtype=dtype_np)
            
            if T > p and Z.shape[1] > 0:
                try:
                    # OLS: A_temp = inv(Z'*Z)*Z'*z
                    ZTZ = Z.T @ Z
                    reg_scale_val = self.regularization_scale
                    # Convert to float (handles both scalar and array types)
                    try:
                        # Convert to float (handles scalar, array, tensor)
                        if isinstance(reg_scale_val, (int, float)):
                            reg_scale = float(reg_scale_val)
                        elif hasattr(reg_scale_val, 'item'):
                            # Handle numpy array or torch tensor
                            reg_scale = float(reg_scale_val.item())
                        else:
                            reg_scale = float(reg_scale_val)
                    except (TypeError, ValueError, AttributeError):
                        reg_scale = 1e-6  # Default fallback
                    ZTZ_reg = ZTZ + np.eye(ZTZ.shape[0], dtype=dtype_np) * reg_scale
                    # Ensure z is 2D: (T, r_i) or reshape to (r_i, T) for matrix multiplication
                    if z.ndim == 1:
                        z = z.reshape(-1, 1)  # (T, 1) if single factor
                    ZTz = Z.T @ z  # (r_i*p, r_i) or (r_i*p, T) depending on z shape
                    A_temp = np.linalg.solve(ZTZ_reg, ZTz).T  # r_i x (r_i*p)
                    
                    # Ensure A_temp has correct shape
                    if A_temp.shape != (r_i, r_i * p):
                        A_temp_new = np.zeros((r_i, r_i * p), dtype=dtype_np)
                        min_rows = min(A_temp.shape[0], r_i)
                        min_cols = min(A_temp.shape[1], r_i * p)
                        A_temp_new[:min_rows, :min_cols] = A_temp[:min_rows, :min_cols]
                        A_temp = A_temp_new
                    
                    # Set transition matrix: A_i(1:r_i,1:r_i*p) = A_temp'
                    A_i[:r_i, :r_i * p] = A_temp
                    # Identity matrices for lag structure: A_i(r_i+1:end,1:r_i*(ppC-1)) = eye
                    if r_i * (ppC - 1) > 0:
                        A_i[r_i:, :r_i * (ppC - 1)] = np.eye(r_i * (ppC - 1), dtype=dtype_np)
                except (RuntimeError, ValueError, np.linalg.LinAlgError):
                    # Fallback: use identity for AR(1) part
                    A_i[:r_i, :r_i] = np.eye(r_i, dtype=dtype_np) * 0.9
                    if r_i * (ppC - 1) > 0:
                        A_i[r_i:, :r_i * (ppC - 1)] = np.eye(r_i * (ppC - 1), dtype=dtype_np)
            else:
                # Not enough data: use identity
                A_i[:r_i, :r_i] = np.eye(r_i, dtype=dtype_np) * 0.9
                if r_i * (ppC - 1) > 0:
                    A_i[r_i:, :r_i * (ppC - 1)] = np.eye(r_i * (ppC - 1), dtype=dtype_np)
            
            # Initialize Q_i (process noise covariance) for this block
            Q_i = np.zeros((r_i * ppC, r_i * ppC), dtype=dtype_np)
            if T > p:
                # Compute VAR residuals: e = z - Z*A_temp
                if Z.shape[1] > 0:
                    try:
                        e = z[p:, :] - (Z[p:, :] @ A_i[:r_i, :r_i * p].T)
                        if e.shape[0] > 1:
                            # Handle single factor case
                            if e.shape[1] == 1:
                                var_val = np.var(e, axis=0, ddof=0)
                                Q_i[:r_i, :r_i] = var_val.reshape(1, 1)  # (1, 1)
                            else:
                                Q_i[:r_i, :r_i] = np.cov(e.T)
                                Q_i[:r_i, :r_i] = (Q_i[:r_i, :r_i] + Q_i[:r_i, :r_i].T) / 2
                        else:
                            Q_i[:r_i, :r_i] = np.eye(r_i, dtype=dtype_np) * 0.1
                    except (RuntimeError, ValueError):
                        Q_i[:r_i, :r_i] = np.eye(r_i, dtype=dtype_np) * 0.1
                else:
                    Q_i[:r_i, :r_i] = np.eye(r_i, dtype=dtype_np) * 0.1
            else:
                Q_i[:r_i, :r_i] = np.eye(r_i, dtype=dtype_np) * 0.1
            
            # Ensure Q_i is positive definite (with robust error handling)
            Q_i[:r_i, :r_i] = self._ensure_positive_definite(
                Q_i[:r_i, :r_i], min_eigenval=1e-8, dtype=dtype_np
            )
            
            # Initial covariance for this block: initV_i = inv(eye - kron(A_i, A_i)) * Q_i(:)
            try:
                A_i_block = A_i[:r_i * ppC, :r_i * ppC]
                kron_AA = np.kron(A_i_block, A_i_block)
                eye_kron = np.eye((r_i * ppC) ** 2, dtype=dtype_np)
                initV_i_flat = np.linalg.solve(
                    eye_kron - kron_AA + np.eye((r_i * ppC) ** 2, dtype=dtype_np) * 1e-6,
                    Q_i[:r_i * ppC, :r_i * ppC].flatten()
                )
                initV_i = initV_i_flat.reshape(r_i * ppC, r_i * ppC)
            except (RuntimeError, ValueError, np.linalg.LinAlgError):
                initV_i = Q_i[:r_i * ppC, :r_i * ppC].copy()
            
            # Accumulate block matrices
            # Each block can have different sizes (r_i * ppC), which is fine for block_diag
            block_size = r_i * ppC
            
            # Extract or create square matrices of correct size
            # All matrices are numpy arrays now
            if A_i.shape[0] >= block_size and A_i.shape[1] >= block_size:
                A_i_final = A_i[:block_size, :block_size]
            else:
                A_i_final = np.zeros((block_size, block_size), dtype=dtype_np)
                min_rows = min(A_i.shape[0], block_size)
                min_cols = min(A_i.shape[1], block_size)
                A_i_final[:min_rows, :min_cols] = A_i[:min_rows, :min_cols]
            
            if Q_i.shape[0] >= block_size and Q_i.shape[1] >= block_size:
                Q_i_final = Q_i[:block_size, :block_size]
            else:
                Q_i_final = np.zeros((block_size, block_size), dtype=dtype_np)
                min_rows = min(Q_i.shape[0], block_size)
                min_cols = min(Q_i.shape[1], block_size)
                Q_i_final[:min_rows, :min_cols] = Q_i[:min_rows, :min_cols]
            
            if initV_i.shape[0] >= block_size and initV_i.shape[1] >= block_size:
                V_0_i_final = initV_i[:block_size, :block_size]
            else:
                V_0_i_final = np.zeros((block_size, block_size), dtype=dtype_np)
                min_rows = min(initV_i.shape[0], block_size)
                min_cols = min(initV_i.shape[1], block_size)
                V_0_i_final[:min_rows, :min_cols] = initV_i[:min_rows, :min_cols]
            
            A_list.append(A_i_final)
            Q_list.append(Q_i_final)
            V_0_list.append(V_0_i_final)
        
        # Concatenate C matrices (all are numpy arrays)
        if C_list:
            C = np.hstack(C_list)
        else:
            C = np.zeros((N, 0), dtype=dtype_np)
        
        # Build block-diagonal A, Q, V_0 using scipy
        from scipy.linalg import block_diag
        if A_list:
            A_factors = block_diag(*A_list)
            Q_factors = block_diag(*Q_list)
            V_0_factors = block_diag(*V_0_list)
        else:
            A_factors = np.zeros((0, 0), dtype=dtype_np)
            Q_factors = np.zeros((0, 0), dtype=dtype_np)
            V_0_factors = np.zeros((0, 0), dtype=dtype_np)
        
        # === IDIOSYNCRATIC COMPONENTS ===
        # Add identity matrix for clock-frequency idiosyncratic series
        if i_idio is not None:
            eyeN = np.eye(N, dtype=dtype_np)
            # Remove columns for non-idiosyncratic series
            i_idio_bool = i_idio.astype(bool)
            eyeN_idio = eyeN[:, i_idio_bool]  # N x n_idio
            C = np.hstack([C, eyeN_idio])
        else:
            # Default: all clock-frequency series have idiosyncratic components
            eyeN_clock = np.eye(N, dtype=dtype_np)[:, :nM] if nM > 0 else np.zeros((N, 0), dtype=dtype_np)
            C = np.hstack([C, eyeN_clock])
        
        # Add slower-frequency idiosyncratic chains using tent kernels
        # Each slower-frequency series has a state chain with length determined by tent kernel size
        if nQ > 0:
            # Get slower-frequency series indices first
            if i_idio is not None:
                slower_freq_indices = np.where(i_idio == 0)[0]
            else:
                slower_freq_indices = np.arange(nM, N, dtype=np.int64)
            
            # Validate: actual number of slower-frequency series should match nQ
            actual_nQ = len(slower_freq_indices)
            if actual_nQ != nQ:
                _logger.warning(
                    f"nQ={nQ} but actual slower-frequency series count={actual_nQ}. "
                    f"Using actual count={actual_nQ}."
                )
                nQ = actual_nQ
            
            # Convert idio_chain_lengths to numpy if needed
            if idio_chain_lengths is not None:
                idio_chain_lengths_arr = _to_numpy(idio_chain_lengths)
                # Ensure length matches actual number of slower-frequency series
                if len(idio_chain_lengths_arr) != nQ:
                    _logger.warning(
                        f"idio_chain_lengths length={len(idio_chain_lengths_arr)} != nQ={nQ}. "
                        f"Padding or truncating to match."
                    )
                    if len(idio_chain_lengths_arr) < nQ:
                        # Pad with pC
                        padding = np.full(nQ - len(idio_chain_lengths_arr), pC, dtype=np.int64)
                        idio_chain_lengths_arr = np.concatenate([idio_chain_lengths_arr, padding])
                    else:
                        # Truncate
                        idio_chain_lengths_arr = idio_chain_lengths_arr[:nQ]
                # Check if all values are 0 (invalid) and use fallback
                if np.all(idio_chain_lengths_arr == 0):
                    _logger.warning(
                        f"All idio_chain_lengths are 0 for slower-frequency series. "
                        f"Using fallback: pC={pC} for all {nQ} series."
                    )
                    idio_chain_lengths_arr = np.full(nQ, pC, dtype=np.int64)
            else:
                # Fallback: use pC (tent kernel size) for all slower-frequency series
                # This is generic: each slower-frequency series gets a chain of length pC
                idio_chain_lengths_arr = np.full(nQ, pC, dtype=np.int64)
            
            # Build idiosyncratic chains for each slower-frequency series
            # Each series j gets a chain of length idio_chain_lengths_arr[j]
            total_chain_states = int(np.sum(idio_chain_lengths_arr))
            C_slower = np.zeros((N, total_chain_states), dtype=dtype_np)
            
            col_offset = 0
            for j, series_idx in enumerate(slower_freq_indices):
                if j >= len(idio_chain_lengths_arr):
                    break
                
                # Validate series_idx is within bounds
                if series_idx >= N:
                    _logger.error(f"series_idx={series_idx} >= N={N}, skipping")
                    continue
                
                chain_len = int(idio_chain_lengths_arr[j])
                
                # Skip if chain_len is 0, but use fallback pC if available
                if chain_len == 0:
                    _logger.warning(
                        f"Series {series_idx}: chain_len=0, using fallback pC={pC}."
                    )
                    chain_len = pC
                
                # Get tent weights for this series' frequency
                # Only use tent_weights_dict for slower-frequency series (not clock-frequency)
                if tent_weights_dict is not None and frequencies_list is not None and series_idx < len(frequencies_list):
                    series_freq = frequencies_list[series_idx]
                    # Check if this is actually a slower-frequency series (not clock-frequency)
                    clock_hierarchy = FREQUENCY_HIERARCHY.get(clock, 3)
                    freq_hierarchy = FREQUENCY_HIERARCHY.get(series_freq, 3)
                    
                    if freq_hierarchy > clock_hierarchy and series_freq in tent_weights_dict:
                        # This is a slower-frequency series, use tent weights
                        tent_weights = tent_weights_dict[series_freq]
                        tent_weights = _to_numpy(tent_weights)
                        # Ensure tent_weights length matches chain_len
                        if len(tent_weights) != chain_len:
                            _logger.warning(
                                f"Series {series_idx} (freq={series_freq}): tent_weights length={len(tent_weights)} != chain_len={chain_len}. "
                                f"Using chain_len={chain_len}."
                            )
                            if len(tent_weights) > chain_len:
                                tent_weights = tent_weights[:chain_len]
                            else:
                                # Pad with uniform weights
                                padding = np.ones(chain_len - len(tent_weights), dtype=dtype_np) / (chain_len - len(tent_weights))
                                tent_weights = np.concatenate([tent_weights, padding])
                        # Normalize tent weights
                        tent_weights = tent_weights / np.sum(tent_weights)
                    else:
                        # Fallback: uniform weights (for slower-frequency series without tent_weights or clock-frequency series)
                        tent_weights = np.ones(chain_len, dtype=dtype_np) / chain_len
                else:
                    # Fallback: uniform weights
                    tent_weights = np.ones(chain_len, dtype=dtype_np) / chain_len
                
                # Validate bounds before assignment
                if col_offset + chain_len > total_chain_states:
                    _logger.error(
                        f"col_offset={col_offset} + chain_len={chain_len} > total_chain_states={total_chain_states}. "
                        f"Truncating chain_len."
                    )
                    chain_len = max(0, total_chain_states - col_offset)
                    tent_weights = tent_weights[:chain_len] if chain_len > 0 else np.array([])
                
                if chain_len > 0:
                    # Set tent weights for this series: C[series_idx, col_offset:col_offset+chain_len] = tent_weights
                    C_slower[series_idx, col_offset:col_offset+chain_len] = tent_weights
                    col_offset += chain_len
            
            C = np.hstack([C, C_slower])
        
        # Initialize R (observation noise covariance) from final residuals
        # Ensure resNaN is 2D: (T, N)
        if resNaN.ndim > 2:
            _logger.warning(f"resNaN has unexpected shape: {resNaN.shape}, reshaping to 2D...")
            resNaN = resNaN.reshape(-1, resNaN.shape[-1])
        elif resNaN.ndim == 1:
            _logger.warning(f"resNaN is 1D: {resNaN.shape}, reshaping to 2D...")
            resNaN = resNaN.reshape(1, -1)
        
        # Check degrees of freedom before computing variance
        T_res, N_res = resNaN.shape
        if T_res <= 1:
            # Not enough data for variance calculation, use fallback
            _logger.warning(f"resNaN has T={T_res} <= 1, using fallback variance values")
            var_res = np.ones(N_res, dtype=dtype_np) * 1e-4
        else:
            # Count valid (non-NaN) values per column
            valid_counts = np.sum(np.isfinite(resNaN), axis=0)
            # Compute variance only for columns with at least 2 valid values
            var_res = np.zeros(N_res, dtype=dtype_np)
            for i in range(N_res):
                if valid_counts[i] > 1:
                    col_data = resNaN[:, i]
                    col_valid = col_data[np.isfinite(col_data)]
                    if len(col_valid) > 1:
                        var_res[i] = np.var(col_valid, ddof=0)
                    else:
                        var_res[i] = 1e-4
                else:
                    var_res[i] = 1e-4
            
            # Handle any remaining NaN/Inf values
            var_res = np.where(np.isfinite(var_res), var_res, 1e-4)
        
        # Ensure var_res is 1D
        if var_res.ndim > 1:
            _logger.warning(f"var_res has unexpected shape: {var_res.shape}, flattening...")
            var_res = var_res.flatten()
        elif var_res.ndim == 0:
            # Single value, expand to (N,)
            var_res = np.array([var_res])
        
        R = np.diag(var_res)  # (N, N)
        R = np.where(np.isfinite(R), R, 1e-4)
        
        # Set clock-frequency idiosyncratic variances to 1e-4 (MATLAB: R(ii_idio(i),ii_idio(i)) = 1e-04)
        if i_idio is not None:
            i_idio_indices = np.where(i_idio > 0)[0]
            for idx in i_idio_indices:
                R[idx, idx] = 1e-4
        else:
            # Default: all clock-frequency series
            for idx in range(nM):
                R[idx, idx] = 1e-4
        
        # Set slower-frequency variances: Rdiag(slower_freq_indices) = 1e-04
        if i_idio is not None:
            slower_freq_indices = np.where(i_idio == 0)[0]
            for idx in slower_freq_indices:
                R[idx, idx] = 1e-4
        else:
            # Fallback: assume last nQ are slower-frequency
            for idx in range(nM, N):
                R[idx, idx] = 1e-4
        
        # === IDIOSYNCRATIC TRANSITION MATRICES ===
        # Clock-frequency: AR(1) for each series
        n_idio_M = nM if i_idio is None else int(np.sum(i_idio))
        BM = np.zeros((n_idio_M, n_idio_M), dtype=dtype_np)
        SM = np.zeros((n_idio_M, n_idio_M), dtype=dtype_np)
        
        if i_idio is not None:
            ii_idio = np.where(i_idio > 0)[0]
        else:
            ii_idio = np.arange(nM, dtype=np.int64)
        
        for i, idx in enumerate(ii_idio):
            res_i = resNaN[:, idx]
            # Find leading and trailing NaNs
            non_nan_mask = ~np.isnan(res_i)
            if np.sum(non_nan_mask) > 1:
                non_nan_idx = np.where(non_nan_mask)[0]
                first_non_nan = non_nan_idx[0] if len(non_nan_idx) > 0 else 0
                last_non_nan = non_nan_idx[-1] if len(non_nan_idx) > 0 else T - 1
                res_i_clean = res[first_non_nan:last_non_nan + 1, idx]
                
                if len(res_i_clean) > 1:
                    # AR(1): res_i(t) = BM * res_i(t-1) + error
                    y_ar = res_i_clean[1:]
                    x_ar = res_i_clean[:-1].reshape(-1, 1)
                    try:
                        # OLS: BM = (x_ar'*x_ar)^{-1} * x_ar'*y_ar
                        XTX = x_ar.T @ x_ar + np.eye(1, dtype=dtype_np) * 1e-6
                        XTy = x_ar.T @ y_ar
                        BM_coef = np.linalg.solve(XTX, XTy)
                        BM[i, i] = float(BM_coef[0, 0] if BM_coef.ndim == 2 else BM_coef[0])
                        # Residual covariance
                        residuals_ar = y_ar - x_ar.flatten() * BM[i, i]
                        if len(residuals_ar) > 1:
                            SM[i, i] = float(np.var(residuals_ar, ddof=0))
                        else:
                            SM[i, i] = 0.1
                    except (RuntimeError, ValueError, np.linalg.LinAlgError):
                        BM[i, i] = 0.1
                        SM[i, i] = 0.1
                else:
                    BM[i, i] = 0.1
                    SM[i, i] = 0.1
            else:
                BM[i, i] = 0.1
                SM[i, i] = 0.1
        
        # Slower-frequency: state chains with rho0 = 0.1
        # Each slower-frequency series has a chain with length determined by tent kernel size
        rho0 = 0.1
        if nQ > 0:
            # Get slower-frequency series indices
            if i_idio is not None:
                slower_freq_indices = np.where(i_idio == 0)[0]
            else:
                slower_freq_indices = np.arange(nM, N, dtype=np.int64)
            
            # sig_e = Rdiag(slower_freq_indices) / (chain_length - 1) for each series
            # This is a generic approximation: variance scales with chain length
            if idio_chain_lengths is not None:
                idio_chain_lengths_arr = _to_numpy(idio_chain_lengths)
                # Filter to only slower-frequency series if idio_chain_lengths has length N
                if len(idio_chain_lengths_arr) == N:
                    # Extract only slower-frequency series chain lengths
                    idio_chain_lengths_arr = idio_chain_lengths_arr[slower_freq_indices]
                # Check if all values are 0 (invalid)
                if np.all(idio_chain_lengths_arr == 0):
                    _logger.warning(
                        f"All idio_chain_lengths are 0 for slower-frequency series. "
                        f"Using fallback: pC={pC} for all {nQ} series."
                    )
                    idio_chain_lengths_arr = np.full(nQ, pC, dtype=np.int64)
            else:
                # Fallback: use pC for all slower-frequency series
                idio_chain_lengths_arr = np.full(nQ, pC, dtype=np.int64)
            
            # Compute sig_e for each slower-frequency series
            sig_e_list = []
            for j, series_idx in enumerate(slower_freq_indices):
                if j >= len(idio_chain_lengths_arr):
                    break
                chain_len = int(idio_chain_lengths_arr[j])
                # Generic approximation: divide by (chain_len - 1) instead of hardcoded 19
                sig_e_j = R[series_idx, series_idx] / max(chain_len - 1, 1.0)
                sig_e_list.append(sig_e_j)
            
            if len(sig_e_list) > 0:
                sig_e = np.array(sig_e_list)
                sig_e = np.where(np.isfinite(sig_e), sig_e, 1e-4)
            else:
                sig_e = np.array([1e-4], dtype=dtype_np)
            
            # Build transition matrices for each slower-frequency series
            # Each series has a chain with its own length (determined by tent kernel size)
            BQ_list = []
            SQ_list = []
            
            for j, series_idx in enumerate(slower_freq_indices):
                if j >= len(idio_chain_lengths_arr) or j >= len(sig_e):
                    break
                chain_len = int(idio_chain_lengths_arr[j])
                
                # Skip if chain_len is 0 (should not happen, but handle gracefully)
                if chain_len == 0:
                    _logger.warning(
                        f"Series {series_idx}: chain_len=0 in BQ construction, skipping. This may indicate a configuration issue."
                    )
                    continue
                
                sig_e_j = sig_e[j]
                
                # Build transition matrix block for this series
                # BQ_block: [[rho0 zeros(1,chain_len-1)]; [eye(chain_len-1), zeros(chain_len-1,1)]]
                BQ_block = np.zeros((chain_len, chain_len), dtype=dtype_np)
                BQ_block[0, 0] = rho0
                if chain_len > 1:
                    BQ_block[1:, :chain_len-1] = np.eye(chain_len - 1, dtype=dtype_np)
                BQ_list.append(BQ_block)
                
                # SQ_block: temp = zeros(chain_len); temp(0,0) = 1
                temp = np.zeros((chain_len, chain_len), dtype=dtype_np)
                temp[0, 0] = 1.0
                # SQ_block = (1 - rho0^2) * sig_e_j * temp
                SQ_block = (1 - rho0 ** 2) * sig_e_j * temp
                SQ_list.append(SQ_block)
            
            # Combine all blocks into block-diagonal matrices
            if len(BQ_list) > 0:
                BQ = block_diag(*BQ_list)
                SQ = block_diag(*SQ_list)
                
                # initViQ = reshape(inv(eye - kron(BQ,BQ))*SQ(:), total_states, total_states)
                total_states = BQ.shape[0]
                try:
                    kron_BQBQ = np.kron(BQ, BQ)
                    eye_kron = np.eye(total_states ** 2, dtype=dtype_np)
                    initViQ_flat = np.linalg.solve(
                        eye_kron - kron_BQBQ + np.eye(total_states ** 2, dtype=dtype_np) * 1e-6,
                        SQ.flatten()
                    )
                    initViQ = initViQ_flat.reshape(total_states, total_states)
                except (RuntimeError, ValueError, np.linalg.LinAlgError):
                    initViQ = SQ.copy()
            else:
                BQ = np.zeros((0, 0), dtype=dtype_np)
                SQ = np.zeros((0, 0), dtype=dtype_np)
                initViQ = np.zeros((0, 0), dtype=dtype_np)
        else:
            BQ = np.zeros((0, 0), dtype=dtype_np)
            SQ = np.zeros((0, 0), dtype=dtype_np)
            initViQ = np.zeros((0, 0), dtype=dtype_np)
        
        # Clock-frequency initial covariance: initViM = diag(1./diag(eye - BM.^2)).*SM
        try:
            eye_BM = np.eye(n_idio_M, dtype=dtype_np)
            BM_sq = BM ** 2
            diag_inv = 1.0 / np.diag(eye_BM - BM_sq)
            diag_inv = np.where(np.isfinite(diag_inv), diag_inv, np.ones_like(diag_inv))
            initViM = np.diag(diag_inv) @ SM
        except (RuntimeError, ValueError, np.linalg.LinAlgError):
            initViM = SM.copy()
        
        # Combine all transition matrices: A = blkdiag(A_factors, BM, BQ)
        # Ensure BM, SM, initViM have correct dimensions
        if BM.shape[0] != n_idio_M or BM.shape[1] != n_idio_M:
            _logger.warning(f"BM shape mismatch: expected ({n_idio_M}, {n_idio_M}), got {BM.shape}. Resizing...")
            BM_new = np.zeros((n_idio_M, n_idio_M), dtype=dtype_np)
            min_dim = min(BM.shape[0], n_idio_M, BM.shape[1], n_idio_M)
            BM_new[:min_dim, :min_dim] = BM[:min_dim, :min_dim]
            BM = BM_new
        if SM.shape[0] != n_idio_M or SM.shape[1] != n_idio_M:
            _logger.warning(f"SM shape mismatch: expected ({n_idio_M}, {n_idio_M}), got {SM.shape}. Resizing...")
            SM_new = np.zeros((n_idio_M, n_idio_M), dtype=dtype_np)
            min_dim = min(SM.shape[0], n_idio_M, SM.shape[1], n_idio_M)
            SM_new[:min_dim, :min_dim] = SM[:min_dim, :min_dim]
            SM = SM_new
        if initViM.shape[0] != n_idio_M or initViM.shape[1] != n_idio_M:
            _logger.warning(f"initViM shape mismatch: expected ({n_idio_M}, {n_idio_M}), got {initViM.shape}. Resizing...")
            initViM_new = np.zeros((n_idio_M, n_idio_M), dtype=dtype_np)
            min_dim = min(initViM.shape[0], n_idio_M, initViM.shape[1], n_idio_M)
            initViM_new[:min_dim, :min_dim] = initViM[:min_dim, :min_dim]
            initViM = initViM_new
        
        # Build final block-diagonal matrices
        A = block_diag(A_factors, BM, BQ)
        Q = block_diag(Q_factors, SM, SQ)
        V_0 = block_diag(V_0_factors, initViM, initViQ)
        
        # Initial state: Z_0 = zeros
        m = A.shape[0]
        Z_0 = np.zeros(m, dtype=dtype_np)
        
        # Ensure V_0 is positive definite (with robust error handling)
        V_0 = self._ensure_positive_definite(V_0, min_eigenval=1e-8, dtype=dtype_np)
        
        # All results are already numpy arrays
        return A, C, Q, R, Z_0, V_0
    
    def check_convergence(
        self,
        loglik: float,
        previous_loglik: float,
        threshold: float,
        verbose: bool = False
    ) -> Tuple[bool, float]:
        """Check EM convergence.
        
        Parameters
        ----------
        loglik : float
            Current log-likelihood value
        previous_loglik : float
            Previous log-likelihood value
        threshold : float
            Convergence threshold (typically 1e-4 to 1e-5)
        verbose : bool
            Whether to log convergence status
            
        Returns
        -------
        converged : bool
            Whether convergence was achieved
        change : float
            Relative change in log-likelihood
        """
        if previous_loglik == float('-inf'):
            return False, 0.0
        
        if abs(previous_loglik) < 1e-10:
            # Previous loglik is essentially zero, use absolute change
            change = abs(loglik - previous_loglik)
        else:
            # Relative change
            change = abs((loglik - previous_loglik) / previous_loglik)
        
        converged = change < threshold
        
        if verbose and converged:
            _logger.info(f'EM algorithm converged: loglik change = {change:.2e} < {threshold:.2e}')
        
        return converged, change

