"""Linear Dynamic Factor Model (DFM) implementation.

This module contains the linear DFM implementation using EM algorithm.
DFM is a PyTorch Lightning module that inherits from BaseFactorModel.
"""

# Standard library imports
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

# Third-party imports
import numpy as np
import torch
import torch.nn as nn

# Local imports
from ..config import (
    DFMConfig,
    make_config_source,
    ConfigSource,
)
from ..config.results import DFMResult, FitParams
from ..config.utils import get_agg_structure, get_tent_weights, FREQUENCY_HIERARCHY, TENT_WEIGHTS_LOOKUP
from ..logger import get_logger
from ..ssm.em import EMAlgorithm, EMStepParams
from ..ssm.kalman import KalmanFilter
from .base import BaseFactorModel

# Frequency to integer mapping for tensor conversion
_FREQ_TO_INT = {'d': 1, 'w': 2, 'm': 3, 'q': 4, 'sa': 5, 'a': 6}

if TYPE_CHECKING:
    from omegaconf import DictConfig
    from ..lightning import DFMDataModule

_logger = get_logger(__name__)


@dataclass
class DFMTrainingState:
    """State tracking for DFM training."""
    A: torch.Tensor
    C: torch.Tensor
    Q: torch.Tensor
    R: torch.Tensor
    Z_0: torch.Tensor
    V_0: torch.Tensor
    loglik: float
    num_iter: int
    converged: bool



