/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.evaluation.statistics;

import java.util.Map;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Config;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;

public class ContinuousEvaluator
extends Evaluator {
    public static final String CONFIG_PREFIX = "continuousevaluator";
    public static final String REPRESENTATIVE_KEY = "continuousevaluator.representative";
    public static final String DEFAULT_REPRESENTATIVE = "MSE";
    private RepresentativeMetric representative;
    private int count;
    private double absoluteError;
    private double squaredError;

    public ContinuousEvaluator() {
        this(Config.getString(REPRESENTATIVE_KEY, DEFAULT_REPRESENTATIVE));
    }

    public ContinuousEvaluator(String representative) {
        this(RepresentativeMetric.valueOf(representative.toUpperCase()));
    }

    public ContinuousEvaluator(RepresentativeMetric representative) {
        this.representative = representative;
        this.count = 0;
        this.absoluteError = 0.0;
        this.squaredError = 0.0;
    }

    @Override
    public void compute(TrainingMap trainingMap) {
        this.compute(trainingMap, null);
    }

    @Override
    public void compute(TrainingMap trainingMap, StandardPredicate predicate) {
        this.count = 0;
        this.absoluteError = 0.0;
        this.squaredError = 0.0;
        for (Map.Entry<GroundAtom, GroundAtom> entry : trainingMap.getFullMap()) {
            if (predicate != null && entry.getKey().getPredicate() != predicate) continue;
            ++this.count;
            this.absoluteError += (double)Math.abs(entry.getValue().getValue() - entry.getKey().getValue());
            this.squaredError += Math.pow(entry.getValue().getValue() - entry.getKey().getValue(), 2.0);
        }
    }

    @Override
    public double getRepresentativeMetric() {
        switch (this.representative) {
            case MAE: {
                return this.mae();
            }
            case MSE: {
                return this.mse();
            }
        }
        throw new IllegalStateException("Unknown representative metric: " + (Object)((Object)this.representative));
    }

    @Override
    public boolean isHigherRepresentativeBetter() {
        return false;
    }

    public double mae() {
        if (this.count == 0) {
            return 0.0;
        }
        return this.absoluteError / (double)this.count;
    }

    public double mse() {
        if (this.count == 0) {
            return 0.0;
        }
        return this.squaredError / (double)this.count;
    }

    @Override
    public String getAllStats() {
        return String.format("MAE: %f, MSE: %f", this.mae(), this.mse());
    }

    public static enum RepresentativeMetric {
        MAE,
        MSE;

    }
}

