mod utils;

use serial_test::serial;
use statsig_rust::{
    networking::{NetworkClient, RequestArgs},
    output_logger::LogLevel,
    Statsig, StatsigOptions,
};
use std::{sync::Arc, time::Duration};
use utils::{
    mock_log_provider::{MockLogProvider, RecordedLog},
    mock_scrapi::{Endpoint, EndpointStub, Method, MockScrapi, StubData},
};

lazy_static::lazy_static! {
    static ref MOCK_LOG_PROVIDER: Arc<MockLogProvider> = Arc::new(MockLogProvider::new());
}

async fn setup(
    response: &str,
    status: u16,
) -> (Arc<MockLogProvider>, MockScrapi, NetworkClient, RequestArgs) {
    let key = "secret-network_failure_tests";

    let mock_log_provider = MOCK_LOG_PROVIDER.clone();
    mock_log_provider.clear();

    let mock_scrapi = MockScrapi::new().await;
    let network_client = NetworkClient::new(key, None, None);

    let options = Arc::new(StatsigOptions {
        output_logger_provider: Some(mock_log_provider.clone()),
        output_log_level: Some(LogLevel::Debug),
        ..StatsigOptions::new()
    });

    // just to get the output logger initialized
    let _ = Statsig::new(key, Some(options));

    mock_scrapi
        .stub(EndpointStub {
            method: Method::POST,
            status,
            response: StubData::String(response.to_string()),
            ..EndpointStub::with_endpoint(Endpoint::DownloadConfigSpecs)
        })
        .await;

    let request_args = RequestArgs {
        url: mock_scrapi.url_for_endpoint(Endpoint::DownloadConfigSpecs),
        retries: 2,
        ..RequestArgs::new()
    };

    (mock_log_provider, mock_scrapi, network_client, request_args)
}

#[tokio::test]
#[serial]
async fn test_non_retryable_error_result() {
    let (_, _, network_client, request_args) = setup("read the docs", 400).await;

    let url = request_args.url.clone();
    let response = network_client.post(request_args, None).await;

    let error = response.err().unwrap();
    let expected = format!("RequestNotRetryable: {url} status(400) Bad Request: read the docs",);
    assert_eq!(error.to_string(), expected);
}

#[tokio::test]
#[serial]
async fn test_non_retryable_std_out() {
    let (mock_log_provider, _, network_client, request_args) = setup("read the docs", 400).await;

    let url = request_args.url.clone();
    let _ = network_client.post(request_args, None).await;

    let logs = {
        let mut guard = mock_log_provider
            .logs
            .try_lock_for(Duration::from_secs(1))
            .unwrap();
        std::mem::take(&mut *guard)
    };

    let expected = format!("RequestNotRetryable: {url} status(400) Bad Request: read the docs",);
    assert!(logs.iter().any(|log| match log {
        RecordedLog::Warn(_, msg) => {
            msg.contains(&expected)
        }
        _ => false,
    }));
}

#[tokio::test]
#[serial]
async fn test_exhaust_retries_error_result() {
    let (_, _, network_client, request_args) = setup("{}", 500).await;

    let url = request_args.url.clone();
    let response = network_client.post(request_args, None).await;

    let error = response.err().unwrap();
    // 1 initial + 2 retries = 3 attempts
    let expected =
        format!("RetriesExhausted: {url} status(500) attempts(3) Internal Server Error: {{}}",);
    assert_eq!(error.to_string(), expected);
}

#[tokio::test]
#[serial]
async fn test_exhaust_retries_error_result_binary_response() {
    let data = "����            H	  � �        �  ";
    let (_, _, network_client, request_args) = setup(data, 500).await;

    let url = request_args.url.clone();
    let response = network_client.post(request_args, None).await;

    let error = response.err().unwrap();
    // 1 initial + 2 retries = 3 attempts
    let expected =
        format!("RetriesExhausted: {url} status(500) attempts(3) Internal Server Error",);
    assert_eq!(error.to_string(), expected);
}

#[tokio::test]
#[serial]
async fn test_exhaust_retries_std_out() {
    let (mock_log_provider, _, network_client, request_args) = setup("{}", 500).await;

    let _ = network_client.post(request_args, None).await;

    let logs = {
        let mut guard = mock_log_provider
            .logs
            .try_lock_for(Duration::from_secs(1))
            .unwrap();
        std::mem::take(&mut *guard)
    };
    assert!(logs.iter().any(|log| match log {
        RecordedLog::Info(_, msg) => {
            msg.contains("Network request failed with status code 500 (attempt 1/3)")
        }
        _ => false,
    }));

    let found_second_log = logs.iter().any(|log| match log {
        RecordedLog::Info(_, msg) => {
            msg.contains("Network request failed with status code 500 (attempt 2/3)")
        }
        _ => false,
    });
    assert!(found_second_log);
}