class DFM(BaseFactorModel):
    """Linear Dynamic Factor Model using EM algorithm.
    
    PyTorch Lightning module for DFM estimation. Supports mixed-frequency data
    via tent kernels when mixed_freq=True.
    """
    
    def __init__(
        self,
        config: Optional[DFMConfig] = None,
        num_factors: Optional[int] = None,
        threshold: float = 1e-4,
        max_iter: int = 100,
        nan_method: int = 2,
        nan_k: int = 3,
        mixed_freq: bool = False,
        **kwargs
    ):
        """Initialize DFM instance.
        
        Parameters
        ----------
        config : DFMConfig, optional
            DFM configuration. Can be loaded later via load_config().
        num_factors : int, optional
            Number of factors. If None, inferred from config.
        threshold : float, default 1e-4
            EM convergence threshold
        max_iter : int, default 100
            Maximum EM iterations
        nan_method : int, default 2
            Missing data handling method
        nan_k : int, default 3
            Spline interpolation order
        mixed_freq : bool, default False
            If True, use tent kernels for mixed-frequency data. If False, treat all series as clock frequency.
            When True, raises ValueError if any frequency pair is not in TENT_WEIGHTS_LOOKUP.
        **kwargs
            Additional arguments passed to BaseFactorModel
        """
        super().__init__(**kwargs)
        
        config = self._initialize_config(config)
        
        self.threshold = threshold
        self.max_iter = max_iter
        self.nan_method = nan_method
        self.nan_k = nan_k
        self.mixed_freq = mixed_freq
        
        # Mixed frequency parameters (set during initialize_from_data)
        self._em_nQ = 0
        self._em_tent_weights_dict = None
        self._em_frequencies = None
        self._em_i_idio = None
        self._em_idio_chain_lengths = None
        self._em_structures_dict = None  # Maps frequency to (R_mat, q) tuple
        self._em_frequencies_list = None  # Original frequency strings for each series
        
        # Determine number of factors
        if num_factors is None:
            if hasattr(config, 'factors_per_block') and config.factors_per_block:
                self.num_factors = int(np.sum(config.factors_per_block))
            else:
                blocks = config.get_blocks_array()
                if blocks.shape[1] > 0:
                    self.num_factors = int(np.sum(blocks[:, 0]))
                else:
                    self.num_factors = 1
        else:
            self.num_factors = num_factors
        
        # Get model structure
        self.r = torch.tensor(
            config.factors_per_block if config.factors_per_block is not None
            else np.ones(config.get_blocks_array().shape[1]),
            dtype=torch.float32
        )
        self.p = getattr(config, 'ar_lag', 1)
        self.blocks = torch.tensor(config.get_blocks_array(), dtype=torch.float32)
        
        # Kalman filter with regularization for numerical stability
        self.kalman = KalmanFilter(
            min_eigenval=1e-5,
            inv_regularization=1e-3,
            cholesky_regularization=1e-5,
            use_cpu=True
        )
        self.em = EMAlgorithm(
            kalman=self.kalman,
            regularization_scale=1e-3
        )
        
        # Parameters will be initialized in setup() or fit_em()
        self.A: Optional[torch.nn.Parameter] = None
        self.C: Optional[torch.nn.Parameter] = None
        self.Q: Optional[torch.nn.Parameter] = None
        self.R: Optional[torch.nn.Parameter] = None
        self.Z_0: Optional[torch.nn.Parameter] = None
        self.V_0: Optional[torch.nn.Parameter] = None
        
        # Training state
        self.Mx: Optional[np.ndarray] = None
        self.Wx: Optional[np.ndarray] = None
        self.data_processed: Optional[torch.Tensor] = None
        
        # Use manual optimization for EM algorithm
        self.automatic_optimization = False
    
    def setup(self, stage: Optional[str] = None) -> None:
        """Initialize model parameters.
        
        This is called by Lightning before training starts.
        Parameters are initialized from data if available.
        """
        # Parameters will be initialized during fit_em() or first training step
        pass
    
    def _create_em_step_params(
        self, 
        y: torch.Tensor, 
        device: torch.device, 
        dtype: torch.dtype
    ) -> EMStepParams:
        """Create EM step parameters using stored mixed frequency parameters.
        
        Parameters
        ----------
        y : torch.Tensor
            Data tensor (N x T)
        device : torch.device
            Device for tensors
        dtype : torch.dtype
            Data type for tensors
            
        Returns
        -------
        EMStepParams
            EM step parameters
            
        Raises
        ------
        ValueError
            If mixed_freq=True but required parameters are not initialized
        """
        clock = getattr(self.config, 'clock', 'm')
        N = y.shape[0]
        
        if self.mixed_freq and self._em_i_idio is None:
            raise ValueError(
                f"DFM._create_em_step_params: mixed_freq=True but _em_i_idio is not initialized. "
                f"Call initialize_from_data() before training."
            )
        
        return EMStepParams(
            y=y,
            A=self.A,
            C=self.C,
            Q=self.Q,
            R=self.R,
            Z_0=self.Z_0,
            V_0=self.V_0,
            r=self.r.to(device),
            p=self.p,
            R_mat=None,  # Not used when structures_dict is provided
            q=None,      # Not used when structures_dict is provided
            nQ=self._em_nQ,
            i_idio=self._em_i_idio if self._em_i_idio is not None else torch.ones(N, device=device, dtype=dtype),
            blocks=self.blocks.to(device),
            tent_weights_dict=self._em_tent_weights_dict,
            clock=clock,
            frequencies=self._em_frequencies,
            idio_chain_lengths=self._em_idio_chain_lengths if self._em_idio_chain_lengths is not None else torch.zeros(N, device=device, dtype=dtype),
            config=self.config,
            structures_dict=self._em_structures_dict,
            frequencies_list=self._em_frequencies_list
        )
    
    def initialize_from_data(self, X: torch.Tensor) -> None:
        """Initialize parameters from data using PCA and OLS.
        
        Parameters
        ----------
        X : torch.Tensor
            Standardized data (T x N)
        """
        opt_nan = {'method': self.nan_method, 'k': self.nan_k}
        clock = getattr(self.config, 'clock', 'm')
        
        # Handle mixed_freq parameter
        if self.mixed_freq:
            # Use tent kernels for mixed-frequency data
            agg_structure = get_agg_structure(self.config, clock=clock)
            
            # Validate that all required frequency pairs are in TENT_WEIGHTS_LOOKUP
            frequencies_list = [s.frequency for s in self.config.series]
            frequencies_set = set(frequencies_list)
            clock_hierarchy = FREQUENCY_HIERARCHY.get(clock, 3)
            
            missing_pairs = []
            for freq in frequencies_set:
                freq_hierarchy = FREQUENCY_HIERARCHY.get(freq, 3)
                if freq_hierarchy > clock_hierarchy:
                    # This frequency is slower than clock, needs tent kernel
                    tent_w = get_tent_weights(freq, clock)
                    if tent_w is None:
                        missing_pairs.append((freq, clock))
            
            if missing_pairs:
                raise ValueError(
                    f"mixed_freq=True but the following frequency pairs are not in TENT_WEIGHTS_LOOKUP: {missing_pairs}. "
                    f"Available pairs: {list(TENT_WEIGHTS_LOOKUP.keys())}. "
                    f"Either add the missing pairs to TENT_WEIGHTS_LOOKUP or set mixed_freq=False."
                )
            
            # Convert tent_weights to torch tensors
            tent_weights_dict = {k: torch.tensor(v, dtype=torch.float32, device=X.device) 
                                for k, v in agg_structure['tent_weights'].items()}
            
            # Store structures dict for per-series R_mat/q lookup
            structures_dict = {}
            if agg_structure['structures']:
                for (freq, clock_key), (R_mat_np, q_np) in agg_structure['structures'].items():
                    structures_dict[freq] = (
                        torch.tensor(R_mat_np, dtype=torch.float32, device=X.device),
                        torch.tensor(q_np, dtype=torch.float32, device=X.device)
                    )
            
            # Create frequencies array
            frequencies_array = np.array(frequencies_list, dtype=object)
            
            # Count slower-frequency series
            nQ = sum(1 for freq in frequencies_list 
                    if FREQUENCY_HIERARCHY.get(freq, 3) > clock_hierarchy)
            
            # Compute i_idio (1 for clock frequency, 0 for slower frequencies)
            i_idio = torch.tensor([1 if freq == clock else 0 for freq in frequencies_list], 
                                 dtype=torch.float32, device=X.device)
            
            # Validate: slower-frequency series must have tent weights
            if nQ > 0 and not tent_weights_dict:
                raise ValueError(
                    f"DFM.initialize_from_data: mixed_freq=True and {nQ} slower-frequency series detected, "
                    f"but tent_weights_dict is empty. Check that all frequency pairs are in TENT_WEIGHTS_LOOKUP."
                )
        else:
            tent_weights_dict = None
            structures_dict = None
            frequencies_array = None
            nQ = 0
            i_idio = torch.ones(X.shape[1], dtype=torch.float32, device=X.device)
        
        # Convert frequencies to torch tensor
        frequencies_tensor = None
        if frequencies_array is not None:
            frequencies_tensor = torch.tensor(
                [_FREQ_TO_INT.get(f, 3) for f in frequencies_array], 
                dtype=torch.int32, 
                device=X.device
            )
        
        # Store for reuse in EM steps
        self._em_nQ = nQ
        self._em_tent_weights_dict = tent_weights_dict
        self._em_frequencies = frequencies_tensor
        self._em_i_idio = i_idio
        self._em_structures_dict = structures_dict
        self._em_frequencies_list = frequencies_list if self.mixed_freq else None
        self._em_idio_chain_lengths = torch.zeros(X.shape[1], dtype=torch.int32, device=X.device)
        
        # Initialize parameters using EM algorithm
        A, C, Q, R, Z_0, V_0 = self.em.initialize_parameters(
            X,
            r=self.r.to(X.device),
            p=self.p,
            blocks=self.blocks.to(X.device),
            opt_nan=opt_nan,
            R_mat=None,  # Not used when structures_dict is provided
            q=None,      # Not used when structures_dict is provided
            nQ=nQ,
            i_idio=i_idio,
            clock=clock,
            tent_weights_dict=tent_weights_dict,
            frequencies=frequencies_tensor,
            structures_dict=structures_dict,
            frequencies_list=frequencies_list if self.mixed_freq else None,
            idio_chain_lengths=self._em_idio_chain_lengths,
            config=self.config
        )
        
        # Convert numpy arrays to torch tensors for nn.Parameter
        device = X.device
        dtype = X.dtype
        self.A = nn.Parameter(torch.tensor(A, device=device, dtype=dtype))
        self.C = nn.Parameter(torch.tensor(C, device=device, dtype=dtype))
        self.Q = nn.Parameter(torch.tensor(Q, device=device, dtype=dtype))
        self.R = nn.Parameter(torch.tensor(R, device=device, dtype=dtype))
        self.Z_0 = nn.Parameter(torch.tensor(Z_0, device=device, dtype=dtype))
        self.V_0 = nn.Parameter(torch.tensor(V_0, device=device, dtype=dtype))
    
    def _convert_em_params_to_numpy(self, em_params: 'EMStepParams') -> 'EMStepParams':
        """Convert EMStepParams from torch tensors to numpy arrays."""
        from ..ssm.utils import _to_numpy
        
        return EMStepParams(
            y=_to_numpy(em_params.y),
            A=_to_numpy(em_params.A),
            C=_to_numpy(em_params.C),
            Q=_to_numpy(em_params.Q),
            R=_to_numpy(em_params.R),
            Z_0=_to_numpy(em_params.Z_0),
            V_0=_to_numpy(em_params.V_0),
            r=_to_numpy(em_params.r),
            p=em_params.p,
            R_mat=_to_numpy(em_params.R_mat) if em_params.R_mat is not None else None,
            q=_to_numpy(em_params.q) if em_params.q is not None else None,
            nQ=em_params.nQ,
            i_idio=_to_numpy(em_params.i_idio),
            blocks=_to_numpy(em_params.blocks),
            tent_weights_dict={k: _to_numpy(v) for k, v in em_params.tent_weights_dict.items()} if em_params.tent_weights_dict else {},
            clock=em_params.clock,
            frequencies=_to_numpy(em_params.frequencies) if em_params.frequencies is not None else None,
            idio_chain_lengths=_to_numpy(em_params.idio_chain_lengths),
            config=em_params.config,
            structures_dict={k: (_to_numpy(v[0]), _to_numpy(v[1])) for k, v in em_params.structures_dict.items()} if em_params.structures_dict else None,
            frequencies_list=em_params.frequencies_list
        )
    
    
    def training_step(self, batch: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], batch_idx: int) -> torch.Tensor:
        """Lightning training step. EM training is handled by fit_em() in on_train_start()."""
        if self.training_state is not None:
            loglik = self.training_state.loglik
            self.log('loglik', loglik, on_step=True, on_epoch=True, prog_bar=True)
            device = next(self.parameters()).device if len(list(self.parameters())) > 0 else torch.device('cpu')
            return -torch.tensor(loglik, device=device, dtype=torch.float32)
        return torch.tensor(0.0, device=next(self.parameters()).device if len(list(self.parameters())) > 0 else torch.device('cpu'))
    
    def fit_em(
        self,
        X: torch.Tensor,
        Mx: Optional[np.ndarray] = None,
        Wx: Optional[np.ndarray] = None
    ) -> DFMTrainingState:
        """Run full EM algorithm until convergence.
        
        Called by trainer during fit(). Runs outside Lightning's training loop.
        """
        import time
        fit_start_time = time.time()
        _logger.info(f"{'='*70}")
        _logger.info(f"Starting DFM EM training (max_iter={self.max_iter}, threshold={self.threshold})")
        _logger.info(f"{'='*70}")
        self.Mx = Mx
        self.Wx = Wx
        
        # Ensure data is on same device as model (Lightning handles this automatically)
        X = X.to(self.device)
        
        device = X.device
        dtype = X.dtype
        
        # Initialize with method=2 (remove >80% NaN rows, then fill)
        from dfm_python.utils.data import rem_nans_spline_torch
        X_init, _ = rem_nans_spline_torch(X, method=2, k=self.nan_k)
        X_init = torch.where(torch.isfinite(X_init), X_init, torch.tensor(0.0, device=device, dtype=dtype))
        
        # Initialize parameters using method=2 data
        self.initialize_from_data(X_init)
        
        # Prepare data for EM loop with method=3 (only remove all-NaN rows)
        X_est, _ = rem_nans_spline_torch(X, method=3, k=self.nan_k)
        y = X_est.T  # (N x T) - may contain NaN values
        
        # Store processed data for get_result()
        self.data_processed = y
        
        # Initialize state
        previous_loglik = float('-inf')
        best_loglik = float('-inf')
        best_params = None
        previous_A_norm = None
        previous_C_norm = None
        num_iter = 0
        converged = False
        loglik = float('-inf')
        change = 0.0
        
        # Create em_params once and reuse across iterations
        em_params = self._create_em_step_params(y, device, dtype)
        
        # EM loop
        while num_iter < self.max_iter and not converged:
            # Update em_params with current model parameters
            em_params.A = self.A
            em_params.C = self.C
            em_params.Q = self.Q
            em_params.R = self.R
            em_params.Z_0 = self.Z_0
            em_params.V_0 = self.V_0
            
            # Perform EM step using NumPy-based EM algorithm
            with torch.no_grad():
                # Convert to numpy and run EM step
                em_params_np = self._convert_em_params_to_numpy(em_params)
                C_new_np, R_new_np, A_new_np, Q_new_np, Z_0_new_np, V_0_new_np, loglik = self.em.forward(em_params_np)
                
                # Convert results back to torch tensors
                C_new = torch.from_numpy(C_new_np).to(device=device, dtype=dtype)
                R_new = torch.from_numpy(R_new_np).to(device=device, dtype=dtype)
                A_new = torch.from_numpy(A_new_np).to(device=device, dtype=dtype)
                Q_new = torch.from_numpy(Q_new_np).to(device=device, dtype=dtype)
                Z_0_new = torch.from_numpy(Z_0_new_np).to(device=device, dtype=dtype)
                V_0_new = torch.from_numpy(V_0_new_np).to(device=device, dtype=dtype)
                
                if num_iter == 0:
                    first_iter_loglik = loglik
                    # Don't use first iteration loglik for change calculation
                    # It will be replaced by second iteration's loglik
                # Update parameters (EM doesn't use gradients, so we update directly)
                self.A.data = A_new
                self.C.data = C_new
                self.Q.data = Q_new
                self.R.data = R_new
                self.Z_0.data = Z_0_new
                self.V_0.data = V_0_new
            
            # Check for NaN in parameters
            has_nan = (
                torch.any(torch.isnan(C_new)) or torch.any(torch.isnan(A_new)) or
                torch.any(torch.isnan(Q_new)) or torch.any(torch.isnan(R_new)) or
                torch.any(torch.isnan(Z_0_new)) or torch.any(torch.isnan(V_0_new)) or
                (isinstance(loglik, float) and (np.isnan(loglik) or np.isinf(loglik)))
            )
            if has_nan:
                _logger.error(f"EM algorithm: NaN/Inf at iteration {num_iter + 1}, stopping")
                break
            
            # Track parameter stability
            try:
                A_norm = torch.linalg.norm(self.A.data).item()
                C_norm = torch.linalg.norm(self.C.data).item()
                
                if previous_A_norm is not None and previous_C_norm is not None and num_iter > 5:
                    A_change = abs(A_norm - previous_A_norm) / max(previous_A_norm, 1e-10)
                    C_change = abs(C_norm - previous_C_norm) / max(previous_C_norm, 1e-10)
                    param_change = max(A_change, C_change)
                    
                    if param_change < 1e-6:
                        _logger.debug(f"EM: Parameters stable (change={param_change:.2e}) at iter {num_iter + 1}")
                    elif param_change > 10.0:
                        _logger.warning(f"EM: Parameters changing significantly (change={param_change:.2e}) at iter {num_iter + 1}")
                
                previous_A_norm = A_norm
                previous_C_norm = C_norm
            except (RuntimeError, ValueError):
                pass
            
            # Delete intermediate tensors to free memory (after all uses)
            del C_new, R_new, A_new, Q_new, Z_0_new, V_0_new
            
            # Clear GPU cache after parameter updates to free memory
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
                import gc
                gc.collect()
            
            # Log memory usage periodically for monitoring
            if (num_iter % 10 == 0 or converged) and torch.cuda.is_available():
                try:
                    mem_used = torch.cuda.memory_allocated() / (1024**3)  # GB
                    mem_reserved = torch.cuda.memory_reserved() / (1024**3)  # GB
                    _logger.debug(f"DFM EM iteration {num_iter}: GPU memory - used={mem_used:.2f}GB, reserved={mem_reserved:.2f}GB")
                except Exception:
                    pass
            
            # Track best log-likelihood
            if loglik > best_loglik:
                best_loglik = loglik
                best_params = {
                    'A': self.A.data.clone(),
                    'C': self.C.data.clone(),
                    'Q': self.Q.data.clone(),
                    'R': self.R.data.clone(),
                    'Z_0': self.Z_0.data.clone(),
                    'V_0': self.V_0.data.clone()
                }
            
            # Check for log-likelihood deterioration
            if num_iter > 10 and loglik < best_loglik - 1000:
                _logger.warning(
                    f"EM: Log-likelihood deteriorated (best: {best_loglik:.4f}, current: {loglik:.4f}). "
                    f"Reverting to best parameters."
                )
                if best_params is not None:
                    with torch.no_grad():
                        self.A.data = best_params['A']
                        self.C.data = best_params['C']
                        self.Q.data = best_params['Q']
                        self.R.data = best_params['R']
                        self.Z_0.data = best_params['Z_0']
                        self.V_0.data = best_params['V_0']
                loglik = best_loglik
                break
            
            # Check convergence
            if num_iter > 2:
                converged, change = self.em.check_convergence(
                    loglik, previous_loglik, self.threshold, verbose=(num_iter % 10 == 0)
                )
            else:
                change = abs(loglik - previous_loglik) if previous_loglik != float('-inf') else 0.0
            
            previous_loglik = loglik
            num_iter += 1
            
            # Log metrics
            self.log('train/loglik', loglik, on_step=False, on_epoch=True)
            self.log('train/em_iteration', float(num_iter), on_step=False, on_epoch=True)
            self.log('train/loglik_change', change, on_step=False, on_epoch=True)
            # Log every iteration for first 10, then every 5, then every 10
            should_log = (
                num_iter <= 10 or 
                num_iter % 5 == 0 or 
                converged or
                num_iter == 1
            )
            
            if should_log:
                status = " ✓" if converged else ""
                # Get GPU memory usage if available
                gpu_mem_info = ""
                if torch.cuda.is_available():
                    try:
                        mem_used = torch.cuda.memory_allocated() / (1024**3)  # GB
                        mem_reserved = torch.cuda.memory_reserved() / (1024**3)  # GB
                        gpu_mem_info = f" | GPU: {mem_used:.1f}GB/{mem_reserved:.1f}GB"
                    except:
                        pass
                
                _logger.info(
                    f"EM iteration {num_iter:4d}/{self.max_iter}: "
                    f"loglik={loglik:12.6f}, change={change:10.2e}{status}{gpu_mem_info}"
                )
        
        # Store final state
        self.training_state = DFMTrainingState(
            A=self.A.data.clone(),
            C=self.C.data.clone(),
            Q=self.Q.data.clone(),
            R=self.R.data.clone(),
            Z_0=self.Z_0.data.clone(),
            V_0=self.V_0.data.clone(),
            loglik=loglik,
            num_iter=num_iter,
            converged=converged
        )
        
        # Final status
        fit_duration = time.time() - fit_start_time
        status_msg = "converged" if converged else f"stopped (change: {change:.2e})"
        _logger.info(f"{'='*70}")
        _logger.info(f"EM training {status_msg} after {num_iter} iterations")
        _logger.info(f"  Final log-likelihood: {loglik:.6f}")
        _logger.info(f"  Duration: {fit_duration:.2f} seconds ({fit_duration/60:.2f} minutes)")
        _logger.info(f"  Average time per iteration: {fit_duration/max(num_iter, 1):.2f} seconds")
        _logger.info(f"{'='*70}")
        
        return self.training_state
    
    def get_result(self) -> DFMResult:
        """Extract DFMResult from trained model."""
        if self.training_state is None:
            raise RuntimeError(
                "DFM get_result failed: model has not been fitted yet. "
                "Please call fit_em() first."
            )
        
        if self.data_processed is None:
            raise RuntimeError(
                "DFM get_result failed: data not available. "
                "Please ensure fit_em() was called with data."
            )
        
        # Get final smoothed factors using Kalman filter
        # data_processed is already (N x T) format, no need to transpose
        y = self.data_processed  # (N x T)
        
        # Run final Kalman smoothing with converged parameters
        # Convert to numpy for NumPy-based Kalman filter
        from ..ssm.utils import _to_numpy
        y_np = _to_numpy(y)
        A_np = _to_numpy(self.training_state.A)
        C_np = _to_numpy(self.training_state.C)
        Q_np = _to_numpy(self.training_state.Q)
        R_np = _to_numpy(self.training_state.R)
        Z_0_np = _to_numpy(self.training_state.Z_0)
        V_0_np = _to_numpy(self.training_state.V_0)
        
        # Call NumPy-based Kalman filter
        zsmooth_np, Vsmooth_np, _, _ = self.kalman.forward(
            y_np,
            A_np,
            C_np,
            Q_np,
            R_np,
            Z_0_np,
            V_0_np
        )
        
        # zsmooth is (m x (T+1)), transpose to ((T+1) x m)
        Zsmooth = zsmooth_np.T
        Z = Zsmooth[1:, :]  # T x m (skip initial state)
        
        # Convert parameters to numpy
        A = self.training_state.A.cpu().numpy()
        C = self.training_state.C.cpu().numpy()
        Q = self.training_state.Q.cpu().numpy()
        R = self.training_state.R.cpu().numpy()
        Z_0 = self.training_state.Z_0.cpu().numpy()
        V_0 = self.training_state.V_0.cpu().numpy()
        r = self.r.cpu().numpy()
        
        # Compute smoothed data
        x_sm = Z @ C.T  # T x N (standardized smoothed data)
        
        # Unstandardize
        Wx_clean = np.where(np.isnan(self.Wx), 1.0, self.Wx) if self.Wx is not None else np.ones(C.shape[0])
        Mx_clean = np.where(np.isnan(self.Mx), 0.0, self.Mx) if self.Mx is not None else np.zeros(C.shape[0])
        X_sm = x_sm * Wx_clean + Mx_clean  # T x N (unstandardized smoothed data)
        
        # Create result object
        result = DFMResult(
            x_sm=x_sm,
            X_sm=X_sm,
            Z=Z,
            C=C,
            R=R,
            A=A,
            Q=Q,
            Mx=self.Mx if self.Mx is not None else np.zeros(C.shape[0]),
            Wx=self.Wx if self.Wx is not None else np.ones(C.shape[0]),
            Z_0=Z_0,
            V_0=V_0,
            r=r,
            p=self.p,
            converged=self.training_state.converged,
            num_iter=self.training_state.num_iter,
            loglik=self.training_state.loglik,
            series_ids=self.config.get_series_ids() if hasattr(self.config, 'get_series_ids') else None
        )
        
        return result
    
    def configure_optimizers(self) -> List[torch.optim.Optimizer]:
        """Configure optimizers. EM algorithm doesn't use optimizers."""
        return []
    
    def _clean_gpu_memory(self) -> None:
        """Force GPU memory cleanup with synchronization and garbage collection."""
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
            import gc
            gc.collect()
    
    def _log_gpu_memory(self, iteration: int) -> None:
        """Log GPU memory usage for monitoring.
        
        Parameters
        ----------
        iteration : int
            Current iteration number
        """
        if torch.cuda.is_available():
            try:
                mem_used = torch.cuda.memory_allocated() / (1024**3)  # GB
                mem_reserved = torch.cuda.memory_reserved() / (1024**3)  # GB
                _logger.debug(f"EM iteration {iteration}: GPU memory {mem_used:.2f}GB/{mem_reserved:.2f}GB")
            except Exception:
                pass
    
    
    def load_config(
        self,
        source: Optional[Union[str, Path, Dict[str, Any], DFMConfig, ConfigSource]] = None,
        *,
        yaml: Optional[Union[str, Path]] = None,
        mapping: Optional[Dict[str, Any]] = None,
        hydra: Optional[Union[Dict[str, Any], Any]] = None,
        base: Optional[Union[str, Path, Dict[str, Any], ConfigSource]] = None,
        override: Optional[Union[str, Path, Dict[str, Any], ConfigSource]] = None,
    ) -> 'DFM':
        """Load configuration from various sources."""
        # Use common config loading logic
        new_config = self._load_config_common(
            source=source,
            yaml=yaml,
            mapping=mapping,
            hydra=hydra,
            base=base,
            override=override,
        )
        
        # DFM-specific: Initialize r and blocks tensors
        self.r = torch.tensor(
            new_config.factors_per_block if new_config.factors_per_block is not None
            else np.ones(new_config.get_blocks_array().shape[1]),
            dtype=torch.float32
        )
        self.blocks = torch.tensor(new_config.get_blocks_array(), dtype=torch.float32)
        
        return self
    
    
    def on_train_start(self) -> None:
        """Called when training starts. Run EM algorithm."""
        # Get processed data and standardization params from DataModule
        data_module = self._get_datamodule()
        X_torch = data_module.get_processed_data()
        Mx, Wx = data_module.get_std_params()
        
        # Run EM algorithm
        self.fit_em(X_torch, Mx=Mx, Wx=Wx)
        
        super().on_train_start()
    
    def predict(
        self,
        horizon: Optional[int] = None,
        *,
        history: Optional[int] = None,
        return_series: bool = True,
        return_factors: bool = True
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """Forecast future values.
        
        Parameters
        ----------
        horizon : int, optional
            Number of periods ahead to forecast. Defaults to 1 year based on clock frequency.
        history : int, optional
            Number of historical periods for Kalman filter update. If None, uses full history.
        return_series : bool, default True
            Whether to return forecasted series.
        return_factors : bool, default True
            Whether to return forecasted factors.
            
        Returns
        -------
        np.ndarray or Tuple[np.ndarray, np.ndarray]
            Forecasted series and/or factors depending on return flags.
        """
        if self.training_state is None:
            raise ValueError(
                f"{self.__class__.__name__} prediction failed: model has not been trained yet. "
                f"Please call trainer.fit(model, data_module) first"
            )
        
        # Get result (only call get_result() if _result is None)
        if not hasattr(self, '_result') or self._result is None:
            self._result = self.get_result()
        
        result = self._result
        
        if not hasattr(result, 'Z') or result.Z is None:
            raise ValueError(
                "DFM prediction failed: result.Z is not available. "
                "This may indicate the model was not properly trained or result object is corrupted."
            )
        
        # Compute default horizon
        if horizon is None:
            from ..config.utils import get_periods_per_year
            from ..utils.helpers import get_clock_frequency
            clock = get_clock_frequency(self.config, 'm')
            horizon = get_periods_per_year(clock)
        
        # Validate horizon
        if horizon <= 0:
            raise ValueError(f"horizon must be positive, got {horizon}")
        
        # Extract model parameters
        A = result.A
        C = result.C
        Wx = result.Wx
        Mx = result.Mx
        p = getattr(result, 'p', 1)  # VAR order, default to 1 for DFM
        
        # Get device for tensor operations
        device = next(self.parameters()).device if len(list(self.parameters())) > 0 else torch.device('cpu')
        dtype = torch.float32
        
        # Update factor state with history if specified
        if history is not None and history > 0:
            Z_last_updated = self._update_factor_state_with_history(
                history=history,
                result=result,
                kalman_filter=getattr(self, 'kalman', None)
            )
            if Z_last_updated is not None:
                # Convert to numpy if torch tensor (for DFM, Kalman filter returns numpy)
                if isinstance(Z_last_updated, torch.Tensor):
                    Z_last = Z_last_updated.detach().cpu().numpy()
                else:
                    Z_last = Z_last_updated
            else:
                # Fallback to training state if update failed
                Z_last = result.Z[-1, :]
        else:
            # Use training state (default behavior)
            Z_last = result.Z[-1, :]
        
        # Validate factor state and parameters
        if np.any(~np.isfinite(Z_last)) or np.any(~np.isfinite(A)) or np.any(~np.isfinite(C)):
            raise ValueError(
                "DFM prediction failed: model state or parameters contain NaN/Inf. "
                "Check training convergence and data quality."
            )
        
        # Convert to tensors for tensor-based helper methods
        Z_last_tensor = torch.tensor(Z_last, device=device, dtype=dtype)
        A_tensor = torch.tensor(A, device=device, dtype=dtype)
        C_tensor = torch.tensor(C, device=device, dtype=dtype)
        Wx_tensor = torch.tensor(Wx, device=device, dtype=dtype) if Wx is not None else None
        Mx_tensor = torch.tensor(Mx, device=device, dtype=dtype) if Mx is not None else None
        
        Z_prev_tensor = None
        if p == 2 and result.Z.shape[0] >= 2:
            Z_prev_tensor = torch.tensor(result.Z[-2, :], device=device, dtype=dtype)
        
        # Forecast factors using VAR dynamics (tensor-based helper)
        Z_forecast_tensor = self._forecast_var_factors(
            Z_last=Z_last_tensor,
            A=A_tensor,
            p=p,
            horizon=horizon,
            Z_prev=Z_prev_tensor
        )
        
        # Transform factors to observations (tensor-based helper)
        X_forecast_tensor = self._transform_factors_to_observations(
            Z_forecast=Z_forecast_tensor,
            C=C_tensor,
            Wx=Wx_tensor if Wx_tensor is not None else torch.ones(C_tensor.shape[0], device=device, dtype=dtype),
            Mx=Mx_tensor if Mx_tensor is not None else torch.zeros(C_tensor.shape[0], device=device, dtype=dtype)
        )
        
        # Convert back to numpy for output (Result objects use numpy)
        X_forecast = X_forecast_tensor.detach().cpu().numpy()
        Z_forecast = Z_forecast_tensor.detach().cpu().numpy()
        
        # Validate forecast results are finite
        if np.any(~np.isfinite(X_forecast)):
            nan_count = np.sum(~np.isfinite(X_forecast))
            raise ValueError(
                f"DFM prediction failed: produced {nan_count} NaN/Inf values in forecast. "
                f"Possible numerical instability. "
                f"Please check model parameters and data quality."
            )
        
        # Attempt scaler inverse_transform if available
        scaler = getattr(self, "scaler", None)
        if scaler is not None and hasattr(scaler, "inverse_transform"):
            try:
                X_forecast = scaler.inverse_transform(X_forecast)
            except Exception as e:
                _logger.warning(
                    f"DFM prediction: scaler.inverse_transform failed, returning unstandardized values. "
                    f"error={e}"
                )
        
        if return_factors and np.any(~np.isfinite(Z_forecast)):
            raise ValueError(
                "DFM prediction failed: factor forecast contains NaN/Inf. "
                "Check model parameters and training convergence."
            )
        
        if return_series and return_factors:
            return X_forecast, Z_forecast
        if return_series:
            return X_forecast
        return Z_forecast
    
    def update(
        self,
        X_std: np.ndarray,
        *,
        history: Optional[int] = None,
        kalman_filter: Optional[Any] = None,
        scaler: Optional[Any] = None
    ) -> 'DFM':
        """Update factor state with standardized data.
        
        This method permanently updates the last factor state (result.Z[-1, :])
        using the provided standardized data. Users should handle all preprocessing
        (masking, imputation, standardization) before calling this method.
        
        Parameters
        ----------
        X_std : np.ndarray
            Standardized data array (T x N), where T is number of time periods
            and N is number of series. Data should already be standardized using
            result.Mx and result.Wx.
        history : int, optional
            Number of recent periods to use for factor state update. If None, uses
            all provided data (default). If specified (e.g., 60), uses only the most
            recent N periods. Initial state (Z_0, V_0) is always estimated from
            full training data, but the update uses only recent history for efficiency.
        kalman_filter : Any, optional
            Kalman filter instance. If None, uses default or model's kalman filter.
            
        Returns
        -------
        DFM
            Self for method chaining
            
        Examples
        --------
        >>> # Update state with new data, then predict
        >>> model.update(X_std).predict(horizon=1)
        >>> # Or update with only recent 12 periods
        >>> model.update(X_std, history=12)
        >>> forecast = model.predict(horizon=6)
        """
        self._check_trained()
        
        result = self.result
        
        # Validate input shape
        if not isinstance(X_std, np.ndarray):
            X_std = np.asarray(X_std)
        if X_std.ndim != 2:
            raise ValueError(
                f"DFM update(): X_std must be 2D array (T x N), "
                f"got shape {X_std.shape}"
            )
        
        # Handle NaN/Inf values
        X_std = np.where(np.isfinite(X_std), X_std, np.nan)
        
        # Filter to recent history if specified
        if history is not None and history > 0 and X_std.shape[0] > history:
            X_recent = X_std[-history:, :]
        else:
            X_recent = X_std
        
        # Store scaler if provided
        if scaler is not None:
            object.__setattr__(self, "scaler", scaler)
        
        # Update factor state using Kalman filter
        Z_last_updated = self._update_factor_state_dfm(
            X_recent, result, kalman_filter or getattr(self, 'kalman', None)
        )
        
        if Z_last_updated is not None:
            result.Z[-1, :] = Z_last_updated
        else:
            _logger.warning("DFM update(): Failed to update factor state, keeping current state")
            
        return self
    
    @property
    def result(self) -> DFMResult:
        """Get model result from training state.
        
        Raises
        ------
        ValueError
            If model has not been trained yet
        """
        # Check if trained and extract result from training state if needed
        self._check_trained()
        return self._result
    
    def reset(self) -> 'DFM':
        """Reset model state."""
        super().reset()
        return self


