"""Tests for nowcast() and predict() implementations.

Tests verify that nowcast() and predict() work correctly with dfm-python package only.
"""

import pytest
import numpy as np
import pandas as pd
from datetime import datetime

from dfm_python.models import DFM
from dfm_python.config import DFMConfig, SeriesConfig
from dfm_python import DFMDataModule, DFMTrainer
from dfm_python.config.results import NowcastResult
from dfm_python.utils.time import TimeIndex, parse_timestamp


class TestNowcastImplementation:
    """Test nowcast() implementation."""
    
    @pytest.fixture
    def simple_config(self):
        """Create a simple config for testing."""
        return DFMConfig(
            series=[
                SeriesConfig(series_id='series1', frequency='m', transformation='lin', blocks=[1]),
                SeriesConfig(series_id='series2', frequency='m', transformation='lin', blocks=[1]),
            ],
            blocks={'block1': {'factors': 1, 'ar_lag': 1, 'clock': 'm'}}
        )
    
    @pytest.fixture
    def simple_data(self):
        """Create simple synthetic data."""
        np.random.seed(42)
        T, N = 50, 2
        X = np.random.randn(T, N).cumsum(axis=0)
        dates = pd.date_range('2020-01-01', periods=T, freq='ME')
        time_index = TimeIndex([parse_timestamp(d.strftime('%Y-%m-%d')) for d in dates])
        df = pd.DataFrame(X, columns=['series1', 'series2'])
        return df, time_index
    
    def test_nowcast_basic(self, simple_config, simple_data):
        """Test basic nowcast functionality."""
        model = DFM()
        model._config = simple_config
        
        # Create DataModule
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        # Train model - config에 max_iter, threshold 설정
        simple_config.max_iter = 2
        simple_config.threshold = 1e-2
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        # Test nowcast returns float
        value = model.nowcast('series1', return_result=False)
        assert isinstance(value, (float, np.floating))
        assert np.isfinite(value)
        
        # Test nowcast returns NowcastResult
        result = model.nowcast('series1', return_result=True)
        assert isinstance(result, NowcastResult)
        assert result.target_series == 'series1'
        assert np.isfinite(result.nowcast_value)
        assert result.view_date is not None
        assert result.target_period is not None
    
    def test_nowcast_with_view_date(self, simple_config, simple_data):
        """Test nowcast with explicit view_date."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        # Test with string view_date
        value1 = model.nowcast('series1', view_date='2020-12-31', return_result=False)
        assert np.isfinite(value1)
        
        # Test with datetime view_date
        value2 = model.nowcast('series1', view_date=datetime(2020, 12, 31), return_result=False)
        assert np.isfinite(value2)
    
    def test_nowcast_target_series_not_found(self, simple_config, simple_data):
        """Test nowcast raises error for invalid target_series."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        with pytest.raises(ValueError, match="target_series.*not found"):
            model.nowcast('invalid_series')
    
    def test_predict_basic(self, simple_config, simple_data):
        """Test basic predict functionality."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        # Test predict returns array
        forecast = model.predict(horizon=3, return_series=True, return_factors=False)
        assert isinstance(forecast, np.ndarray)
        assert forecast.shape[0] == 3  # horizon
        assert forecast.shape[1] == 2  # number of series
        assert np.all(np.isfinite(forecast))
    
    def test_predict_with_factors(self, simple_config, simple_data):
        """Test predict returns both series and factors."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        # Test predict returns tuple
        X_forecast, Z_forecast = model.predict(horizon=3, return_series=True, return_factors=True)
        assert isinstance(X_forecast, np.ndarray)
        assert isinstance(Z_forecast, np.ndarray)
        assert X_forecast.shape[0] == 3
        assert Z_forecast.shape[0] == 3
        assert np.all(np.isfinite(X_forecast))
        assert np.all(np.isfinite(Z_forecast))
    
    def test_nowcast_view_date_future(self, simple_config, simple_data):
        """Test nowcast with view_date in the future."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        simple_config.max_iter = 2
        simple_config.threshold = 1e-2
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        # Future view_date should use latest available data
        future_date = '2025-12-31'
        value = model.nowcast('series1', view_date=future_date, return_result=False)
        assert np.isfinite(value)
    
    def test_nowcast_view_date_past(self, simple_config, simple_data):
        """Test nowcast with view_date in the past."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        simple_config.max_iter = 2
        simple_config.threshold = 1e-2
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        # Past view_date should work (uses available data up to that date)
        past_date = '2020-06-30'
        value = model.nowcast('series1', view_date=past_date, return_result=False)
        assert np.isfinite(value)
    
    def test_nowcast_all_missing_data(self, simple_config, simple_data):
        """Test nowcast with all missing data (should handle gracefully)."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        # Make all data NaN
        df_missing = df.copy()
        df_missing.iloc[:, :] = np.nan
        
        data_module = DFMDataModule(config=simple_config, data=df_missing, time_index=time_index)
        data_module.setup()
        
        simple_config.max_iter = 2
        simple_config.threshold = 1e-2
        trainer = DFMTrainer()
        
        # Training may fail with all NaN data, which is expected
        # All NaN data cannot be used for training, so we expect an error
        with pytest.raises((ValueError, RuntimeError, TypeError, IndexError)):
            trainer.fit(model, data_module)
    
    def test_predict_horizon_zero(self, simple_config, simple_data):
        """Test predict with horizon=0 raises error."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        with pytest.raises(ValueError, match="horizon must be positive"):
            model.predict(horizon=0)
    
    def test_predict_horizon_negative(self, simple_config, simple_data):
        """Test predict with negative horizon raises error."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        with pytest.raises(ValueError, match="horizon must be positive"):
            model.predict(horizon=-1)
    
    def test_predict_history_zero(self, simple_config, simple_data):
        """Test predict with history=0 (should use full history)."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        # history=0 should be treated as None (use full history)
        forecast = model.predict(horizon=3, history=0, return_series=True, return_factors=False)
        assert isinstance(forecast, np.ndarray)
        assert forecast.shape[0] == 3
    
    def test_predict_history_negative(self, simple_config, simple_data):
        """Test predict with negative history (should use full history)."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        # Negative history should be treated as None (use full history)
        forecast = model.predict(horizon=3, history=-1, return_series=True, return_factors=False)
        assert isinstance(forecast, np.ndarray)
        assert forecast.shape[0] == 3
    
    def test_predict_history_larger_than_data(self, simple_config, simple_data):
        """Test predict with history larger than data length."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        # history > data length should use all available data
        forecast = model.predict(horizon=3, history=1000, return_series=True, return_factors=False)
        assert isinstance(forecast, np.ndarray)
        assert forecast.shape[0] == 3
    
    def test_predict_both_returns_false(self, simple_config, simple_data):
        """Test predict with both return_series and return_factors False."""
        model = DFM()
        model._config = simple_config
        
        df, time_index = simple_data
        data_module = DFMDataModule(config=simple_config, data=df, time_index=time_index)
        data_module.setup()
        
        trainer = DFMTrainer()
        trainer.fit(model, data_module)
        
        # Should return factors when both are False (default behavior)
        result = model.predict(horizon=3, return_series=False, return_factors=False)
        # Implementation may return factors or raise error - check it's valid
        assert isinstance(result, np.ndarray)
        assert result.shape[0] == 3

