import math

import numpy as np
from pytest import fixture
from scipy.interpolate import RectBivariateSpline

from simuls_tpm.utils import grid_function


@fixture
def fun():
    return lambda x, y: x * np.exp(-y)


def test_RectBivariateSpline(fun):
    n_points = 9
    grid1 = np.arange(1, n_points + 1) / (n_points + 1.0)
    z_grid = grid_function(fun, grid1, grid1)
    biv_spline = RectBivariateSpline(grid1, grid1, z_grid)
    x_new = np.array([0.4, 0.7, 0.1])
    y_new = np.array([0.9, 0.2, 0.5])
    z_new_th = fun(x_new, y_new)
    z_new = biv_spline(x=x_new, y=y_new, grid=False)
    assert np.allclose(z_new, z_new_th)


def test_RectBivariateSpline_dy(fun):
    n_points = 9
    grid1 = np.arange(1, n_points + 1) / (n_points + 1.0)
    z_grid = grid_function(fun, grid1, grid1)
    biv_spline = RectBivariateSpline(grid1, grid1, z_grid)
    x_new = np.array([0.4, 0.7, 0.1])
    y_new = np.array([0.9, 0.2, 0.5])
    z_new_th = -fun(x_new, y_new)
    z_new = biv_spline(x=x_new, y=y_new, dy=1, grid=False)
    assert np.allclose(z_new, z_new_th, 1e-3)


def test_RectBivariateSpline_integral(fun):
    n_points = 9
    grid1 = np.arange(n_points) / (n_points - 1.0)
    z_grid = grid_function(fun, grid1, grid1)
    biv_spline = RectBivariateSpline(grid1, grid1, z_grid)
    val_integral_th = 0.5 * (1.0 - 1.0 / math.e)
    val_integral = biv_spline.integral(0.0, 1.0, 0.0, 1.0)
    assert np.allclose(val_integral, val_integral_th, 1e-3)
