/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.evaluation.statistics;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CategoricalEvaluator
extends Evaluator {
    private static final Logger log = LoggerFactory.getLogger(CategoricalEvaluator.class);
    public static final String DELIM = ":";
    private Set<Integer> virtualCategoryIndexes;
    private RepresentativeMetric representative;
    private String defaultPredicate;
    private int hits;
    private int misses;

    public CategoricalEvaluator() {
        this(RepresentativeMetric.valueOf(Options.EVAL_CAT_REPRESENTATIVE.getString()), StringUtils.splitInt(Options.EVAL_CAT_CATEGORY_INDEXES.getString(), DELIM));
    }

    public CategoricalEvaluator(int ... rawCategoryIndexes) {
        this(Options.EVAL_CAT_REPRESENTATIVE.getString(), rawCategoryIndexes);
    }

    public CategoricalEvaluator(String representative, int ... rawCategoryIndexes) {
        this(RepresentativeMetric.valueOf(representative.toUpperCase()), rawCategoryIndexes);
    }

    public CategoricalEvaluator(RepresentativeMetric representative, int ... rawCategoryIndexes) {
        this.representative = representative;
        this.setVirtualCategoryIndexes(rawCategoryIndexes);
        this.defaultPredicate = Options.EVAL_CAT_DEFAULT_PREDICATE.getString();
        this.hits = 0;
        this.misses = 0;
    }

    public void setVirtualCategoryIndexes(int ... rawCategoryIndexes) {
        if (rawCategoryIndexes == null || rawCategoryIndexes.length == 0) {
            throw new IllegalArgumentException("Found no category indexes.");
        }
        this.virtualCategoryIndexes = new HashSet<Integer>(rawCategoryIndexes.length);
        for (int catIndex : rawCategoryIndexes) {
            this.virtualCategoryIndexes.add(catIndex);
        }
        log.debug("Virtual category indexes: [{}].", (Object)StringUtils.join(", ", this.virtualCategoryIndexes.toArray()));
    }

    @Override
    public void compute(TrainingMap trainingMap) {
        if (this.defaultPredicate == null) {
            throw new UnsupportedOperationException("CategoricalEvaluators must have a default predicate set (through config).");
        }
        this.compute(trainingMap, StandardPredicate.get(this.defaultPredicate));
    }

    @Override
    public void compute(TrainingMap trainingMap, StandardPredicate predicate) {
        assert (predicate != null);
        this.hits = 0;
        this.misses = 0;
        Set<GroundAtom> predictedCategories = this.getPredictedCategories(trainingMap, predicate);
        for (GroundAtom truthAtom : trainingMap.getAllTruths()) {
            if (truthAtom.getPredicate() != predicate || (double)truthAtom.getValue() < 1.0) continue;
            if (predictedCategories.contains(truthAtom)) {
                ++this.hits;
                continue;
            }
            ++this.misses;
        }
    }

    @Override
    public double getRepMetric() {
        switch (this.representative) {
            case ACCURACY: {
                return this.accuracy();
            }
        }
        throw new IllegalStateException("Unknown representative metric: " + (Object)((Object)this.representative));
    }

    @Override
    public double getBestRepScore() {
        switch (this.representative) {
            case ACCURACY: {
                return 1.0;
            }
        }
        throw new IllegalStateException("Unknown representative metric: " + (Object)((Object)this.representative));
    }

    @Override
    public boolean isHigherRepBetter() {
        return true;
    }

    public double accuracy() {
        if (this.hits + this.misses == 0) {
            return 0.0;
        }
        return (double)this.hits / (double)(this.hits + this.misses);
    }

    @Override
    public String getAllStats() {
        return String.format("Categorical Accuracy: %f", this.accuracy());
    }

    private Set<Integer> getTrueCategoryIndexes(StandardPredicate predicate) {
        HashSet<Integer> categoryIndexes = new HashSet<Integer>();
        for (Integer rawIndex : this.virtualCategoryIndexes) {
            int index = rawIndex;
            if (index < 0) {
                index += predicate.getArity();
            }
            if (index < 0 || index >= predicate.getArity()) {
                throw new RuntimeException(String.format("Categorical index (%d) out of bounds for %s/%d.", index, predicate.getName(), predicate.getArity()));
            }
            categoryIndexes.add(index);
        }
        log.trace("True category indexes for {}: [{}].", (Object)predicate.getName(), (Object)StringUtils.join(", ", categoryIndexes.toArray()));
        return categoryIndexes;
    }

    protected Set<GroundAtom> getPredictedCategories(TrainingMap trainingMap, StandardPredicate predicate) {
        Map predictedCategories = null;
        Set<Integer> categoryIndexes = this.getTrueCategoryIndexes(predicate);
        for (GroundAtom atom : this.getTargets(trainingMap)) {
            Map ignoreWarning;
            if (atom.getPredicate() != predicate) continue;
            predictedCategories = ignoreWarning = (Map)this.putPredictedCategories(predictedCategories, atom, 0, categoryIndexes);
        }
        HashSet<GroundAtom> rtn = new HashSet<GroundAtom>();
        this.collectPredictedCategories(predictedCategories, rtn);
        return rtn;
    }

    private Object putPredictedCategories(Object currentNode, GroundAtom atom, int argIndex, Set<Integer> categoryIndexes) {
        Map<Constant, Object> predictedCategories;
        assert (argIndex <= atom.getArity());
        if (categoryIndexes.contains(argIndex)) {
            return this.putPredictedCategories(currentNode, atom, argIndex + 1, categoryIndexes);
        }
        if (argIndex == atom.getArity()) {
            if (currentNode == null) {
                return atom;
            }
            GroundAtom oldBest = (GroundAtom)currentNode;
            if (atom.getValue() > oldBest.getValue()) {
                return atom;
            }
            if (MathUtils.equals(atom.getValue(), oldBest.getValue())) {
                if (RandUtils.nextBoolean()) {
                    return atom;
                }
                return oldBest;
            }
            return oldBest;
        }
        if (currentNode == null) {
            predictedCategories = new HashMap();
        } else {
            Map ignoreWarning = (Map)currentNode;
            predictedCategories = ignoreWarning;
        }
        Constant arg = atom.getArguments()[argIndex];
        predictedCategories.put(arg, this.putPredictedCategories(predictedCategories.get(arg), atom, argIndex + 1, categoryIndexes));
        return predictedCategories;
    }

    private void collectPredictedCategories(Map<Constant, Object> predictedCategories, Set<GroundAtom> result) {
        for (Object value : predictedCategories.values()) {
            if (value instanceof GroundAtom) {
                result.add((GroundAtom)value);
                continue;
            }
            Map ignoreWarning = (Map)value;
            this.collectPredictedCategories(ignoreWarning, result);
        }
    }

    public static enum RepresentativeMetric {
        ACCURACY;

    }
}

