"""Deep Dynamic Factor Model (DDFM) using PyTorch.

This module implements a PyTorch-based Deep Dynamic Factor Model that uses
a nonlinear encoder (autoencoder) to extract factors, while maintaining
linear dynamics and decoder for interpretability and compatibility with
Kalman filtering.

DDFM is a PyTorch Lightning module that inherits from BaseFactorModel.
"""

import numpy as np
from typing import Optional, Tuple, Union, List, Dict, Any, TYPE_CHECKING
from datetime import datetime, timedelta
from pathlib import Path
import polars as pl
from dataclasses import dataclass
from ..logger import get_logger

import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
# PyTorch is mandatory - no optional import needed

from .base import BaseFactorModel
from ..config import (
    DFMConfig, DEFAULT_BLOCK_NAME,
    make_config_source, ConfigSource, MergedConfigSource,
)
from ..config.results import DDFMResult
from ..config.utils import get_periods_per_year
from ..utils.data import rem_nans_spline
from ..utils.helpers import (
    safe_get_attr,
    get_clock_frequency,
    resolve_param,
)
from ..encoder.vae import Encoder, Decoder, extract_decoder_params
from ..utils.statespace import estimate_idiosyncratic_dynamics
from ..nowcast.dataview import DataView

if TYPE_CHECKING:
    from ..lightning import DFMDataModule

_logger = get_logger(__name__)


@dataclass
class DDFMTrainingState:
    """State tracking for DDFM training."""
    factors: np.ndarray
    prediction: np.ndarray
    converged: bool
    num_iter: int
    training_loss: Optional[float] = None




