/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.gradient.minimizer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.gradient.GradientDescent;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.atom.UnmanagedObservedAtom;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.model.rule.arithmetic.WeightedArithmeticRule;
import org.linqs.psl.model.rule.arithmetic.WeightedGroundArithmeticRule;
import org.linqs.psl.model.rule.arithmetic.expression.ArithmeticRuleExpression;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.ConstantNumber;
import org.linqs.psl.reasoner.duallcqp.DualBCDReasoner;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.TermState;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

public abstract class Minimizer
extends GradientDescent {
    private static final Logger log = Logger.getLogger(Minimizer.class);
    protected float[] latentInferenceIncompatibility;
    protected TermState[] latentInferenceTermState;
    protected float[] latentInferenceAtomValueState;
    protected float[] mapIncompatibility;
    protected float[] mapSquaredIncompatibility;
    protected float[] augmentedInferenceIncompatibility;
    protected float[] augmentedInferenceSquaredIncompatibility;
    protected TermState[] augmentedInferenceTermState;
    protected float[] augmentedInferenceAtomValueState;
    protected float[] augmentedRVAtomGradient;
    protected float[] augmentedDeepAtomGradient;
    protected List<Integer> rvAtomIndexToProxIndex;
    protected List<Integer> proxIndexToRVAtomIndex;
    protected WeightedArithmeticRule[] proxRules;
    protected UnmanagedObservedAtom[] proxRuleObservedAtoms;
    protected int[] proxRuleObservedAtomIndexes;
    protected float[] proxRuleObservedAtomValueGradient;
    protected final float proxRuleWeight;
    protected float parameterMovementTolerance;
    protected float finalParameterMovementTolerance;
    protected float constraintTolerance;
    protected float finalConstraintTolerance;
    protected boolean initializedProxRuleConstants;
    protected int outerIteration;
    protected final float initialSquaredPenaltyCoefficient;
    protected float squaredPenaltyCoefficient;
    protected float squaredPenaltyCoefficientIncreaseRate;
    protected final float initialLinearPenaltyCoefficient;
    protected float linearPenaltyCoefficient;

    public Minimizer(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
        super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);
        this.latentInferenceIncompatibility = new float[this.mutableRules.size()];
        this.latentInferenceTermState = null;
        this.latentInferenceAtomValueState = null;
        this.mapIncompatibility = new float[this.mutableRules.size()];
        this.mapSquaredIncompatibility = new float[this.mutableRules.size()];
        this.augmentedInferenceIncompatibility = new float[this.mutableRules.size()];
        this.augmentedInferenceSquaredIncompatibility = new float[this.mutableRules.size()];
        this.augmentedInferenceTermState = null;
        this.augmentedInferenceAtomValueState = null;
        this.rvAtomIndexToProxIndex = new ArrayList<Integer>();
        this.proxIndexToRVAtomIndex = new ArrayList<Integer>();
        this.proxRules = null;
        this.proxRuleObservedAtoms = null;
        this.proxRuleObservedAtomValueGradient = null;
        this.proxRuleWeight = Options.MINIMIZER_PROX_RULE_WEIGHT.getFloat();
        this.squaredPenaltyCoefficient = this.initialSquaredPenaltyCoefficient = Options.MINIMIZER_INITIAL_SQUARED_PENALTY.getFloat();
        this.squaredPenaltyCoefficientIncreaseRate = Options.MINIMIZER_SQUARED_PENALTY_INCREASE_RATE.getFloat();
        this.linearPenaltyCoefficient = this.initialLinearPenaltyCoefficient = Options.MINIMIZER_INITIAL_LINEAR_PENALTY.getFloat();
        this.parameterMovementTolerance = 1.0f / this.initialSquaredPenaltyCoefficient;
        this.finalParameterMovementTolerance = Options.MINIMIZER_FINAL_PARAMETER_MOVEMENT_CONVERGENCE_TOLERANCE.getFloat();
        this.constraintTolerance = (float)(1.0 / Math.pow(this.initialSquaredPenaltyCoefficient, 0.1f));
        this.finalConstraintTolerance = Options.MINIMIZER_OBJECTIVE_DIFFERENCE_TOLERANCE.getFloat();
        this.initializedProxRuleConstants = false;
        this.outerIteration = 1;
    }

    @Override
    protected void postInitGroundModel() {
        AtomStore atomStore = this.trainInferenceApplication.getTermStore().getDatabase().getAtomStore();
        int unFixedAtomCount = 0;
        for (GroundAtom atom : atomStore) {
            if (atom.isFixed()) continue;
            ++unFixedAtomCount;
        }
        boolean originalMergeConstants = this.trainInferenceApplication.getTermStore().getTermGenerator().getMergeConstants();
        this.trainInferenceApplication.getTermStore().getTermGenerator().setMergeConstants(false);
        this.proxRules = new WeightedArithmeticRule[unFixedAtomCount];
        this.proxRuleObservedAtoms = new UnmanagedObservedAtom[unFixedAtomCount];
        this.proxRuleObservedAtomIndexes = new int[unFixedAtomCount];
        this.proxRuleObservedAtomValueGradient = new float[unFixedAtomCount];
        int originalAtomCount = atomStore.size();
        int proxRuleIndex = 0;
        for (int i = 0; i < originalAtomCount; ++i) {
            GroundAtom atom = atomStore.getAtom(i);
            if (atom.isFixed()) {
                this.rvAtomIndexToProxIndex.add(-1);
                continue;
            }
            this.rvAtomIndexToProxIndex.add(proxRuleIndex);
            this.proxIndexToRVAtomIndex.add(i);
            StandardPredicate proxPredicate = StandardPredicate.get(String.format("augmented%s", atom.getPredicate().getName()), atom.getPredicate().getArgumentTypes());
            if (Predicate.get(proxPredicate.getName()) == null) {
                Predicate.registerPredicate(proxPredicate);
            } else assert (Predicate.get(proxPredicate.getName()).equals(proxPredicate)) : "The 'augmented' prefix on predicate names is reserved for weight learning functionality.";
            this.proxRuleObservedAtoms[proxRuleIndex] = new UnmanagedObservedAtom(proxPredicate, atom.getArguments(), atom.getValue());
            atomStore.addAtom(this.proxRuleObservedAtoms[proxRuleIndex]);
            this.proxRuleObservedAtomIndexes[proxRuleIndex] = atomStore.getAtomIndex(this.proxRuleObservedAtoms[proxRuleIndex]);
            this.proxRules[proxRuleIndex] = new WeightedArithmeticRule(new ArithmeticRuleExpression(Arrays.asList(new ConstantNumber(1.0f), new ConstantNumber(-1.0f)), Arrays.asList(atom, this.proxRuleObservedAtoms[proxRuleIndex]), FunctionComparator.EQ, new ConstantNumber(0.0f)), this.proxRuleWeight, true);
            this.proxRules[proxRuleIndex].setActive(false);
            this.trainInferenceApplication.getTermStore().add(new WeightedGroundArithmeticRule(this.proxRules[proxRuleIndex], Arrays.asList(Float.valueOf(1.0f), Float.valueOf(-1.0f)), Arrays.asList(atom, this.proxRuleObservedAtoms[proxRuleIndex]), FunctionComparator.LTE, 0.0f));
            this.trainInferenceApplication.getTermStore().add(new WeightedGroundArithmeticRule(this.proxRules[proxRuleIndex], Arrays.asList(Float.valueOf(1.0f), Float.valueOf(-1.0f)), Arrays.asList(atom, this.proxRuleObservedAtoms[proxRuleIndex]), FunctionComparator.GTE, 0.0f));
            ++proxRuleIndex;
        }
        this.trainInferenceApplication.getTermStore().getTermGenerator().setMergeConstants(originalMergeConstants);
        super.postInitGroundModel();
        float[] atomValues = this.trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        this.latentInferenceTermState = this.trainInferenceApplication.getTermStore().saveState();
        this.latentInferenceAtomValueState = Arrays.copyOf(atomValues, atomValues.length);
        this.augmentedInferenceTermState = this.trainInferenceApplication.getTermStore().saveState();
        this.augmentedInferenceAtomValueState = Arrays.copyOf(atomValues, atomValues.length);
        this.augmentedRVAtomGradient = new float[atomValues.length];
        this.augmentedDeepAtomGradient = new float[atomValues.length];
    }

    @Override
    protected boolean breakOptimization(int iteration, float objective, float oldObjective) {
        if (iteration > this.maxNumSteps) {
            log.trace("Breaking Weight Learning. Reached maximum number of iterations: {}", this.maxNumSteps);
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        float totalObjectiveDifference = this.computeObjectiveDifference();
        if (totalObjectiveDifference < this.finalConstraintTolerance) {
            log.trace("Breaking Weight Learning. Objective difference {} is less than final constraint tolerance {}.", Float.valueOf(totalObjectiveDifference), Float.valueOf(this.finalConstraintTolerance));
            return true;
        }
        return false;
    }

    @Override
    protected void gradientStep(int iteration) {
        this.parameterMovement = 0.0f;
        this.parameterMovement += this.weightGradientStep(iteration);
        this.parameterMovement += this.internalParameterGradientStep(iteration);
        this.parameterMovement += this.atomGradientStep();
        float totalObjectiveDifference = this.computeObjectiveDifference();
        if (iteration > 0 && this.parameterMovement < this.parameterMovementTolerance) {
            ++this.outerIteration;
            if (totalObjectiveDifference < this.constraintTolerance) {
                if (totalObjectiveDifference < this.finalConstraintTolerance && this.parameterMovement < this.finalParameterMovementTolerance) {
                    return;
                }
                this.linearPenaltyCoefficient += 2.0f * this.squaredPenaltyCoefficient * totalObjectiveDifference;
                this.constraintTolerance = (float)((double)this.constraintTolerance / Math.pow(this.squaredPenaltyCoefficient, 0.9));
                this.parameterMovementTolerance /= this.squaredPenaltyCoefficient;
            } else {
                this.squaredPenaltyCoefficient = this.squaredPenaltyCoefficientIncreaseRate * this.squaredPenaltyCoefficient;
                this.constraintTolerance = (float)(1.0 / Math.pow(this.squaredPenaltyCoefficient, 0.1));
                this.parameterMovementTolerance = 1.0f / this.squaredPenaltyCoefficient;
            }
        }
        log.trace("Outer iteration: {}, Objective Difference: {}, Parameter Movement: {}, Squared Penalty Coefficient: {}, Linear Penalty Coefficient: {}, Constraint Tolerance: {}, parameterMovementTolerance: {}.", this.outerIteration, Float.valueOf(totalObjectiveDifference), Float.valueOf(this.parameterMovement), Float.valueOf(this.squaredPenaltyCoefficient), Float.valueOf(this.linearPenaltyCoefficient), Float.valueOf(this.constraintTolerance), Float.valueOf(this.parameterMovementTolerance));
    }

    @Override
    protected float internalParameterGradientStep(int iteration) {
        float proxRuleObservedAtomsValueMovement = 0.0f;
        float stepSize = this.computeStepSize(iteration);
        float[] atomValues = this.trainInferenceApplication.getTermStore().getDatabase().getAtomStore().getAtomValues();
        for (int i = 0; i < this.proxRules.length; ++i) {
            float newProxRuleObservedAtomsValue = Math.min(Math.max(this.proxRuleObservedAtoms[i].getValue() - stepSize * this.proxRuleObservedAtomValueGradient[i], 0.0f), 1.0f);
            proxRuleObservedAtomsValueMovement += Math.abs(this.proxRuleObservedAtoms[i].getValue() - newProxRuleObservedAtomsValue);
            this.proxRuleObservedAtoms[i]._assumeValue(newProxRuleObservedAtomsValue);
            atomValues[this.proxRuleObservedAtomIndexes[i]] = newProxRuleObservedAtomsValue;
            this.augmentedInferenceAtomValueState[this.proxRuleObservedAtomIndexes[i]] = newProxRuleObservedAtomsValue;
        }
        return proxRuleObservedAtomsValueMovement;
    }

    protected void initializeProximityRuleConstants() {
        this.fixLabeledRandomVariables();
        log.trace("Performing Latent Inference.");
        this.computeMAPStateWithWarmStart(this.trainInferenceApplication, this.latentInferenceTermState, this.latentInferenceAtomValueState);
        this.inTrainingMAPState = true;
        this.unfixLabeledRandomVariables();
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        float[] atomValues = atomStore.getAtomValues();
        System.arraycopy(this.latentInferenceAtomValueState, 0, this.augmentedInferenceAtomValueState, 0, this.latentInferenceAtomValueState.length);
        for (int i = 0; i < this.proxRules.length; ++i) {
            this.proxRuleObservedAtoms[i]._assumeValue(this.latentInferenceAtomValueState[this.proxIndexToRVAtomIndex.get(i)]);
            atomValues[this.proxRuleObservedAtomIndexes[i]] = this.latentInferenceAtomValueState[this.proxIndexToRVAtomIndex.get(i)];
            this.augmentedInferenceAtomValueState[this.proxRuleObservedAtomIndexes[i]] = this.latentInferenceAtomValueState[this.proxIndexToRVAtomIndex.get(i)];
        }
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom randomVariableAtom = entry.getKey();
            ObservedAtom observedAtom = entry.getValue();
            int rvAtomIndex = atomStore.getAtomIndex(randomVariableAtom);
            int proxRuleIndex = this.rvAtomIndexToProxIndex.get(rvAtomIndex);
            this.proxRuleObservedAtoms[proxRuleIndex]._assumeValue(observedAtom.getValue());
            atomValues[this.proxRuleObservedAtomIndexes[proxRuleIndex]] = observedAtom.getValue();
            this.augmentedInferenceAtomValueState[this.proxRuleObservedAtomIndexes[proxRuleIndex]] = observedAtom.getValue();
        }
        this.initializedProxRuleConstants = true;
    }

    @Override
    protected void computeIterationStatistics() {
        this.computeFullInferenceStatistics();
        if (!this.initializedProxRuleConstants) {
            this.initializeProximityRuleConstants();
        }
        this.computeAugmentedInferenceStatistics();
        this.computeProxRuleObservedAtomValueGradient();
    }

    @Override
    protected void computeTotalAtomGradient() {
        float totalEnergyDifference = this.computeObjectiveDifference();
        for (int i = 0; i < this.trainInferenceApplication.getDatabase().getAtomStore().size(); ++i) {
            float rvGradientDifference = this.augmentedRVAtomGradient[i] - this.MAPRVAtomGradient[i];
            float deepGradientDifference = this.augmentedDeepAtomGradient[i] - this.MAPDeepAtomGradient[i];
            this.rvAtomGradient[i] = this.squaredPenaltyCoefficient * totalEnergyDifference * rvGradientDifference + this.linearPenaltyCoefficient * rvGradientDifference;
            this.deepAtomGradient[i] = this.squaredPenaltyCoefficient * totalEnergyDifference * deepGradientDifference + this.linearPenaltyCoefficient * deepGradientDifference;
        }
    }

    protected void computeProxRuleObservedAtomValueGradient() {
        Arrays.fill(this.proxRuleObservedAtomValueGradient, 0.0f);
        this.addSupervisedProxRuleObservedAtomValueGradient();
        this.addAugmentedLagrangianProxRuleConstantsGradient();
    }

    private void computeFullInferenceStatistics() {
        log.trace("Running Inference.");
        this.computeMAPStateWithWarmStart(this.trainInferenceApplication, this.trainMAPTermState, this.trainMAPAtomValueState);
        this.inTrainingMAPState = true;
        this.computeCurrentIncompatibility(this.mapIncompatibility);
        this.computeCurrentSquaredIncompatibility(this.mapSquaredIncompatibility);
        this.trainInferenceApplication.getReasoner().computeOptimalValueGradient(this.trainInferenceApplication.getTermStore(), this.MAPRVAtomGradient, this.MAPDeepAtomGradient);
    }

    protected void computeAugmentedInferenceStatistics() {
        this.activateAugmentedInferenceProxTerms();
        log.trace("Running Augmented Inference.");
        this.computeMAPStateWithWarmStart(this.trainInferenceApplication, this.augmentedInferenceTermState, this.augmentedInferenceAtomValueState);
        this.inTrainingMAPState = true;
        this.computeCurrentIncompatibility(this.augmentedInferenceIncompatibility);
        this.computeCurrentSquaredIncompatibility(this.augmentedInferenceSquaredIncompatibility);
        this.trainInferenceApplication.getReasoner().computeOptimalValueGradient(this.trainInferenceApplication.getTermStore(), this.augmentedRVAtomGradient, this.augmentedDeepAtomGradient);
        this.deactivateAugmentedInferenceProxTerms();
    }

    private void activateAugmentedInferenceProxTerms() {
        for (WeightedArithmeticRule augmentedInferenceProxRule : this.proxRules) {
            augmentedInferenceProxRule.setActive(true);
        }
        this.inTrainingMAPState = false;
    }

    private void deactivateAugmentedInferenceProxTerms() {
        for (WeightedArithmeticRule augmentedInferenceProxRule : this.proxRules) {
            augmentedInferenceProxRule.setActive(false);
        }
        this.inTrainingMAPState = false;
    }

    @Override
    protected float computeLearningLoss() {
        float totalObjectiveDifference = this.computeObjectiveDifference();
        float supervisedLoss = this.computeSupervisedLoss();
        float totalProxValue = this.computeTotalProxValue(new float[this.proxRuleObservedAtoms.length]);
        log.trace("Total Prox Loss: {}, Total objective difference: {}, Supervised Loss: {}", Float.valueOf(totalProxValue), Float.valueOf(totalObjectiveDifference), Float.valueOf(supervisedLoss));
        return this.squaredPenaltyCoefficient / 2.0f * (float)Math.pow(totalObjectiveDifference, 2.0) + this.linearPenaltyCoefficient * totalObjectiveDifference + supervisedLoss;
    }

    private float computeObjectiveDifference() {
        float[] incompatibilityDifference = new float[this.mutableRules.size()];
        float totalEnergyDifference = this.computeTotalEnergyDifference(incompatibilityDifference);
        float totalProxValue = this.computeTotalProxValue(new float[this.proxRuleObservedAtoms.length]);
        return totalEnergyDifference + totalProxValue;
    }

    protected abstract float computeSupervisedLoss();

    protected void addAugmentedLagrangianProxRuleConstantsGradient() {
        int i;
        float[] incompatibilityDifference = new float[this.mutableRules.size()];
        float totalEnergyDifference = this.computeTotalEnergyDifference(incompatibilityDifference);
        float[] proxRuleIncompatibility = new float[this.proxRuleObservedAtoms.length];
        float totalProxValue = this.computeTotalProxValue(proxRuleIncompatibility);
        float[] proxRuleObservedAtomValueMoreauGradient = new float[this.proxRuleObservedAtoms.length];
        for (i = 0; i < this.proxRuleObservedAtoms.length; ++i) {
            proxRuleObservedAtomValueMoreauGradient[i] = 2.0f * this.proxRuleWeight * proxRuleIncompatibility[i];
        }
        for (i = 0; i < this.proxRuleObservedAtoms.length; ++i) {
            int n = i;
            this.proxRuleObservedAtomValueGradient[n] = this.proxRuleObservedAtomValueGradient[n] + this.linearPenaltyCoefficient * proxRuleObservedAtomValueMoreauGradient[i];
            int n2 = i;
            this.proxRuleObservedAtomValueGradient[n2] = this.proxRuleObservedAtomValueGradient[n2] + this.squaredPenaltyCoefficient * (totalEnergyDifference + totalProxValue) * proxRuleObservedAtomValueMoreauGradient[i];
        }
    }

    protected abstract void addSupervisedProxRuleObservedAtomValueGradient();

    @Override
    protected void addLearningLossWeightGradient() {
        float[] incompatibilityDifference = new float[this.mutableRules.size()];
        float totalEnergyDifference = this.computeTotalEnergyDifference(incompatibilityDifference);
        float totalProxValue = this.computeTotalProxValue(new float[this.proxRuleObservedAtoms.length]);
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            int n = i;
            this.weightGradient[n] = this.weightGradient[n] + this.linearPenaltyCoefficient * incompatibilityDifference[i];
            int n2 = i;
            this.weightGradient[n2] = this.weightGradient[n2] + this.squaredPenaltyCoefficient * (totalEnergyDifference + totalProxValue) * incompatibilityDifference[i];
        }
    }

    private float computeTotalEnergyDifference(float[] incompatibilityDifference) {
        float regularizationParameter = 0.0f;
        if (this.trainInferenceApplication.getReasoner() instanceof DualBCDReasoner) {
            regularizationParameter = (float)((DualBCDReasoner)this.trainInferenceApplication.getReasoner()).regularizationParameter;
        }
        float totalEnergyDifference = 0.0f;
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            incompatibilityDifference[i] = this.augmentedInferenceIncompatibility[i] - this.mapIncompatibility[i];
            if (((WeightedRule)this.mutableRules.get(i)).isSquared()) {
                totalEnergyDifference += (((WeightedRule)this.mutableRules.get(i)).getWeight() + regularizationParameter) * (this.augmentedInferenceIncompatibility[i] - this.mapIncompatibility[i]);
                continue;
            }
            totalEnergyDifference += ((WeightedRule)this.mutableRules.get(i)).getWeight() * (this.augmentedInferenceIncompatibility[i] - this.mapIncompatibility[i]);
            totalEnergyDifference += regularizationParameter * (this.augmentedInferenceSquaredIncompatibility[i] - this.mapSquaredIncompatibility[i]);
        }
        GroundAtom[] atoms = this.trainInferenceApplication.getDatabase().getAtomStore().getAtoms();
        float augmentedInferenceLCQPRegularization = 0.0f;
        float fullInferenceLCQPRegularization = 0.0f;
        for (int i = 0; i < this.trainInferenceApplication.getDatabase().getAtomStore().size(); ++i) {
            if (atoms[i].isFixed()) continue;
            augmentedInferenceLCQPRegularization = (float)((double)augmentedInferenceLCQPRegularization + (double)regularizationParameter * Math.pow(this.augmentedInferenceAtomValueState[i], 2.0));
            fullInferenceLCQPRegularization = (float)((double)fullInferenceLCQPRegularization + (double)regularizationParameter * Math.pow(this.trainMAPAtomValueState[i], 2.0));
        }
        return totalEnergyDifference += augmentedInferenceLCQPRegularization - fullInferenceLCQPRegularization;
    }

    private float computeTotalProxValue(float[] proxRuleIncompatibility) {
        float totalProxValue = 0.0f;
        for (int i = 0; i < this.proxRules.length; ++i) {
            proxRuleIncompatibility[i] = this.proxRuleObservedAtoms[i].getValue() - this.augmentedInferenceAtomValueState[this.proxIndexToRVAtomIndex.get(i)];
            totalProxValue = (float)((double)totalProxValue + Math.pow(proxRuleIncompatibility[i], 2.0));
        }
        totalProxValue = this.proxRuleWeight * totalProxValue;
        return totalProxValue;
    }

    private void clipProxRuleObservedAtomValueGradient(float[] gradient) {
        for (int i = 0; i < gradient.length; ++i) {
            if (MathUtils.isZero(this.proxRuleObservedAtoms[i].getValue()) && gradient[i] > 0.0f) {
                gradient[i] = 0.0f;
                continue;
            }
            if (!MathUtils.equals(this.proxRuleObservedAtoms[i].getValue(), 1.0f) || !(gradient[i] < 0.0f)) continue;
            gradient[i] = 0.0f;
        }
    }

    @Override
    protected float computeGradientNorm() {
        float gradientNorm = super.computeGradientNorm();
        float[] boxClippedProxRuleObservedAtomValueGradient = (float[])this.proxRuleObservedAtomValueGradient.clone();
        this.clipProxRuleObservedAtomValueGradient(boxClippedProxRuleObservedAtomValueGradient);
        float boxClippedProxRuleObservedAtomValueGradientNorm = MathUtils.pNorm(boxClippedProxRuleObservedAtomValueGradient, this.stoppingGradientNorm);
        return gradientNorm += boxClippedProxRuleObservedAtomValueGradientNorm;
    }

    protected void computeCurrentSquaredIncompatibility(float[] incompatibilityArray) {
        Arrays.fill(incompatibilityArray, 0.0f);
        float[] atomValues = this.trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        for (Object rawTerm : this.trainInferenceApplication.getTermStore()) {
            Integer index;
            ReasonerTerm term = (ReasonerTerm)rawTerm;
            if (!(term.getRule() instanceof WeightedRule) || (index = (Integer)this.ruleIndexMap.get((WeightedRule)term.getRule())) == null) continue;
            int n = index;
            incompatibilityArray[n] = incompatibilityArray[n] + term.evaluateSquaredHingeLoss(atomValues);
        }
    }

    protected void fixLabeledRandomVariables() {
        AtomStore atomStore = this.trainInferenceApplication.getTermStore().getDatabase().getAtomStore();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom randomVariableAtom = entry.getKey();
            ObservedAtom observedAtom = entry.getValue();
            int atomIndex = atomStore.getAtomIndex(randomVariableAtom);
            atomStore.getAtoms()[atomIndex] = observedAtom;
            atomStore.getAtomValues()[atomIndex] = observedAtom.getValue();
            this.latentInferenceAtomValueState[atomIndex] = observedAtom.getValue();
            randomVariableAtom.setValue(observedAtom.getValue());
        }
        this.inTrainingMAPState = false;
    }

    protected void unfixLabeledRandomVariables() {
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom randomVariableAtom = entry.getKey();
            int atomIndex = atomStore.getAtomIndex(randomVariableAtom);
            atomStore.getAtoms()[atomIndex] = randomVariableAtom;
        }
        this.inTrainingMAPState = false;
    }
}

