/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.gradient.optimalvalue;

import java.util.List;
import org.linqs.psl.application.learning.weight.gradient.optimalvalue.OptimalValue;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;

public class StructuredPerceptron
extends OptimalValue {
    protected float[] MAPIncompatibility;

    public StructuredPerceptron(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
        super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);
        this.MAPIncompatibility = new float[this.mutableRules.size()];
    }

    @Override
    protected void computeIterationStatistics() {
        this.computeLatentInferenceIncompatibility();
        this.computeFullInferenceIncompatibility();
    }

    private void computeFullInferenceIncompatibility() {
        this.computeMAPStateWithWarmStart(this.trainInferenceApplication, this.trainMAPTermState, this.trainMAPAtomValueState);
        this.inTrainingMAPState = true;
        this.computeCurrentIncompatibility(this.MAPIncompatibility);
        this.trainInferenceApplication.getReasoner().parallelComputeGradient(this.trainInferenceApplication.getTermStore(), this.MAPRVAtomGradient, this.MAPDeepAtomGradient);
    }

    @Override
    protected float computeLearningLoss() {
        float energyDifference = 0.0f;
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            energyDifference += ((WeightedRule)this.mutableRules.get(i)).getWeight() * (this.latentInferenceIncompatibility[i] - this.MAPIncompatibility[i]);
        }
        return energyDifference;
    }

    @Override
    protected void addLearningLossWeightGradient() {
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            int n = i;
            this.weightGradient[n] = this.weightGradient[n] + (this.latentInferenceIncompatibility[i] - this.MAPIncompatibility[i]);
        }
    }

    @Override
    protected void computeTotalAtomGradient() {
        for (int i = 0; i < this.rvAtomGradient.length; ++i) {
            this.rvAtomGradient[i] = this.rvLatentAtomGradient[i] - this.MAPRVAtomGradient[i];
            this.deepAtomGradient[i] = this.deepLatentAtomGradient[i] - this.MAPDeepAtomGradient[i];
        }
    }
}

