from enum import Enum

from code_loader.default_metrics import mean_absolute_percentage_error_dimension_reduced, \
    mean_absolute_error_dimension_reduced, mean_squared_logarithmic_error_dimension_reduced, \
    mean_squared_error_dimension_reduced, categorical_crossentropy, binary_crossentropy
from code_loader.inner_leap_binder.leapbinder_decorators import tensorleap_custom_loss


class LossName(Enum):
    MeanSquaredError = 'MeanSquaredError'
    MeanSquaredLogarithmicError = 'MeanSquaredLogarithmicError'
    MeanAbsoluteError = 'MeanAbsoluteError'
    MeanAbsolutePercentageError = 'MeanAbsolutePercentageError'
    CategoricalCrossentropy = 'CategoricalCrossentropy'
    BinaryCrossentropy = 'BinaryCrossentropy'


loss_name_to_function = {
    LossName.MeanSquaredError.name: mean_squared_error_dimension_reduced,
    LossName.MeanSquaredLogarithmicError.name: mean_squared_logarithmic_error_dimension_reduced,
    LossName.MeanAbsoluteError.name: mean_absolute_error_dimension_reduced,
    LossName.MeanAbsolutePercentageError.name: mean_absolute_percentage_error_dimension_reduced,
    LossName.CategoricalCrossentropy.name: categorical_crossentropy,
    LossName.BinaryCrossentropy.name: binary_crossentropy
}


for loss_name, func in loss_name_to_function.items():
    @tensorleap_custom_loss(loss_name)
    def loss_func(ground_truth, prediction):
        return func(ground_truth, prediction)
