import torch


class SkipLinear(torch.nn.Module):
    """
    A skip linear layer.

    Notes:
    Code adapted from James D. McCaffrey:
    "Regression Using a PyTorch Neural Network with a Transformer Component"

    Reference:
        https://jamesmccaffrey.wordpress.com/2023/12/01/regression-using-a-pytorch-neural-network-with-a-transformer-component/

    Args:
        n_in (int):
            the input dimension
        n_out (int):
            the output dimension

    Examples:
        >>> from spotPython.light.transformer.skiplinear import SkipLinear
            import torch
            n_in = 2
            n_out = 4
            sl = SkipLinear(n_in, n_out)
            input = torch.zeros(1, n_in)
            for i in range(n_in):
                input[0, i] = i
            print(f"Input shape: {input.shape}")
            print(f"Input: {input}")
            output = sl(input)
            print(f"Output shape: {output.shape}")
            print(f"Output: {output}")
            print(sl.lst_modules)
            for i in sl.lst_modules:
                print(f"weights: {i.weights}")
            Input shape: torch.Size([1, 2])
            Input: tensor([[0., 1.]])
            Output shape: torch.Size([1, 4])
            Output: tensor([[ 0.0000,  0.0000, -0.0062, -0.0032]], grad_fn=<ViewBackward0>)
            ModuleList(
            (0-1): 2 x Core()
            )
            weights: Parameter containing:
            tensor([[-0.0098],
                    [ 0.0038]], requires_grad=True)
            weights: Parameter containing:
            tensor([[0.0041],
                    [0.0074]], requires_grad=True)
    """

    class Core(torch.nn.Module):
        def __init__(self, n):
            super().__init__()
            # 1 node to n nodes, n >= 2
            self.weights = torch.nn.Parameter(torch.zeros((n, 1), dtype=torch.float32))
            self.biases = torch.nn.Parameter(torch.tensor(n, dtype=torch.float32))
            lim = 0.01
            torch.nn.init.uniform_(self.weights, -lim, lim)
            torch.nn.init.zeros_(self.biases)

        def forward(self, x):
            wx = torch.mm(x, self.weights.t())
            v = torch.add(wx, self.biases)
            return v

    def __init__(self, n_in, n_out):
        super().__init__()
        self.n_in = n_in
        self.n_out = n_out
        if n_out % n_in != 0:
            raise ValueError("n_out % n_in != 0")
        n = n_out // n_in  # num nodes per input

        self.lst_modules = torch.nn.ModuleList([SkipLinear.Core(n) for i in range(n_in)])

    def forward(self, x):
        lst_nodes = []
        for i in range(self.n_in):
            xi = x[:, i].reshape(-1, 1)
            oupt = self.lst_modules[i](xi)
            lst_nodes.append(oupt)
        result = torch.cat((lst_nodes[0], lst_nodes[1]), 1)
        for i in range(2, self.n_in):
            result = torch.cat((result, lst_nodes[i]), 1)
        result = result.reshape(-1, self.n_out)
        return result
