from numpy import dot

from neuralpy.core.properties import ListProperty
from neuralpy.algorithms.feedforward import FeedForwardNetwork
from neuralpy.network.learning import SupervisedLearning
from . import optimization_types


__all__ = ('Backpropagation',)


class Backpropagation(SupervisedLearning, FeedForwardNetwork):
    """ Backpropagation algorithm.

    Parameters
    ----------
    {optimizations}
    {raw_predict_param}
    {full_params}

    Methods
    -------
    {supervised_train}
    {full_methods}

    Examples
    --------
    >>> import numpy as np
    >>> from neuralpy.algorithms import Backpropagation
    >>>
    >>> x_train = np.array([[1, 2], [3, 4]])
    >>> y_train = np.array([[1], [0]])
    >>>
    >>> bpnet = Backpropagation((2, 3, 1), verbose=False, step=0.1)
    >>> bpnet.train(x_train, y_train)
    """

    __opt_params = """optimizations : list or None
        The list of optimization algortihms. ``None`` by default.
        If this option is not empty it will generate new class which
        will inherit all from this list. Support two types of
        optimization algorithms: weight update and step update.
    """
    shared_docs = {"optimizations": __opt_params}

    optimizations = ListProperty(default=None)

    def __new__(cls, connection, options=None, **kwargs):
        # Argument `options` is a simple hack for `__reduce__`.
        # `__reduce__` can't retore class with keyword arguments and
        # it will put them as `dict` argument in `options` and we will
        # translate it to kwargs. The same hack at `__init__` method.
        if options is None:
            options = kwargs

        optimizations = options.get('optimizations', None)
        if not optimizations:
            return super(Backpropagation, cls).__new__(cls)

        founded_types = []
        for optimization_class in optimizations:
            opt_class_type = getattr(optimization_class, 'optimization_type',
                                     None)
            if opt_class_type not in optimization_types:
                raise ValueError("Invalid optimization class `{}`".format(
                                 optimization_class.__name__))

            if opt_class_type in founded_types:
                raise ValueError(
                    "There can be only one optimization class with "
                    "type `{}`".format(optimization_types[opt_class_type])
                )

            founded_types.append(opt_class_type)

        # Build new class which would inherit main and all optimization
        new_class_name = (
            cls.__name__ +
            ''.join(class_.__name__ for class_ in optimizations)
        )
        mro_classes = tuple(list(optimizations) + [cls])
        new_class = type(new_class_name, mro_classes, {})

        return super(Backpropagation, new_class).__new__(new_class)

    def __init__(self, connection, options=None, **kwargs):
        if options is None:
            options = kwargs
        super(Backpropagation, self).__init__(connection, **options)

    def get_gradient(self, output_train, target_train):
        self.delta = []
        self.gradients = []

        update = self.error.deriv(output_train, target_train)

        for i, layer in enumerate(reversed(self.train_layers), 1):
            summated_data = self.summated_data[-i]
            current_layer_input = self.layer_outputs[-i]

            deriv = layer.activation_function.deriv(summated_data)

            delta = deriv * update
            update = dot(delta, layer.weight_without_bias.T)

            gradient = dot(current_layer_input.T, delta)

            self.gradients.insert(0, gradient)
            self.delta.insert(0, delta)

        return self.gradients

    def learn(self, output_train, target_train):
        gradients = self.get_gradient(output_train, target_train)
        return [-gradient for gradient in gradients]

    def update_weights(self, weight_deltas):
        layer_weight_update = self.layer_weight_update
        for i, layer in enumerate(self.train_layers):
            layer.weight += layer_weight_update(weight_deltas[i], i)

    def layer_step(self, layer_number):
        return self.step

    def layer_weight_update(self, delta, layer_number):
        return self.layer_step(layer_number) * delta

    def get_class_name(self):
        return 'Backpropagation'

    def __reduce__(self):
        options = {}
        for name, option in self.options.items():
            value = getattr(self, name)
            # Default values not always valid types. For this reason we
            # ignore all values which has the same as default.
            if value != option.value.default:
                options[name] = value
        return (Backpropagation, (self.connection, options))
