/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.gradient.minimizer;

import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.gradient.minimizer.Minimizer;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.Rule;

public class BinaryCrossEntropy
extends Minimizer {
    public BinaryCrossEntropy(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
        super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);
    }

    @Override
    protected float computeSupervisedLoss() {
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        float supervisedLoss = 0.0f;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom randomVariableAtom = entry.getKey();
            ObservedAtom observedAtom = entry.getValue();
            int atomIndex = atomStore.getAtomIndex(randomVariableAtom);
            int proxRuleIndex = (Integer)this.rvAtomIndexToProxIndex.get(atomIndex);
            supervisedLoss = (float)((double)supervisedLoss + -1.0 * ((double)observedAtom.getValue() * Math.log(Math.max(this.proxRuleObservedAtoms[proxRuleIndex].getValue(), 1.0E-4f)) + (double)(1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - this.proxRuleObservedAtoms[proxRuleIndex].getValue(), 1.0E-4f))));
        }
        return supervisedLoss;
    }

    @Override
    protected void addSupervisedProxRuleObservedAtomValueGradient() {
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            int proxRuleIndex;
            RandomVariableAtom randomVariableAtom = entry.getKey();
            ObservedAtom observedAtom = entry.getValue();
            int atomIndex = atomStore.getAtomIndex(randomVariableAtom);
            int n = proxRuleIndex = ((Integer)this.rvAtomIndexToProxIndex.get(atomIndex)).intValue();
            this.proxRuleObservedAtomValueGradient[n] = this.proxRuleObservedAtomValueGradient[n] + -1.0f * (observedAtom.getValue() / Math.max(this.proxRuleObservedAtoms[proxRuleIndex].getValue(), 1.0E-4f) - (1.0f - observedAtom.getValue()) / Math.max(1.0f - this.proxRuleObservedAtoms[proxRuleIndex].getValue(), 1.0E-4f));
        }
    }
}

