from typing_extensions import override

import torch

from deeplotx.nn.linear_regression import LinearRegression


class SoftmaxRegression(LinearRegression):
    def __init__(self, input_dim: int, output_dim: int, num_heads: int = 1, num_layers: int = 1,
                 expansion_factor: int | float = 1.5, bias: bool = True, dropout_rate: float = 0.1,
                 model_name: str | None = None, device: str | None = None, dtype: torch.dtype | None = None, **kwargs):
        super().__init__(input_dim=input_dim, output_dim=output_dim, num_heads=num_heads, num_layers=num_layers,
                         expansion_factor=expansion_factor, bias=bias, dropout_rate=dropout_rate,
                         model_name=model_name, device=device, dtype=dtype, **kwargs)

    @override
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ensure_device_and_dtype(x, device=self.device, dtype=self.dtype)
        return torch.softmax(super().forward(x), dim=-1, dtype=self.dtype)
