"""Result structures for Dynamic Factor Model estimation.

This module contains dataclasses for storing DFM estimation results and parameters.
"""

import numpy as np
from dataclasses import dataclass
from typing import Optional, List, Dict, Any

from ..config import DFMConfig


@dataclass
class DFMResult:
    """DFM estimation results structure.
    
    This dataclass contains all outputs from the DFM estimation procedure,
    including estimated parameters, smoothed data, and factors.
    
    Attributes
    ----------
    x_sm : np.ndarray
        Standardized smoothed data matrix (T x N), where T is time periods
        and N is number of series. Data is standardized (zero mean, unit variance).
    X_sm : np.ndarray
        Unstandardized smoothed data matrix (T x N). This is the original-scale
        version of x_sm, computed as X_sm = x_sm * Wx + Mx.
    Z : np.ndarray
        Smoothed factor estimates (T x m), where m is the state dimension.
        Columns represent different factors (common factors and idiosyncratic components).
    C : np.ndarray
        Observation/loading matrix (N x m). Each row corresponds to a series,
        each column to a factor. C[i, j] gives the loading of series i on factor j.
    R : np.ndarray
        Covariance matrix for observation equation residuals (N x N).
        Typically diagonal, representing idiosyncratic variances.
    A : np.ndarray
        Transition matrix (m x m) for the state equation. Describes how factors
        evolve over time: Z_t = A @ Z_{t-1} + error.
    Q : np.ndarray
        Covariance matrix for transition equation residuals (m x m).
        Describes the covariance of factor innovations.
    Mx : np.ndarray
        Series means (N,). Used for standardization: x = (X - Mx) / Wx.
    Wx : np.ndarray
        Series standard deviations (N,). Used for standardization.
    Z_0 : np.ndarray
        Initial state vector (m,). Starting values for factors at t=0.
    V_0 : np.ndarray
        Initial covariance matrix (m x m) for factors. Uncertainty about Z_0.
    r : np.ndarray
        Number of factors per block (n_blocks,). Each element specifies
        how many factors are in each block structure.
    p : int
        Number of lags in the autoregressive structure of factors. Typically p=1.
    rmse : float, optional
        Overall RMSE on original scale (averaged across all series).
    rmse_per_series : np.ndarray, optional
        RMSE per series on original scale (N,).
    rmse_std : float, optional
        Overall RMSE on standardized scale (averaged across all series).
    rmse_std_per_series : np.ndarray, optional
        RMSE per series on standardized scale (N,).
    converged : bool, optional
        Whether EM algorithm converged.
    num_iter : int, optional
        Number of EM iterations performed.
    loglik : float, optional
        Final log-likelihood value.
    
    Examples
    --------
    >>> from dfm_python import DFM
    >>> model = DFM()
    >>> Res = model.fit(X, config, threshold=1e-4)
    >>> # Access smoothed factors
    >>> common_factor = Res.Z[:, 0]
    >>> # Access factor loadings for first series
    >>> loadings = Res.C[0, :]
    >>> # Reconstruct smoothed series from factors
    >>> reconstructed = Res.Z @ Res.C.T
    """
    x_sm: np.ndarray      # Standardized smoothed data (T x N)
    X_sm: np.ndarray      # Unstandardized smoothed data (T x N)
    Z: np.ndarray         # Smoothed factors (T x m)
    C: np.ndarray         # Observation matrix (N x m)
    R: np.ndarray         # Covariance for observation residuals (N x N)
    A: np.ndarray         # Transition matrix (m x m)
    Q: np.ndarray         # Covariance for transition residuals (m x m)
    Mx: np.ndarray        # Series means (N,)
    Wx: np.ndarray        # Series standard deviations (N,)
    Z_0: np.ndarray       # Initial state (m,)
    V_0: np.ndarray       # Initial covariance (m x m)
    r: np.ndarray         # Number of factors per block
    p: int                # Number of lags
    converged: bool = False  # Whether EM algorithm converged
    num_iter: int = 0     # Number of iterations completed
    loglik: float = -np.inf  # Final log-likelihood
    rmse: Optional[float] = None  # Overall RMSE (original scale)
    rmse_per_series: Optional[np.ndarray] = None  # RMSE per series (original scale)
    rmse_std: Optional[float] = None  # Overall RMSE (standardized scale)
    rmse_std_per_series: Optional[np.ndarray] = None  # RMSE per series (standardized scale)
    # Optional metadata for object-oriented access
    series_ids: Optional[List[str]] = None
    block_names: Optional[List[str]] = None
    time_index: Optional[object] = None  # Typically a TimeIndex

    # ----------------------------
    # Convenience methods (OOP)
    # ----------------------------
    def num_series(self) -> int:
        """Return number of series (rows in C)."""
        return int(self.C.shape[0])

    def num_state(self) -> int:
        """Return state dimension (columns in Z/C)."""
        return int(self.Z.shape[1])

    def num_factors(self) -> int:
        """Return number of primary factors (sum of r)."""
        try:
            return int(np.sum(self.r))
        except Exception:
            return self.num_state()

    def to_polars_factors(self, time_index: Optional[object] = None, factor_names: Optional[List[str]] = None):
        """Return factors as polars DataFrame.
        
        Parameters
        ----------
        time_index : TimeIndex, list, or compatible, optional
            Time index to use for rows. If None, uses stored time_index if available.
        factor_names : List[str], optional
            Column names. Defaults to F1..Fm.
        """
        try:
            import polars as pl
            from .time import TimeIndex
            
            cols = factor_names if factor_names is not None else [f"F{i+1}" for i in range(self.num_state())]
            
            # Create DataFrame with factors as columns
            df_dict = {col: self.Z[:, i] for i, col in enumerate(cols)}
            
            # Add time column if time_index provided
            time_to_use = time_index if time_index is not None else self.time_index
            if time_to_use is not None:
                if isinstance(time_to_use, TimeIndex):
                    time_list = time_to_use.to_list()
                elif hasattr(time_to_use, '__iter__') and not isinstance(time_to_use, (str, bytes)):
                    time_list = list(time_to_use)
                else:
                    try:
                        time_list = [time_to_use[i] for i in range(len(time_to_use))]
                    except (TypeError, AttributeError):
                        time_list = []
                if time_list:
                    df_dict['time'] = time_list
            
            return pl.DataFrame(df_dict)
        except (ImportError, ValueError, TypeError):
            return self.Z

    def to_polars_smoothed(self, time_index: Optional[object] = None, series_ids: Optional[List[str]] = None):
        """Return smoothed data (original scale) as polars DataFrame."""
        try:
            import polars as pl
            from .time import TimeIndex
            
            cols = series_ids if series_ids is not None else (self.series_ids if self.series_ids is not None else [f"S{i+1}" for i in range(self.num_series())])
            
            # Create DataFrame with series as columns
            df_dict = {col: self.X_sm[:, i] for i, col in enumerate(cols)}
            
            # Add time column if time_index provided
            time_to_use = time_index if time_index is not None else self.time_index
            if time_to_use is not None:
                if isinstance(time_to_use, TimeIndex):
                    time_list = time_to_use.to_list()
                elif hasattr(time_to_use, '__iter__') and not isinstance(time_to_use, (str, bytes)):
                    time_list = list(time_to_use)
                else:
                    try:
                        time_list = [time_to_use[i] for i in range(len(time_to_use))]
                    except (TypeError, AttributeError):
                        time_list = []
                if time_list:
                    df_dict['time'] = time_list
            
            return pl.DataFrame(df_dict)
        except (ImportError, ValueError, TypeError):
            return self.X_sm
    

    def save(self, path: str) -> None:
        """Save result to a pickle file."""
        try:
            import pickle
            with open(path, 'wb') as f:
                pickle.dump(self, f)
        except (IOError, OSError, pickle.PickleError) as e:
            raise RuntimeError(f"Failed to save DFMResult to {path}: {e}")


