/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.evaluation.statistics;

import java.util.Map;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.util.MathUtils;

public class DiscreteEvaluator
extends Evaluator {
    private double threshold;
    private RepresentativeMetric representative;
    private int tp;
    private int fn;
    private int tn;
    private int fp;

    public DiscreteEvaluator() {
        this(Options.EVAL_DISCRETE_THRESHOLD.getDouble());
    }

    public DiscreteEvaluator(double threshold) {
        this(threshold, Options.EVAL_DISCRETE_REPRESENTATIVE.getString());
    }

    public DiscreteEvaluator(double threshold, String representative) {
        this(threshold, RepresentativeMetric.valueOf(representative.toUpperCase()));
    }

    public DiscreteEvaluator(double threshold, RepresentativeMetric representative) {
        if (threshold < 0.0 || threshold > 1.0) {
            throw new IllegalArgumentException("Threhsold must be in (0, 1). Found: " + threshold);
        }
        this.threshold = threshold;
        this.representative = representative;
        this.tp = 0;
        this.fn = 0;
        this.tn = 0;
        this.fp = 0;
    }

    @Override
    public void compute(TrainingMap trainingMap) {
        this.compute(trainingMap, null);
    }

    @Override
    public void compute(TrainingMap trainingMap, StandardPredicate predicate) {
        this.tp = 0;
        this.fn = 0;
        this.tn = 0;
        this.fp = 0;
        for (Map.Entry<GroundAtom, GroundAtom> entry : this.getMap(trainingMap)) {
            boolean predicted;
            if (predicate != null && entry.getKey().getPredicate() != predicate) continue;
            boolean expected = (double)entry.getValue().getValue() >= this.threshold;
            boolean bl = predicted = (double)entry.getKey().getValue() >= this.threshold;
            if (predicted && expected) {
                ++this.tp;
                continue;
            }
            if (!predicted && expected) {
                ++this.fn;
                continue;
            }
            if (predicted && !expected) {
                ++this.fp;
                continue;
            }
            ++this.tn;
        }
    }

    @Override
    public double getRepMetric() {
        switch (this.representative) {
            case F1: {
                return this.f1();
            }
            case POSITIVE_PRECISION: {
                return this.positivePrecision();
            }
            case NEGATIVE_PRECISION: {
                return this.negativePrecision();
            }
            case POSITIVE_RECALL: {
                return this.positiveRecall();
            }
            case NEGATIVE_RECALL: {
                return this.negativeRecall();
            }
            case ACCURACY: {
                return this.accuracy();
            }
        }
        throw new IllegalStateException("Unknown representative metric: " + (Object)((Object)this.representative));
    }

    @Override
    public double getBestRepScore() {
        switch (this.representative) {
            case F1: 
            case POSITIVE_PRECISION: 
            case NEGATIVE_PRECISION: 
            case POSITIVE_RECALL: 
            case NEGATIVE_RECALL: 
            case ACCURACY: {
                return 1.0;
            }
        }
        throw new IllegalStateException("Unknown representative metric: " + (Object)((Object)this.representative));
    }

    @Override
    public boolean isHigherRepBetter() {
        return true;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public double positivePrecision() {
        if (this.tp + this.fp == 0) {
            return 0.0;
        }
        return (double)this.tp / (double)(this.tp + this.fp);
    }

    public double negativePrecision() {
        if (this.tn + this.fn == 0) {
            return 0.0;
        }
        return (double)this.tn / (double)(this.tn + this.fn);
    }

    public double positiveRecall() {
        if (this.tp + this.fn == 0) {
            return 0.0;
        }
        return (double)this.tp / (double)(this.tp + this.fn);
    }

    public double negativeRecall() {
        if (this.tn + this.fp == 0) {
            return 0.0;
        }
        return (double)this.tn / (double)(this.tn + this.fp);
    }

    public double f1() {
        return this.fScore(1.0);
    }

    public double fScore(double beta) {
        double precision = this.positivePrecision();
        double recall = this.positiveRecall();
        double denom = Math.pow(beta, 2.0) * precision + recall;
        if (MathUtils.isZero(denom)) {
            return 0.0;
        }
        return (1.0 + Math.pow(beta, 2.0)) * (precision * recall) / denom;
    }

    public double accuracy() {
        int numAtoms = this.tp + this.tn + this.fp + this.fn;
        if (numAtoms == 0) {
            return 0.0;
        }
        return (double)(this.tp + this.tn) / (double)numAtoms;
    }

    @Override
    public String getAllStats() {
        return String.format("Accuracy: %f, F1: %f, Positive Class Precision: %f, Positive Class Recall: %f, Negative Class Precision: %f, Negative Class Recall: %f", this.accuracy(), this.f1(), this.positivePrecision(), this.positiveRecall(), this.negativePrecision(), this.negativeRecall());
    }

    public static enum RepresentativeMetric {
        F1,
        POSITIVE_PRECISION,
        NEGATIVE_PRECISION,
        POSITIVE_RECALL,
        NEGATIVE_RECALL,
        ACCURACY;

    }
}

