/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.evaluation.statistics;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;

public class AUCEvaluator
extends Evaluator {
    private double threshold;
    private RepresentativeMetric representative;
    private List<GroundAtom> truth;
    private List<GroundAtom> predicted;

    public AUCEvaluator() {
        this(Options.EVAL_AUC_THRESHOLD.getDouble());
    }

    public AUCEvaluator(double threshold) {
        this(threshold, Options.EVAL_AUC_REPRESENTATIVE.getString());
    }

    public AUCEvaluator(double threshold, String representative) {
        this(threshold, RepresentativeMetric.valueOf(representative.toUpperCase()));
    }

    public AUCEvaluator(double threshold, RepresentativeMetric representative) {
        if (threshold < 0.0 || threshold > 1.0) {
            throw new IllegalArgumentException("Threhsold must be in (0, 1). Found: " + threshold);
        }
        this.threshold = threshold;
        this.representative = representative;
        this.truth = new ArrayList<GroundAtom>();
        this.predicted = new ArrayList<GroundAtom>();
    }

    @Override
    public void compute(TrainingMap trainingMap) {
        this.compute(trainingMap, null);
    }

    @Override
    public void compute(TrainingMap trainingMap, StandardPredicate predicate) {
        this.truth = new ArrayList<GroundAtom>(trainingMap.getLabelMap().size());
        this.predicted = new ArrayList<GroundAtom>(trainingMap.getLabelMap().size());
        for (Map.Entry<GroundAtom, GroundAtom> entry : this.getMap(trainingMap)) {
            if (predicate != null && entry.getKey().getPredicate() != predicate) continue;
            this.truth.add(entry.getValue());
            this.predicted.add(entry.getKey());
        }
        Collections.sort(this.truth);
        Collections.sort(this.predicted);
    }

    @Override
    public double getRepMetric() {
        switch (this.representative) {
            case AUROC: {
                return this.auroc();
            }
            case POSITIVE_AUPRC: {
                return this.positiveAUPRC();
            }
            case NEGATIVE_AUPRC: {
                return this.negativeAUPRC();
            }
        }
        throw new IllegalStateException("Unknown representative metric: " + (Object)((Object)this.representative));
    }

    @Override
    public double getBestRepScore() {
        switch (this.representative) {
            case AUROC: 
            case POSITIVE_AUPRC: 
            case NEGATIVE_AUPRC: {
                return 1.0;
            }
        }
        throw new IllegalStateException("Unknown representative metric: " + (Object)((Object)this.representative));
    }

    @Override
    public boolean isHigherRepBetter() {
        return true;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public double positiveAUPRC() {
        return this.auprc(true);
    }

    public double negativeAUPRC() {
        return this.auprc(false);
    }

    private double auprc(boolean positiveIsTrue) {
        int totalPositives = 0;
        for (GroundAtom atom : this.truth) {
            if ((double)atom.getValue() >= this.threshold ^ positiveIsTrue) continue;
            ++totalPositives;
        }
        if (totalPositives == 0) {
            return 0.0;
        }
        double area = 0.0;
        int tp = 0;
        int fp = 0;
        double prevY = 1.0;
        double prevX = 0.0;
        for (GroundAtom atom : this.predicted) {
            Boolean rawLabel = this.getLabel(atom);
            if (rawLabel == null) continue;
            boolean label = rawLabel;
            if (!positiveIsTrue) {
                boolean bl = label = !label;
            }
            if (label) {
                ++tp;
            } else {
                ++fp;
            }
            double newY = (double)tp / (double)(tp + fp);
            double newX = (double)tp / (double)totalPositives;
            area += (newX - prevX) * Math.max(prevY, newY) - 0.5 * ((newX - prevX) * Math.abs(newY - prevY));
            prevY = newY;
            prevX = newX;
        }
        return area;
    }

    public double auroc() {
        int totalPositives = 0;
        for (GroundAtom atom : this.truth) {
            if (!((double)atom.getValue() > this.threshold)) continue;
            ++totalPositives;
        }
        int totalNegatives = this.predicted.size() - totalPositives;
        if (totalPositives == 0) {
            return 0.0;
        }
        if (totalNegatives == 0) {
            return 1.0;
        }
        double area = 0.0;
        int tp = 0;
        int fp = 0;
        double prevY = 0.0;
        double prevX = 0.0;
        for (GroundAtom atom : this.predicted) {
            Boolean label = this.getLabel(atom);
            if (label == null) continue;
            if (label.booleanValue()) {
                ++tp;
            } else {
                ++fp;
            }
            double newY = (double)tp / (double)totalPositives;
            double newX = (double)fp / (double)totalNegatives;
            area += 0.5 * (newX - prevX) * Math.abs(newY - prevY) + (newX - prevX) * newY;
            prevY = newY;
            prevX = newX;
        }
        return area += 0.5 * (1.0 - prevX) * Math.abs(1.0 - prevY) + (1.0 - prevX) * 1.0;
    }

    @Override
    public String getAllStats() {
        return String.format("AUROC: %f, Positive Class AUPRC: %f, Negative Class AUPRC: %f", this.auroc(), this.positiveAUPRC(), this.negativeAUPRC());
    }

    private Boolean getLabel(GroundAtom atom) {
        int index = this.truth.indexOf(atom);
        if (index == -1) {
            return null;
        }
        return (double)this.truth.get(index).getValue() > this.threshold;
    }

    public static enum RepresentativeMetric {
        AUROC,
        POSITIVE_AUPRC,
        NEGATIVE_AUPRC;

    }
}