@dataclass
class DFMParams:
    """DFM estimation parameter overrides.
    
    All parameters are optional. If None, the corresponding value
    from DFMConfig will be used during parameter resolution.
    
    This dataclass groups all parameter overrides that can be passed
    to `_dfm_core()` and `_prepare_data_and_params()` to reduce
    function parameter count and improve readability.
    """
    threshold: Optional[float] = None
    max_iter: Optional[int] = None
    ar_lag: Optional[int] = None
    nan_method: Optional[int] = None
    nan_k: Optional[int] = None
    clock: Optional[str] = None
    clip_ar_coefficients: Optional[bool] = None
    ar_clip_min: Optional[float] = None
    ar_clip_max: Optional[float] = None
    clip_data_values: Optional[bool] = None
    data_clip_threshold: Optional[float] = None
    use_regularization: Optional[bool] = None
    regularization_scale: Optional[float] = None
    min_eigenvalue: Optional[float] = None
    max_eigenvalue: Optional[float] = None
    use_damped_updates: Optional[bool] = None
    damping_factor: Optional[float] = None
    
    @classmethod
    def from_kwargs(cls, **kwargs) -> 'DFMParams':
        """Create DFMParams from keyword arguments.
        
        Filters kwargs to only include valid parameter names,
        ignoring any extra arguments.
        """
        valid_params = {
            'threshold', 'max_iter', 'ar_lag', 'nan_method', 'nan_k',
            'clock', 'clip_ar_coefficients', 'ar_clip_min', 'ar_clip_max',
            'clip_data_values', 'data_clip_threshold', 'use_regularization',
            'regularization_scale', 'min_eigenvalue', 'max_eigenvalue',
            'use_damped_updates', 'damping_factor'
        }
        filtered = {k: v for k, v in kwargs.items() if k in valid_params}
        return cls(**filtered)


@dataclass
class EMAlgorithmParams:
    """Parameters for EM algorithm execution.
    
    This dataclass groups all parameters required for running the EM algorithm,
    reducing function parameter count and improving readability.
    
    All parameters are required (no optional fields) since the EM algorithm
    needs all of them to execute.
    """
    # Data
    y: np.ndarray
    y_est: np.ndarray
    
    # Model parameters
    A: np.ndarray
    C: np.ndarray
    Q: np.ndarray
    R: np.ndarray
    Z_0: np.ndarray
    V_0: np.ndarray
    r: np.ndarray
    p: int
    
    # Structure parameters
    R_mat: Optional[np.ndarray]
    q: Optional[np.ndarray]
    nQ: int
    i_idio: np.ndarray
    blocks: np.ndarray
    tent_weights_dict: Dict[str, np.ndarray]
    clock: str
    frequencies: Optional[np.ndarray]
    idio_chain_lengths: np.ndarray
    
    # Config and algorithm parameters
    config: DFMConfig
    threshold: float
    max_iter: int
    use_damped_updates: bool
    damping_factor: float

