/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.gradient.minimizer;

import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.gradient.minimizer.Minimizer;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.Rule;

public class SquaredError
extends Minimizer {
    public SquaredError(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
        super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);
    }

    @Override
    protected float computeSupervisedLoss() {
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        float supervisedLoss = 0.0f;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom randomVariableAtom = entry.getKey();
            ObservedAtom observedAtom = entry.getValue();
            int atomIndex = atomStore.getAtomIndex(randomVariableAtom);
            supervisedLoss = (float)((double)supervisedLoss + Math.pow(this.proxRuleObservedAtoms[(Integer)this.rvAtomIndexToProxIndex.get(atomIndex)].getValue() - observedAtom.getValue(), 2.0));
        }
        return supervisedLoss;
    }

    @Override
    protected void addSupervisedProxRuleObservedAtomValueGradient() {
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            int proxRuleIndex;
            RandomVariableAtom randomVariableAtom = entry.getKey();
            ObservedAtom observedAtom = entry.getValue();
            int atomIndex = atomStore.getAtomIndex(randomVariableAtom);
            int n = proxRuleIndex = ((Integer)this.rvAtomIndexToProxIndex.get(atomIndex)).intValue();
            this.proxRuleObservedAtomValueGradient[n] = this.proxRuleObservedAtomValueGradient[n] + 2.0f * (this.proxRuleObservedAtoms[proxRuleIndex].getValue() - observedAtom.getValue());
        }
    }
}