class DDFMModel:
    """Deep Dynamic Factor Model using PyTorch (low-level implementation).
    
    This class implements a DDFM with:
    - Nonlinear encoder (MLP) to extract factors from observations
    - Linear decoder for interpretability
    - Linear factor dynamics (VAR)
    - Kalman filtering for final smoothing
    
    The model is trained using gradient descent (Adam optimizer) to minimize
    reconstruction error, then factor dynamics are estimated via OLS, and
    final smoothing is performed using Kalman filter.
    
    .. note::
        This is the low-level implementation used internally by the ``DDFM`` class.
        For high-level API with Lightning training, use the ``DDFM`` class (defined below)
        which is a PyTorch Lightning module.
        
        Example:
        
        .. code-block:: python
        
            from dfm_python import DDFM, DFMDataModule, DDFMTrainer
            
            model = DDFM()
            model.load_config('config/ddfm.yaml')
            dm = DFMDataModule(config=model.config, data=df_processed)
            dm.setup()
            trainer = DDFMTrainer(max_epochs=100)
            trainer.fit(model, dm)
    """
    
    def __init__(
        self,
        encoder_layers: Optional[List[int]] = None,
        num_factors: Optional[int] = None,
        activation: str = 'tanh',
        use_batch_norm: bool = True,
        learning_rate: float = 0.001,
        epochs: int = 100,
        batch_size: int = 32,
        factor_order: int = 1,
        use_idiosyncratic: bool = True,
        min_obs_idio: int = 5,
        lags_input: int = 0,
        max_iter: int = 200,
        tolerance: float = 0.0005,
        disp: int = 10,
        seed: Optional[int] = None,
        **kwargs
    ):
        """Initialize DDFM model.
        
        Parameters
        ----------
        encoder_layers : List[int], optional
            Hidden layer dimensions for encoder. Default: [64, 32]
        num_factors : int, optional
            Number of factors. If None, will be inferred from config during fit.
        activation : str
            Activation function ('tanh', 'relu', 'sigmoid'). Default: 'tanh'
        use_batch_norm : bool
            Whether to use batch normalization in encoder. Default: True
        learning_rate : float
            Learning rate for Adam optimizer. Default: 0.001
        epochs : int
            Number of epochs per MCMC iteration. Default: 100
        batch_size : int
            Batch size for training. Default: 32
        factor_order : int
            VAR lag order for factor dynamics (1 or 2). Default: 1
        use_idiosyncratic : bool
            Whether to model idiosyncratic components with AR(1) dynamics. Default: True
        min_obs_idio : int
            Minimum number of observations required for idio AR(1) estimation. Default: 5
        lags_input : int
            Number of lags of inputs on encoder (default 0, i.e. same inputs and outputs). Default: 0
        max_iter : int
            Maximum number of MCMC iterations. Default: 200
        tolerance : float
            Convergence tolerance. Default: 0.0005
        disp : int
            Display intermediate results every 'disp' iterations. Default: 10
        seed : int, optional
            Random seed for reproducibility. Default: None
        """
        
        # PyTorch is mandatory - no need to check
        if factor_order not in [1, 2]:
            raise ValueError(
                f"DDFM initialization failed: factor_order must be 1 or 2, got {factor_order}. "
                f"Please provide a valid factor_order value (1 for VAR(1) or 2 for VAR(2))."
            )
        
        self.encoder_layers = encoder_layers or [64, 32]
        self.num_factors = num_factors
        self.activation = activation
        self.use_batch_norm = use_batch_norm
        self.learning_rate = learning_rate
        self.epochs = epochs  # Epochs per MCMC iteration
        self.batch_size = batch_size
        self.factor_order = factor_order
        self.use_idiosyncratic = use_idiosyncratic
        self.min_obs_idio = min_obs_idio
        self.lags_input = lags_input
        self.max_iter = max_iter
        self.tolerance = tolerance
        self.disp = disp
        
        # PyTorch modules (will be initialized in fit)
        self.encoder: Optional[Encoder] = None
        self.decoder: Optional[Decoder] = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Random number generator for MC sampling
        self.rng = np.random.RandomState(seed if seed is not None else 3)
    
    def predict(
        self,
        horizon: Optional[int] = None,
        *,
        return_series: bool = True,
        return_factors: bool = True
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """Forecast future values.
        
        Parameters
        ----------
        horizon : int, optional
            Number of periods ahead to forecast. If None, defaults to 1 year
            of periods based on clock frequency.
        return_series : bool, default True
            Whether to return forecasted series.
        return_factors : bool, default True
            Whether to return forecasted factors.
            
        Returns
        -------
        np.ndarray or Tuple[np.ndarray, np.ndarray]
            Forecasted series and/or factors
        """
        if self._result is None:
            raise ValueError(
                "DDFM prediction failed: model has not been fitted yet. "
                "Please call trainer.fit(model, data_module) first."
            )
        
        # Default horizon
        if horizon is None:
            if self._config is not None:
                clock = get_clock_frequency(self._config, 'm')
                horizon = get_periods_per_year(clock)
            else:
                horizon = 12  # Default to 12 periods if no config
        
        if horizon <= 0:
            raise ValueError(
                f"DDFM prediction failed: horizon must be a positive integer, got {horizon}. "
                f"Please provide a positive integer value for the forecast horizon."
            )
        
        # Extract parameters
        A = self._result.A  # Factor dynamics (m x m) for VAR(1) or (m x 2m) for VAR(2)
        C = self._result.C
        Wx = self._result.Wx
        Mx = self._result.Mx
        Z_last = self._result.Z[-1, :]  # Last factor estimate (m,)
        p = self._result.p  # VAR order
        
        # Deterministic forecast
        if p == 1:
            # VAR(1): f_t = A @ f_{t-1}
            Z_forecast = np.zeros((horizon, Z_last.shape[0]))
            Z_forecast[0, :] = A @ Z_last
            for h in range(1, horizon):
                Z_forecast[h, :] = A @ Z_forecast[h - 1, :]
        elif p == 2:
            # VAR(2): f_t = A1 @ f_{t-1} + A2 @ f_{t-2}
            # Need last two factor values
            if self._result.Z.shape[0] < 2:
                # Fallback to VAR(1) if not enough history
                Z_forecast = np.zeros((horizon, Z_last.shape[0]))
                A1 = A[:, :Z_last.shape[0]]
                Z_forecast[0, :] = A1 @ Z_last
                for h in range(1, horizon):
                    Z_forecast[h, :] = A1 @ Z_forecast[h - 1, :]
            else:
                Z_prev = self._result.Z[-2, :]  # f_{t-2}
                A1 = A[:, :Z_last.shape[0]]
                A2 = A[:, Z_last.shape[0]:]
                Z_forecast = np.zeros((horizon, Z_last.shape[0]))
                Z_forecast[0, :] = A1 @ Z_last + A2 @ Z_prev
                if horizon > 1:
                    Z_forecast[1, :] = A1 @ Z_forecast[0, :] + A2 @ Z_last
                for h in range(2, horizon):
                    Z_forecast[h, :] = A1 @ Z_forecast[h - 1, :] + A2 @ Z_forecast[h - 2, :]
        else:
            raise ValueError(
                f"DDFM prediction failed: unsupported VAR order {p}. "
                f"Only VAR(1) and VAR(2) are supported. Please use factor_order=1 or factor_order=2."
            )
        
        # Transform to observations
        X_forecast_std = Z_forecast @ C.T
        X_forecast = X_forecast_std * Wx + Mx
        
        if return_series and return_factors:
            return X_forecast, Z_forecast
        if return_series:
            return X_forecast
        return Z_forecast
    
    def generate_dataset(
        self,
        target_series: str,
        periods: List[datetime],
        backward: int = 0,
        forward: int = 0,
        dataview: Optional['DataView'] = None
    ) -> Dict[str, Any]:
        """Generate dataset for DFM evaluation.
        
        Note: Requires data to be stored during fit(). Data is automatically
        stored from DataModule during fit().
        """
        if not hasattr(self, '_data') or self._data is None:
            raise ValueError(
                "DDFM generate_dataset failed: model has not been fitted with DataModule yet. "
                "Please call trainer.fit(model, data_module) first to store data."
            )
        
        from ..utils.helpers import find_series_index
        from ..utils.time import find_time_index
        from ..nowcast.dataview import DataView
        
        i_series = find_series_index(self._config, target_series)
        X_features, y_baseline, y_actual, metadata, backward_results = [], [], [], [], []
        
        if dataview is not None:
            dataview_factory = dataview
        else:
            # Convert data to numpy if needed
            if hasattr(self._data, 'to_numpy'):
                X_data = self._data.to_numpy()
            else:
                X_data = np.asarray(self._data)
            
            dataview_factory = DataView.from_arrays(
                X=X_data, Time=self._time,
                Z=self._original_data, config=self._config,
                X_frame=self._data_frame
            )
        if dataview_factory.config is None:
            dataview_factory.config = self._config
        
        for period in periods:
            view_obj = dataview_factory.with_view_date(period)
            X_view, Time_view, _ = view_obj.materialize()
            
            if backward > 0:
                nowcasts, data_view_dates = [], []
                for weeks_back in range(backward, -1, -1):
                    data_view_date = period - timedelta(weeks=weeks_back)
                    view_past = dataview_factory.with_view_date(data_view_date)
                    X_view_past, Time_view_past, _ = view_past.materialize()
                    # Access nowcast through _nowcast_ref (set by high-level DDFM class)
                    nowcast_obj = getattr(self, '_nowcast_ref', None)
                    if nowcast_obj is None:
                        raise ValueError(
                            "DDFM nowcast failed: requires high-level DDFM instance. "
                            "Please call nowcast() from DDFM class, not DDFMModel."
                        )
                    nowcast_val = nowcast_obj(
                        target_series=target_series,
                        view_date=view_past.view_date or data_view_date,
                        target_period=period
                    )
                    nowcasts.append(nowcast_val)
                    data_view_dates.append(view_past.view_date or data_view_date)
                baseline_nowcast = nowcasts[-1]
                backward_results.append({
                    'nowcasts': np.array(nowcasts),
                    'data_view_dates': data_view_dates,
                    'target_date': period
                })
            else:
                # Access nowcast through _nowcast_ref (set by high-level DDFM class)
                nowcast_obj = getattr(self, '_nowcast_ref', None)
                if nowcast_obj is None:
                    raise ValueError(
                        "DDFM nowcast failed: requires high-level DDFM instance. "
                        "Please call nowcast() from DDFM class, not DDFMModel."
                    )
                baseline_nowcast = nowcast_obj(
                    target_series=target_series,
                    view_date=view_obj.view_date or period,
                    target_period=period
                )
            
            y_baseline.append(baseline_nowcast)
            t_idx = find_time_index(self._time, period)
            actual_val = np.nan
            # Convert data to numpy for indexing
            if hasattr(self._data, 'to_numpy'):
                data_array = self._data.to_numpy()
            else:
                data_array = np.asarray(self._data)
            if t_idx is not None and t_idx < data_array.shape[0] and i_series < data_array.shape[1]:
                actual_val = data_array[t_idx, i_series]
            y_actual.append(actual_val)
            
            # Extract features
            if self._result is not None and hasattr(self._result, 'Z'):
                latest_factors = self._result.Z[-1, :] if self._result.Z.shape[0] > 0 else np.zeros(self._result.Z.shape[1])
            else:
                latest_factors = np.array([])
            if X_view.shape[0] > 0:
                mean_residual = np.nanmean(X_view[-1, :]) if X_view.shape[0] > 0 else 0.0
            else:
                mean_residual = 0.0
            features = np.concatenate([latest_factors, [mean_residual]])
            X_features.append(features)
            metadata.append({'period': period, 'target_series': target_series})
        
        return {
            'X': np.array(X_features),
            'y_baseline': np.array(y_baseline),
            'y_actual': np.array(y_actual),
            'y_target': np.array(y_actual) - np.array(y_baseline),
            'metadata': metadata,
            'backward_results': backward_results if backward > 0 else []
        }
    
    def get_state(
        self,
        t: Union[int, datetime],
        target_series: str,
        lookback: Optional[int] = None
    ) -> Dict[str, Any]:
        """Get DFM state at time t.
        
        Note: Requires data to be stored during fit(). Data is automatically
        stored from DataModule during fit().
        """
        if not hasattr(self, '_data') or self._data is None:
            raise ValueError(
                "DDFM get_state failed: model has not been fitted with DataModule yet. "
                "Please call trainer.fit(model, data_module) first to store data."
            )
        
        from ..config.utils import get_periods_per_year
        from ..utils.helpers import find_series_index
        from ..utils.time import find_time_index, convert_to_timestamp
        # create_data_view is already imported at module level
        
        if lookback is None:
            clock = get_clock_frequency(self._config, 'm')
            lookback = get_periods_per_year(clock)
        
        t = convert_to_timestamp(t, self._time, None)
        i_series = find_series_index(self._config, target_series)
        
        # Convert data to numpy if needed
        if hasattr(self._data, 'to_numpy'):
            X_data = self._data.to_numpy()
        else:
            X_data = np.asarray(self._data)
        
        X_view, Time_view, _ = create_data_view(
            X=X_data, Time=self._time,
            Z=self._original_data, config=self._config, view_date=t
        )
        
        # Access nowcast through _nowcast_ref (set by high-level DDFM class)
        nowcast_obj = getattr(self, '_nowcast_ref', None)
        if nowcast_obj is None:
            raise ValueError(
                "DDFM nowcast failed: requires high-level DDFM instance. "
                "Please call nowcast() from DDFM class, not DDFMModel."
            )
        baseline_nowcast = nowcast_obj(target_series=target_series, view_date=t, target_period=None)
        
        baseline_forecast, actual_history, residuals, factors_history = [], [], [], []
        t_idx = find_time_index(self._time, t)
        if t_idx is None:
            raise ValueError(
                f"DDFM get_state failed: time {t} not found in model_instance._time. "
                f"Please provide a valid time value that exists in the model's time index."
            )
        
        for i in range(max(0, t_idx - lookback + 1), t_idx + 1):
            if i < X_data.shape[0]:
                forecast_val = baseline_nowcast
                baseline_forecast.append(forecast_val)
                actual_val = X_data[i, i_series] if i_series < X_data.shape[1] else np.nan
                actual_history.append(actual_val)
                residuals.append(actual_val - forecast_val)
                if self._result is not None and hasattr(self._result, 'Z') and i < self._result.Z.shape[0]:
                    factors_history.append(self._result.Z[i, :])
                else:
                    factors_history.append(np.array([]))
        
        return {
            'time': t,
            'target_series': target_series,
            'baseline_nowcast': baseline_nowcast,
            'baseline_forecast': np.array(baseline_forecast),
            'actual_history': np.array(actual_history),
            'residuals': np.array(residuals),
            'factors_history': factors_history
        }

# ============================================================================
# High-level API Classes
# ============================================================================
from ..config.results import DFMResult
from ..utils.helpers import (
    safe_get_method,
    safe_get_attr,
    get_clock_frequency,
)
from ..utils.time import TimeIndex

if TYPE_CHECKING:
    from ..nowcast import Nowcast
    from ..lightning import DFMDataModule

class DDFM(BaseFactorModel):
    """High-level API for Deep Dynamic Factor Model (PyTorch Lightning module).
    
    This class is a PyTorch Lightning module that can be used with standard
    Lightning training patterns. It inherits from BaseFactorModel and implements
    DDFM training using autoencoder and MCMC procedure.
    
    Example (Standard Lightning Pattern):
        >>> from dfm_python import DDFM, DFMDataModule, DDFMTrainer
        >>> import polars as pl
        >>> 
        >>> # Step 1: Load and preprocess data
        >>> df = pl.read_csv('data/finance.csv')
        >>> df_processed = df.select([col for col in df.columns if col != 'date'])
        >>> 
        >>> # Step 2: Create DataModule
        >>> dm = DFMDataModule(config_path='config/ddfm_config.yaml', data=df_processed)
        >>> dm.setup()
        >>> 
        >>> # Step 3: Create model and load config
        >>> model = DDFM(encoder_layers=[64, 32], num_factors=2)
        >>> model.load_config('config/ddfm_config.yaml')
        >>> 
        >>> # Step 4: Create trainer and fit
        >>> trainer = DDFMTrainer(max_epochs=100)
        >>> trainer.fit(model, dm)
        >>> 
        >>> # Step 5: Predict
        >>> Xf, Zf = model.predict(horizon=6)
    
    Note on GPU Memory Usage:
        DDFM typically uses less GPU memory than DFM because:
        1. DDFM uses batch training (batch_size=32), processing data in small chunks
        2. DFM uses EM algorithm with Kalman filtering, which stores large covariance
           matrices on GPU: V (m x m x T+1), R (N x N), Q (m x m) for all time steps
        3. DDFM's neural network (encoder/decoder) is relatively small compared to
           the large covariance matrices in DFM's Kalman smoother
        4. DDFM processes data incrementally, while DFM processes the full dataset
           simultaneously during Kalman smoothing
        
        For example, with T=8000, N=22, m=2:
        - DFM: V matrix alone is (2 x 2 x 8001) = ~128KB, plus R (22 x 22) = ~4KB,
          plus all intermediate matrices during Kalman smoothing
        - DDFM: Processes batches of 32 samples at a time, so only (32 x 22) = ~3KB
          per batch on GPU, plus small encoder/decoder weights
    """
    
    def __init__(
        self,
        config: Optional[DFMConfig] = None,
        encoder_layers: Optional[List[int]] = None,
        num_factors: Optional[int] = None,
        activation: str = 'tanh',
        use_batch_norm: bool = True,
        learning_rate: float = 0.001,
        epochs: int = 100,
        batch_size: int = 32,
        factor_order: int = 1,
        use_idiosyncratic: bool = True,
        min_obs_idio: int = 5,
        max_iter: int = 200,
        tolerance: float = 0.0005,
        disp: int = 10,
        seed: Optional[int] = None,
        **kwargs
    ):
        """Initialize DDFM instance.
        
        Parameters
        ----------
        config : DFMConfig, optional
            DFM configuration. Can be loaded later via load_config().
        encoder_layers : List[int], optional
            Hidden layer dimensions for encoder. Default: [64, 32]
        num_factors : int, optional
            Number of factors. If None, inferred from config.
        activation : str, default 'tanh'
            Activation function ('tanh', 'relu', 'sigmoid')
        use_batch_norm : bool, default True
            Whether to use batch normalization in encoder
        learning_rate : float, default 0.001
            Learning rate for Adam optimizer
        epochs : int, default 100
            Number of epochs per MCMC iteration
        batch_size : int, default 32
            Batch size for training
        factor_order : int, default 1
            VAR lag order for factor dynamics (1 or 2)
        use_idiosyncratic : bool, default True
            Whether to model idiosyncratic components
        min_obs_idio : int, default 5
            Minimum observations for idio AR(1) estimation
        max_iter : int, default 200
            Maximum number of MCMC iterations
        tolerance : float, default 0.0005
            Convergence tolerance
        disp : int, default 10
            Display progress every 'disp' iterations
        seed : int, optional
            Random seed for reproducibility
        **kwargs
            Additional arguments passed to BaseFactorModel
        """
        super().__init__(**kwargs)
        
        # If config not provided, create a placeholder that will be set via load_config
        # Note: DDFM does not use block structure, but BaseModelConfig requires blocks
        # We create a minimal default block that will be ignored by DDFM
        if config is None:
            from ..config.schema import DFMConfig, SeriesConfig, DEFAULT_BLOCK_NAME
            config = DFMConfig(
                series=[SeriesConfig(series_id='placeholder', frequency='m', transformation='lin', blocks=[1])],
                blocks={DEFAULT_BLOCK_NAME: {'factors': 1, 'ar_lag': 1, 'clock': 'm'}}
            )
        
        self._config = config
        self.encoder_layers = encoder_layers or [64, 32]
        self.activation = activation
        self.use_batch_norm = use_batch_norm
        self.learning_rate = learning_rate
        self.epochs_per_iter = epochs
        self.batch_size = batch_size
        self.factor_order = factor_order
        self.use_idiosyncratic = use_idiosyncratic
        self.min_obs_idio = min_obs_idio
        self.max_iter = max_iter
        self.tolerance = tolerance
        self.disp = disp
        
        # Determine number of factors
        # Note: DDFM does not use block structure - num_factors is specified directly
        if num_factors is None:
            # Try to get from config num_factors (DDFM-specific parameter)
            if hasattr(config, 'num_factors') and config.num_factors is not None:
                self.num_factors = config.num_factors
            else:
                # Default to 1 if not specified
                self.num_factors = 1
            # Track that num_factors was computed from config, not explicitly set
            self._num_factors_explicit = False
        else:
            self.num_factors = num_factors
            # Track that num_factors was explicitly set
            self._num_factors_explicit = True
        
        # Initialize encoder and decoder
        # Note: input_dim and output_dim will be set in setup() when we know data dimensions
        self.encoder: Optional[Encoder] = None
        self.decoder: Optional[Decoder] = None
        
        # Training state
        self.Mx: Optional[np.ndarray] = None
        self.Wx: Optional[np.ndarray] = None
        self.data_processed: Optional[torch.Tensor] = None
        
        # MCMC state
        self.current_mcmc_data: Optional[torch.Tensor] = None
        self.mcmc_iteration: int = 0
        
        # Random number generator for MC sampling
        self.rng = np.random.RandomState(seed if seed is not None else 3)
        
        # Low-level implementation for utility methods
        self._model_impl = DDFMModel(
            encoder_layers=encoder_layers,
            num_factors=num_factors,
            activation=activation,
            use_batch_norm=use_batch_norm,
            learning_rate=learning_rate,
            epochs=epochs,
            batch_size=batch_size,
            factor_order=factor_order,
            use_idiosyncratic=use_idiosyncratic,
            min_obs_idio=min_obs_idio,
            max_iter=max_iter,
            tolerance=tolerance,
            disp=disp,
            seed=seed,
            **kwargs
        )
        self._data_module: Optional['DFMDataModule'] = None
        self._nowcast: Optional['Nowcast'] = None
    
    def setup(self, stage: Optional[str] = None) -> None:
        """Initialize encoder and decoder when data dimensions are known.
        
        This is called by Lightning before configure_optimizers(), so we need
        to initialize encoder/decoder here if datamodule is available.
        If not available here, will be initialized in on_train_start().
        """
        # Access datamodule if available (trainer should be attached by now)
        if hasattr(self, 'trainer') and self.trainer is not None:
            if hasattr(self.trainer, 'datamodule') and self.trainer.datamodule is not None:
                self._data_module = self.trainer.datamodule
                try:
                    # Get data to determine input dimension
                    # Note: datamodule.setup() should have been called by Lightning already
                    X_torch = self._data_module.get_processed_data()
                    input_dim = X_torch.shape[1]
                    
                    # Initialize networks if not already initialized
                    if self.encoder is None or self.decoder is None:
                        self.initialize_networks(input_dim)
                        # Move to same device as data
                        device = X_torch.device
                        self.encoder = self.encoder.to(device)
                        self.decoder = self.decoder.to(device)
                        _logger.debug(f"Initialized encoder/decoder in setup() with input_dim={input_dim}")
                except (RuntimeError, AttributeError) as e:
                    # If datamodule not ready yet, will initialize in on_train_start()
                    _logger.debug(f"Could not initialize networks in setup(): {e}. Will initialize in on_train_start()")
                    pass
    
    def initialize_networks(self, input_dim: int) -> None:
        """Initialize encoder and decoder networks.
        
        Parameters
        ----------
        input_dim : int
            Number of input features (number of series)
        """
        self.encoder = Encoder(
            input_dim=input_dim,
            hidden_dims=self.encoder_layers,
            output_dim=self.num_factors,
            activation=self.activation,
            use_batch_norm=self.use_batch_norm,
        )
        
        # Use standard decoder (block structure removed for simplicity)
        self.decoder = Decoder(
            input_dim=self.num_factors,
            output_dim=input_dim,
            use_bias=True,
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through encoder and decoder.
        
        Parameters
        ----------
        x : torch.Tensor
            Input data (batch_size x T x N) or (T x N)
            
        Returns
        -------
        reconstructed : torch.Tensor
            Reconstructed data
        """
        if self.encoder is None or self.decoder is None:
            raise RuntimeError(
                "DDFM forward pass failed: encoder and decoder must be initialized. "
                "Please ensure the model is properly initialized before calling forward()."
            )
        
        # Handle different input shapes
        if x.ndim == 3:
            batch_size, T, N = x.shape
            x_flat = x.view(batch_size * T, N)
            factors = self.encoder(x_flat)
            reconstructed = self.decoder(factors)
            return reconstructed.view(batch_size, T, N)
        else:
            factors = self.encoder(x)
            reconstructed = self.decoder(factors)
            return reconstructed
    
    def training_step(self, batch: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], batch_idx: int) -> torch.Tensor:
        """Training step for autoencoder.
        
        This is used for standard autoencoder training and also called
        during MCMC procedure for each MC sample.
        
        Missing data (NaN values) are handled by masking them in the loss function,
        similar to the original DDFM implementation (mse_missing).
        
        Parameters
        ----------
        batch : torch.Tensor or tuple
            Data tensor or (data, target) tuple where both are the same for reconstruction.
            Data may contain NaN values which are masked in the loss.
        batch_idx : int
            Batch index
            
        Returns
        -------
        loss : torch.Tensor
            Reconstruction loss (MSE with missing data masking)
        """
        # Handle both tuple and single tensor batches
        if isinstance(batch, tuple):
            data, target = batch
        else:
            data = batch
            target = data  # For autoencoder, target is same as input
        
        # Ensure data is on the same device as the model
        device = next(self.parameters()).device
        data = data.to(device)
        target = target.to(device)
        
        # Forward pass
        reconstructed = self.forward(data)
        
        # Compute loss with missing data masking (mse_missing)
        # Create mask: 1 for non-missing, 0 for missing (NaN)
        mask = torch.where(torch.isnan(target), torch.zeros_like(target), torch.ones_like(target))
        
        # Replace NaN with zeros for computation
        target_clean = torch.where(torch.isnan(target), torch.zeros_like(target), target)
        
        # Apply mask to predictions
        reconstructed_masked = reconstructed * mask
        
        # Compute MSE only on non-missing values
        # MSE = mean((target_clean - reconstructed_masked)^2) over non-missing elements
        squared_diff = (target_clean - reconstructed_masked) ** 2
        loss = torch.sum(squared_diff) / (torch.sum(mask) + 1e-8)  # Avoid division by zero
        
        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    
    def _validate_factors_shape(self, factors: np.ndarray, operation: str = "operation") -> np.ndarray:
        """Validate and normalize factors shape.
        
        Parameters
        ----------
        factors : np.ndarray
            Factors array to validate
        operation : str
            Operation name for error messages
            
        Returns
        -------
        np.ndarray
            Validated factors array (2D, shape T x num_factors)
        """
        factors = np.asarray(factors)
        if factors.ndim == 0 or factors.size == 0:
            raise ValueError(
                f"DDFM {operation} failed: factors is empty or invalid (shape: {factors.shape}). "
                f"This indicates training did not complete properly. "
                f"Please check training logs and ensure fit_mcmc() completed successfully."
            )
        if factors.ndim == 1:
            # Reshape to (T, num_factors) if it's 1D
            factors = factors.reshape(-1, self.num_factors) if factors.size > 0 else factors.reshape(0, self.num_factors)
        if factors.ndim != 2:
            raise ValueError(
                f"DDFM {operation} failed: factors must be 2D array (T x m), got shape {factors.shape}"
            )
        return factors
    
    def _estimate_var_with_fallback(
        self, 
        factors: np.ndarray, 
        factor_order: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Estimate VAR dynamics with error handling and fallback.
        
        Parameters
        ----------
        factors : np.ndarray
            Factors array (T x m)
        factor_order : int
            VAR order (1 or 2)
            
        Returns
        -------
        A : np.ndarray
            Transition matrix (m x m or m x 2m for VAR(2))
        Q : np.ndarray
            Innovation covariance (m x m)
        """
        from ..utils.statespace import estimate_var1, estimate_var2
        
        # Validate factors shape - check for 0-dimensional array first
        factors = np.asarray(factors)
        if factors.ndim == 0:
            _logger.warning(
                f"Factors is 0-dimensional array for VAR estimation. "
                f"Using identity matrix for A and small covariance for Q."
            )
            num_factors = self.num_factors if hasattr(self, 'num_factors') and self.num_factors > 0 else 1
            return np.eye(num_factors), np.eye(num_factors) * 1e-6
        
        # Validate factors shape
        if factors.size == 0 or factors.ndim < 2 or factors.shape[0] < 2 or factors.shape[1] == 0:
            _logger.warning(
                f"Insufficient or invalid factors shape {factors.shape} for VAR estimation. "
                f"Using identity matrix for A and small covariance for Q."
            )
            num_factors = factors.shape[1] if factors.ndim == 2 and factors.shape[1] > 0 else self.num_factors
            if num_factors == 0:
                num_factors = 1
            return np.eye(num_factors), np.eye(num_factors) * 1e-6
        
        # Estimate VAR with error handling
        try:
            if factor_order == 1:
                A_f, Q_f = estimate_var1(factors)
            elif factor_order == 2:
                A_f, Q_f = estimate_var2(factors)
            else:
                raise ValueError(
                    f"DDFM VAR estimation failed: factor_order must be 1 or 2, got {factor_order}"
                )
            
            # Validate Q_f shape
            if Q_f.ndim == 0:
                _logger.warning(
                    f"VAR estimation returned 0-dimensional Q. Using identity matrix instead."
                )
                Q_f = np.eye(factors.shape[1]) * 1e-6
            elif Q_f.ndim != 2:
                _logger.warning(
                    f"VAR estimation returned Q with unexpected shape {Q_f.shape}. Reshaping."
                )
                if Q_f.size == factors.shape[1] ** 2:
                    Q_f = Q_f.reshape(factors.shape[1], factors.shape[1])
                else:
                    Q_f = np.eye(factors.shape[1]) * 1e-6
            
            return A_f, Q_f
            
        except (ValueError, np.linalg.LinAlgError) as e:
            _logger.warning(
                f"VAR({factor_order}) estimation failed: {e}. "
                f"Using identity matrix for A and small covariance for Q."
            )
            num_factors = factors.shape[1] if factors.ndim == 2 else self.num_factors
            if num_factors == 0:
                num_factors = 1
            return np.eye(num_factors), np.eye(num_factors) * 1e-6
    
    def configure_optimizers(self):
        """Configure optimizer for autoencoder training.
        
        Returns
        -------
        list
            List containing the optimizer (PyTorch Lightning expects list/dict/tuple)
        """
        if self.encoder is None or self.decoder is None:
            # If still not initialized, create placeholder optimizer as fallback
            # This should not happen if setup() works correctly, but provides safety net
            _logger.warning(
                "Encoder/decoder not initialized in configure_optimizers(). "
                "Creating placeholder optimizer. This may indicate an issue with setup()."
            )
            # Create dummy parameter for optimizer (Lightning requires at least one optimizer)
            dummy_param = nn.Parameter(torch.zeros(1))
            optimizer = torch.optim.Adam([dummy_param], lr=self.learning_rate)
            return [optimizer]
        
        optimizer = torch.optim.Adam(
            list(self.encoder.parameters()) + list(self.decoder.parameters()),
            lr=self.learning_rate
        )
        
        return [optimizer]
    
    def _create_autoencoder_optimizer(self) -> torch.optim.Optimizer:
        """Create optimizer for autoencoder training.
        
        This is a helper method for internal use (e.g., in fit_mcmc()).
        For Lightning trainer setup, use configure_optimizers() instead.
        
        Returns
        -------
        torch.optim.Optimizer
            Adam optimizer for encoder and decoder parameters
        """
        if self.encoder is None or self.decoder is None:
            raise RuntimeError(
                "Encoder and decoder must be initialized before creating optimizer. "
                "Call initialize_networks() first."
            )
        
        optimizer = torch.optim.Adam(
            list(self.encoder.parameters()) + list(self.decoder.parameters()),
            lr=self.learning_rate
        )
        
        return optimizer
    
    def fit_mcmc(
        self,
        X: torch.Tensor,
        x_clean: torch.Tensor,
        missing_mask: np.ndarray,
        Mx: Optional[np.ndarray] = None,
        Wx: Optional[np.ndarray] = None,
        max_iter: Optional[int] = None,
        tolerance: Optional[float] = None,
        disp: Optional[int] = None,
        seed: Optional[int] = None,
    ) -> DDFMTrainingState:
        """Run MCMC iterative training procedure.
        
        This method implements the MCMC procedure for DDFM training.
        It alternates between estimating idiosyncratic dynamics, generating
        MC samples, training the autoencoder, and checking convergence.
        
        Parameters
        ----------
        X : torch.Tensor
            Standardized data with missing values (T x N)
        x_clean : torch.Tensor
            Clean data (interpolated, T x N) used for initial training
        missing_mask : np.ndarray
            Missing data mask (T x N), True where data is missing
        Mx : np.ndarray, optional
            Mean values for unstandardization (N,)
        Wx : np.ndarray, optional
            Standard deviation values for unstandardization (N,)
        max_iter : int, optional
            Maximum number of MCMC iterations (uses self.max_iter if None)
        tolerance : float, optional
            Convergence tolerance (uses self.tolerance if None)
        disp : int, optional
            Display progress every 'disp' iterations (uses self.disp if None)
        seed : int, optional
            Random seed for reproducibility
            
        Returns
        -------
        DDFMTrainingState
            Final training state with factors and convergence info
        """
        self.Mx = Mx
        self.Wx = Wx
        self.data_processed = X
        
        device = X.device
        dtype = X.dtype
        T, N = X.shape
        
        # Use instance attributes if not provided
        max_iter = max_iter if max_iter is not None else self.max_iter
        tolerance = tolerance if tolerance is not None else self.tolerance
        disp = disp if disp is not None else self.disp
        
        # Initialize networks if not done
        if self.encoder is None or self.decoder is None:
            self.initialize_networks(N)
        
        # Ensure encoder and decoder are on the correct device
        # This is critical even if they were initialized in on_train_start(),
        # as the device might differ or they might not have been moved properly
        self.encoder = self.encoder.to(device)
        self.decoder = self.decoder.to(device)
        
        # Random number generator for MC sampling
        rng = np.random.RandomState(seed if seed is not None else (self.rng.randint(0, 2**31) if hasattr(self.rng, 'randint') else 3))
        
        # Convert to numpy for MCMC procedure (some operations are easier in numpy)
        x_standardized_np = X.cpu().numpy()
        x_clean_np = x_clean.cpu().numpy()
        bool_no_miss = ~missing_mask
        
        # Initialize data structures
        data_mod_only_miss = x_standardized_np.copy()  # Original with missing values
        data_mod = x_clean_np.copy()  # Clean data (will be modified during MCMC)
        
        # Initial prediction
        x_tensor = x_clean.to(device)
        self.encoder.eval()
        self.decoder.eval()
        with torch.no_grad():
            factors_init = self.encoder(x_tensor).cpu().numpy()
            factors_tensor = torch.tensor(factors_init, device=device, dtype=dtype)
            prediction_iter = self.decoder(factors_tensor).cpu().numpy()
        
        # Initialize factors
        factors = factors_init.copy()
        
        # Update missing values with initial prediction
        bool_miss = missing_mask
        if bool_miss.any():
            data_mod_only_miss[bool_miss] = prediction_iter[bool_miss]
        
        # Initial residuals
        eps = data_mod_only_miss - prediction_iter
        
        # MCMC loop
        iter_count = 0
        not_converged = True
        prediction_prev_iter = None
        delta = float('inf')
        loss_now = float('inf')
        
        _logger.info(f"Starting MCMC training: max_iter={max_iter}, tolerance={tolerance}, epochs_per_iter={self.epochs_per_iter}")
        
        while not_converged and iter_count < max_iter:
            iter_count += 1
            self.mcmc_iteration = iter_count
            
            # Get idiosyncratic distribution
            if self.use_idiosyncratic:
                A_eps, Q_eps = estimate_idiosyncratic_dynamics(eps, missing_mask, self.min_obs_idio)
                # Convert to format expected by MCMC procedure
                phi = A_eps if A_eps.ndim == 2 else np.diag(A_eps) if A_eps.ndim == 1 else np.eye(N)
                mu_eps = np.zeros(N)
                if Q_eps.ndim == 2:
                    std_eps = np.sqrt(np.diag(Q_eps))
                elif Q_eps.ndim == 1:
                    std_eps = np.sqrt(Q_eps)
                else:
                    std_eps = np.ones(N) * 0.1
            else:
                phi = np.zeros((N, N))
                mu_eps = np.zeros(N)
                std_eps = np.ones(N) * 1e-8
            
            # Subtract conditional AR-idio mean from x
            if self.use_idiosyncratic and eps.shape[0] > 1:
                data_mod[1:] = data_mod_only_miss[1:] - eps[:-1, :] @ phi
                data_mod[:1] = data_mod_only_miss[:1]
            else:
                data_mod = data_mod_only_miss.copy()
            
            # Generate MC samples for idio (dims = epochs_per_iter x T x N)
            eps_draws = np.zeros((self.epochs_per_iter, T, N))
            for t in range(T):
                eps_draws[:, t, :] = rng.multivariate_normal(
                    mu_eps, np.diag(std_eps), size=self.epochs_per_iter
                )
            
            # Initialize noisy inputs
            x_sim_den = np.zeros((self.epochs_per_iter, T, N))
            
            # Loop over MC samples
            factors_samples = []
            for i in range(self.epochs_per_iter):
                x_sim_den[i, :, :] = data_mod.copy()
                # Corrupt input data by subtracting sampled idio innovations
                x_sim_den[i, :, :] = x_sim_den[i, :, :] - eps_draws[i, :, :]
                
                # Train autoencoder on corrupted sample (1 epoch)
                # Convert to torch and create dataset
                x_sample = torch.tensor(x_sim_den[i, :, :], device=device, dtype=dtype)
                dataset = torch.utils.data.TensorDataset(x_sample, x_sample)
                dataloader = torch.utils.data.DataLoader(
                    dataset, batch_size=self.batch_size, shuffle=True
                )
                
                # Train for 1 epoch
                self.encoder.train()
                self.decoder.train()
                # Create optimizer directly (don't use configure_optimizers() here as it's a Lightning hook)
                optimizer = self._create_autoencoder_optimizer()
                
                for batch_data, batch_target in dataloader:
                    optimizer.zero_grad()
                    reconstructed = self.forward(batch_data)
                    # Use missing-aware loss (mse_missing) for consistency
                    # Create mask for missing values (though in MCMC loop, missing values are filled)
                    mask = torch.where(torch.isnan(batch_target), torch.zeros_like(batch_target), torch.ones_like(batch_target))
                    target_clean = torch.where(torch.isnan(batch_target), torch.zeros_like(batch_target), batch_target)
                    reconstructed_masked = reconstructed * mask
                    squared_diff = (target_clean - reconstructed_masked) ** 2
                    loss = torch.sum(squared_diff) / (torch.sum(mask) + 1e-8)
                    loss.backward()
                    optimizer.step()
                # Extract factors from this sample
                x_sample_tensor = torch.tensor(x_sim_den[i, :, :], device=device, dtype=dtype)
                self.encoder.eval()
                with torch.no_grad():
                    factors_sample = self.encoder(x_sample_tensor).cpu().numpy()
                factors_samples.append(factors_sample)
            
            # Update factors: average over all MC samples
            factors = np.mean(np.array(factors_samples), axis=0)  # T x num_factors
            
            # Check convergence
            self.decoder.eval()
            with torch.no_grad():
                factors_tensor = torch.tensor(factors, device=device, dtype=dtype)
                prediction_iter = self.decoder(factors_tensor).cpu().numpy()
            
            if iter_count > 1:
                # Compute MSE on non-missing values
                mask = ~np.isnan(data_mod_only_miss)
                if np.sum(mask) > 0:
                    mse = np.nanmean((prediction_prev_iter[mask] - prediction_iter[mask]) ** 2)
                    delta = mse
                    loss_now = mse
                else:
                    delta = float('inf')
                    loss_now = float('inf')
                
                if iter_count % disp == 0:
                    _logger.info(
                        f"Iteration {iter_count}/{max_iter}: loss={loss_now:.6f}, delta={delta:.6f}"
                    )
                
                if delta < tolerance:
                    not_converged = False
                    _logger.info(
                        f"Convergence achieved in {iter_count} iterations: "
                        f"loss={loss_now:.6f}, delta={delta:.6f} < {tolerance}"
                    )
            else:
                # First iteration: compute initial loss
                mask = ~np.isnan(data_mod_only_miss)
                if np.sum(mask) > 0:
                    loss_now = np.nanmean((data_mod_only_miss[mask] - prediction_iter[mask]) ** 2)
                else:
                    loss_now = float('inf')
            
            # Store previous prediction for convergence checking
            prediction_prev_iter = prediction_iter.copy()
            
            # Update missing values with current prediction
            if bool_miss.any():
                data_mod_only_miss[bool_miss] = prediction_iter[bool_miss]
            
            # Update residuals
            eps = data_mod_only_miss - prediction_iter
        
        if not_converged:
            delta_str = f"{delta:.6f}" if iter_count > 1 else "N/A"
            _logger.warning(
                f"Convergence not achieved within {max_iter} iterations. "
                f"Final delta: {delta_str}"
            )
        
        converged = not not_converged
        
        # Validate and normalize factors shape before storing
        factors = self._validate_factors_shape(factors, operation="fit_mcmc")
        
        # Store final state
        self.training_state = DDFMTrainingState(
            factors=factors,
            prediction=prediction_iter,
            converged=converged,
            num_iter=iter_count,
            training_loss=loss_now
        )
        
        return self.training_state
    
    def get_result(self) -> DDFMResult:
        """Extract DDFMResult from trained model.
        
        Returns
        -------
        DDFMResult
            Estimation results with parameters, factors, and diagnostics
        """
        if self.training_state is None:
            raise RuntimeError(
                "DDFM get_result failed: model has not been fitted yet. "
                "Please call fit_mcmc() first."
            )
        
        if self.encoder is None or self.decoder is None:
            raise RuntimeError(
                "DDFM get_result failed: encoder and decoder must be initialized. "
                "Please ensure the model is properly initialized before getting results."
            )
        
        # Extract decoder parameters (C, bias)
        C, bias = extract_decoder_params(self.decoder)
        
        # Get factors and prediction
        factors = self.training_state.factors  # T x num_factors
        prediction_iter = self.training_state.prediction  # T x N
        
        # Validate and normalize factors shape
        factors = self._validate_factors_shape(factors, operation="get_result")
        
        # Convert to numpy
        C = C.cpu().numpy() if isinstance(C, torch.Tensor) else C
        bias = bias.cpu().numpy() if isinstance(bias, torch.Tensor) else bias
        
        # Compute residuals and estimate idiosyncratic dynamics
        if self.data_processed is not None:
            x_standardized = self.data_processed.cpu().numpy()
            # Ensure shapes match
            if x_standardized.shape != prediction_iter.shape:
                _logger.warning(
                    f"Shape mismatch in get_result: data_processed {x_standardized.shape} vs "
                    f"prediction {prediction_iter.shape}. Using prediction shape for residuals."
                )
                residuals = np.zeros_like(prediction_iter)
            else:
                residuals = x_standardized - prediction_iter
        else:
            residuals = np.zeros_like(prediction_iter)
        
        # Estimate factor dynamics (VAR) with error handling
        A_f, Q_f = self._estimate_var_with_fallback(factors, self.factor_order)
        
        # For DDFM, we use simplified state-space (factor-only)
        A = A_f
        Q = Q_f
        Z_0 = factors[0, :]
        V_0 = np.cov(factors.T)
        
        # Estimate R from residuals
        R_diag = np.var(residuals, axis=0)
        R = np.diag(np.maximum(R_diag, 1e-8))
        
        # Compute smoothed data
        x_sm = prediction_iter  # T x N (standardized)
        
        # Unstandardize
        Wx_clean = np.where(np.isnan(self.Wx), 1.0, self.Wx) if self.Wx is not None else np.ones(C.shape[0])
        Mx_clean = np.where(np.isnan(self.Mx), 0.0, self.Mx) if self.Mx is not None else np.zeros(C.shape[0])
        X_sm = x_sm * Wx_clean + Mx_clean  # T x N (unstandardized)
        
        # Create result object
        r = np.array([self.num_factors])
        
        result = DDFMResult(
            x_sm=x_sm,
            X_sm=X_sm,
            Z=factors,  # T x m
            C=C,
            R=R,
            A=A,
            Q=Q,
            Mx=self.Mx if self.Mx is not None else np.zeros(C.shape[0]),
            Wx=self.Wx if self.Wx is not None else np.ones(C.shape[0]),
            Z_0=Z_0,
            V_0=V_0,
            r=r,
            p=self.factor_order,
            converged=self.training_state.converged,
            num_iter=self.training_state.num_iter,
            loglik=None,  # DDFM doesn't compute loglik in same way
            series_ids=self.config.get_series_ids() if hasattr(self.config, 'get_series_ids') else None,
            block_names=None,  # DDFM does not use block structure
            training_loss=self.training_state.training_loss,
            encoder_layers=self.encoder_layers,
            use_idiosyncratic=self.use_idiosyncratic,
        )
        
        return result
    
    def on_train_start(self) -> None:
        """Called when training starts. Run MCMC training."""
        # Store data_module reference for later use (nowcast, predict, etc.)
        if hasattr(self.trainer, 'datamodule'):
            self._data_module = self.trainer.datamodule
            
            # Get processed data and standardization params from DataModule
            X_torch = self._data_module.get_processed_data()
            Mx, Wx = self._data_module.get_standardization_params()
            
            # Initialize encoder/decoder if not already done in setup()
            if self.encoder is None or self.decoder is None:
                input_dim = X_torch.shape[1]
                self.initialize_networks(input_dim)
                # Move to same device as data
                device = X_torch.device
                self.encoder = self.encoder.to(device)
                self.decoder = self.decoder.to(device)
                _logger.debug(f"Initialized encoder/decoder in on_train_start() with input_dim={input_dim}")
            
            # Handle case where standardization params might be None
            if Mx is None or Wx is None:
                N = X_torch.shape[1]
                Mx = np.zeros(N, dtype=np.float32)
                Wx = np.ones(N, dtype=np.float32)
            
            # Handle missing data
            # Use method=1 (fill without trimming) to preserve data size for test compatibility
            # Method 1 fills missing values using spline interpolation + moving average without trimming rows
            # This ensures result shape matches data_module.data_processed.shape[0] as expected by tests
            # Method 2 (default) trims rows with >80% NaN, which causes shape mismatch with test expectations
            nan_method = safe_get_attr(self.config, 'nan_method', 1)  # Changed from 2 to 1 to preserve data size
            nan_k = safe_get_attr(self.config, 'nan_k', 3)
            x_clean, missing_mask = rem_nans_spline(
                X_torch.cpu().numpy() if isinstance(X_torch, torch.Tensor) else X_torch,
                method=nan_method,
                k=nan_k
            )
            x_clean_torch = torch.tensor(x_clean, dtype=torch.float32, device=X_torch.device)
            # Use missing_mask from rem_nans_spline() to ensure shape consistency
            # With method=1, data size is preserved (no trimming), so missing_mask matches original data shape
            # missing_mask indicates original NaN positions (before filling)
            
            # Run MCMC training
            # Pass x_clean_torch as X to ensure all data arrays have consistent trimmed shape
            self.fit_mcmc(
                X=x_clean_torch,
                x_clean=x_clean_torch,
                missing_mask=missing_mask,
                Mx=Mx,
                Wx=Wx,
            )
        
        super().on_train_start()
    
    @property
    def nowcast(self) -> 'Nowcast':
        """Get nowcasting manager instance."""
        if self._nowcast is None:
            if self._config is None:
                raise ValueError(
                    "DDFM nowcast access failed: configuration has not been loaded yet. "
                    "Please call load_config() first."
                )
            if self._data_module is None:
                raise ValueError(
                    "DDFM nowcast access failed: DataModule has not been provided yet. "
                    "Please provide DataModule via trainer.fit() before accessing nowcast."
                )
            if self.training_state is None:
                raise ValueError(
                    "DDFM nowcast access failed: model has not been trained yet. "
                    "Please call trainer.fit() first."
                )
            from ..nowcast.nowcast import Nowcast
            self._nowcast = Nowcast(model=self, data_module=self._data_module)
        return self._nowcast
    
    def load_config(
        self,
        source: Optional[Union[str, Path, Dict[str, Any], DFMConfig, ConfigSource]] = None,
        *,
        yaml: Optional[Union[str, Path]] = None,
        mapping: Optional[Dict[str, Any]] = None,
        hydra: Optional[Union[Dict[str, Any], Any]] = None,
        base: Optional[Union[str, Path, Dict[str, Any], ConfigSource]] = None,
        override: Optional[Union[str, Path, Dict[str, Any], ConfigSource]] = None,
    ) -> 'DDFM':
        """Load configuration from various sources."""
        # Preserve explicitly set num_factors if it was set during initialization
        preserved_num_factors = None
        if hasattr(self, '_num_factors_explicit') and self._num_factors_explicit:
            preserved_num_factors = self.num_factors
        
        # Use common config loading logic
        self._load_config_common(
            source=source,
            yaml=yaml,
            mapping=mapping,
            hydra=hydra,
            base=base,
            override=override,
        )
        
        # Restore preserved num_factors if it was explicitly set
        if preserved_num_factors is not None:
            self.num_factors = preserved_num_factors
            # Keep the flag set since it's still explicitly set
            self._num_factors_explicit = True
        
        # DDFM-specific initialization is handled in __init__ or on_train_start
        # No additional setup needed here
        
        return self
    
    def predict(
        self,
        horizon: Optional[int] = None,
        *,
        return_series: bool = True,
        return_factors: bool = True
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        """Forecast future values.
        
        This method can be called after training. It uses the training state
        from the Lightning module to generate forecasts.
        """
        if self.training_state is None:
            error_msg = self._format_error_message(
                operation="prediction",
                reason="model has not been trained yet",
                guidance="Please call trainer.fit(model, data_module) first"
            )
            raise ValueError(error_msg)
        
        # Convert training state to result format for prediction
        if not hasattr(self, '_result') or self._result is None:
            self._result = self.get_result()
        
        # Also set _model_impl._result so _model_impl.predict() can use it
        if hasattr(self, '_model_impl') and self._model_impl is not None:
            self._model_impl._result = self._result
        
        return self._model_impl.predict(
            horizon=horizon,
            return_series=return_series,
            return_factors=return_factors
        )
    
    @property
    def result(self) -> DDFMResult:
        """Get model result from training state.
        
        Raises
        ------
        ValueError
            If model has not been trained yet
        """
        # Check if trained and extract result from training state if needed
        self._check_trained()
        return self._result
    
    @property
    def config(self) -> DFMConfig:
        """Get model configuration."""
        if not hasattr(self, '_config') or self._config is None:
            raise ValueError("Model configuration not set. Call load_config() or pass config to __init__() first.")
        return self._config
    
    def plot(self, **kwargs) -> 'DDFM':
        """Plot common visualizations.
        
        .. note::
            Plot functionality is not yet implemented. This method is a placeholder
            for future visualization features. Use external plotting libraries
            (matplotlib, plotly) with model results for visualization.
        """
        if self.training_state is None:
            raise ValueError(
                "DDFM plot failed: model has not been trained yet. "
                "Please call trainer.fit() first."
            )
        _logger.warning("Plot functionality not yet implemented. Use external plotting libraries with model results.")
        return self
    
    def reset(self) -> 'DDFM':
        """Reset model state."""
        self._config = None
        self._data_module = None
        self._result = None
        self._nowcast = None
        if hasattr(self, 'training_state'):
            self.training_state = None
        return self
    
    def load_pickle(self, path: Union[str, Path], **kwargs) -> 'DDFM':
        """Load a saved model from pickle file.
        
        Note: DataModule is not saved in pickle. Users must create a new DataModule
        and call train() with it after loading the model.
        """
        import pickle  # Import locally to avoid unnecessary dependency
        with open(path, 'rb') as f:
            payload = pickle.load(f)
        self._config = payload.get('config')
        self._result = payload.get('result')
        # Note: data_module is not loaded - users must provide it via train()
        return self
    
    def generate_dataset(
        self,
        target_series: str,
        periods: List[datetime],
        backward: int = 0,
        forward: int = 0,
        dataview: Optional['DataView'] = None
    ) -> Dict[str, Any]:
        """Generate dataset for DFM evaluation.
        
        Delegates to model_impl.generate_dataset() with access to high-level nowcast.
        """
        # Store nowcast reference in model_impl for the method call
        # This allows generate_dataset to access nowcast property
        setattr(self._model_impl, '_nowcast_ref', self.nowcast)
        try:
            result = self._model_impl.generate_dataset(
                target_series=target_series,
                periods=periods,
                backward=backward,
                forward=forward,
                dataview=dataview
            )
        finally:
            # Clean up
            if hasattr(self._model_impl, '_nowcast_ref'):
                delattr(self._model_impl, '_nowcast_ref')
        return result
    
    def get_state(
        self,
        t: Union[int, datetime],
        target_series: str,
        lookback: Optional[int] = None
    ) -> Dict[str, Any]:
        """Get DFM state at time t.
        
        Delegates to model_impl.get_state() with access to high-level nowcast.
        """
        # Store nowcast reference in model_impl for the method call
        setattr(self._model_impl, '_nowcast_ref', self.nowcast)
        try:
            result = self._model_impl.get_state(
                t=t,
                target_series=target_series,
                lookback=lookback
            )
        finally:
            # Clean up
            if hasattr(self._model_impl, '_nowcast_ref'):
                delattr(self._model_impl, '_nowcast_ref')
        return result

