/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.sgd;

import java.util.Arrays;
import java.util.List;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.ArrayUtils;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

public class SGDReasoner
extends Reasoner<SGDObjectiveTerm> {
    private static final Logger log = Logger.getLogger(SGDReasoner.class);
    private static final float EPSILON = 1.0E-8f;
    private boolean firstOrderBreak;
    private float firstOrderTolerance;
    private float firstOrderNorm;
    private float[] prevGradient;
    private float adamBeta1;
    private float adamBeta2;
    private float[] accumulatedGradientSquares;
    private float[] accumulatedGradientMean;
    private float[] accumulatedGradientVariance;
    private float initialLearningRate;
    private float learningRateInverseScaleExp;
    private boolean coordinateStep;
    private SGDLearningSchedule learningSchedule;
    private SGDExtension sgdExtension;

    public SGDReasoner() {
        this.maxIterations = Options.SGD_MAX_ITER.getInt();
        this.firstOrderBreak = Options.SGD_FIRST_ORDER_BREAK.getBoolean();
        this.firstOrderTolerance = Options.SGD_FIRST_ORDER_THRESHOLD.getFloat();
        this.firstOrderNorm = Options.SGD_FIRST_ORDER_NORM.getFloat();
        this.initialLearningRate = Options.SGD_LEARNING_RATE.getFloat();
        this.learningRateInverseScaleExp = Options.SGD_INVERSE_TIME_EXP.getFloat();
        this.learningSchedule = SGDLearningSchedule.valueOf(Options.SGD_LEARNING_SCHEDULE.getString().toUpperCase());
        this.coordinateStep = Options.SGD_COORDINATE_STEP.getBoolean();
        this.sgdExtension = SGDExtension.valueOf(Options.SGD_EXTENSION.getString().toUpperCase());
        this.prevGradient = null;
        this.adamBeta1 = Options.SGD_ADAM_BETA_1.getFloat();
        this.adamBeta2 = Options.SGD_ADAM_BETA_2.getFloat();
        this.accumulatedGradientSquares = null;
        this.accumulatedGradientMean = null;
        this.accumulatedGradientVariance = null;
    }

    @Override
    public double optimize(TermStore<SGDObjectiveTerm> termStore, List<EvaluationInstance> evaluations, TrainingMap trainingMap) {
        termStore.initForOptimization();
        this.initForOptimization(termStore);
        float learningRate = 0.0f;
        float objective = 0.0f;
        float oldObjective = Float.POSITIVE_INFINITY;
        float[] prevVariableValues = null;
        float lowestObjective = Float.POSITIVE_INFINITY;
        float[] lowestVariableValues = null;
        long totalTime = 0L;
        boolean breakSGD = false;
        int iteration = 1;
        while (!breakSGD) {
            long start = System.currentTimeMillis();
            objective = 0.0f;
            learningRate = this.calculateAnnealedLearningRate(iteration);
            if (iteration > 1) {
                Arrays.fill(this.prevGradient, 0.0f);
            }
            for (SGDObjectiveTerm term : termStore) {
                if (!term.isActive()) continue;
                if (iteration > 1) {
                    objective += term.evaluate(prevVariableValues);
                    this.addTermGradient(term, this.prevGradient, prevVariableValues, termStore.getVariableAtoms());
                }
                this.variableUpdate(term, termStore, iteration, learningRate);
            }
            this.evaluate(termStore, iteration, evaluations, trainingMap);
            if (iteration == 1) {
                this.prevGradient = new float[termStore.getVariableValues().length];
                prevVariableValues = Arrays.copyOf(termStore.getVariableValues(), termStore.getVariableValues().length);
                lowestVariableValues = Arrays.copyOf(termStore.getVariableValues(), termStore.getVariableValues().length);
            } else {
                this.clipGradient(prevVariableValues, this.prevGradient);
                breakSGD = this.breakOptimization(iteration, termStore, new Reasoner.ObjectiveResult(objective, 0L), new Reasoner.ObjectiveResult(oldObjective, 0L));
                if (objective < lowestObjective) {
                    lowestObjective = objective;
                    System.arraycopy(prevVariableValues, 0, lowestVariableValues, 0, lowestVariableValues.length);
                }
                System.arraycopy(termStore.getVariableValues(), 0, prevVariableValues, 0, prevVariableValues.length);
                oldObjective = objective;
            }
            long end = System.currentTimeMillis();
            totalTime += end - start;
            if (iteration > 1) {
                log.trace("Iteration {} -- Objective: {}, Violated Constraints: 0, Gradient Norm: {}, Iteration Time: {}, Total Optimization Time: {}", iteration - 1, Float.valueOf(objective), Float.valueOf(MathUtils.pNorm(this.prevGradient, this.firstOrderNorm)), end - start, totalTime);
            }
            ++iteration;
        }
        Reasoner.ObjectiveResult finalObjective = this.computeObjective(termStore);
        if (finalObjective.objective < lowestObjective) {
            lowestObjective = finalObjective.objective;
            lowestVariableValues = prevVariableValues;
        }
        float[] variableValues = termStore.getVariableValues();
        System.arraycopy(lowestVariableValues, 0, variableValues, 0, variableValues.length);
        this.optimizationComplete(termStore, new Reasoner.ObjectiveResult(lowestObjective, 0L), totalTime, iteration - 1);
        return lowestObjective;
    }

    @Override
    protected void initForOptimization(TermStore<SGDObjectiveTerm> termStore) {
        super.initForOptimization(termStore);
        switch (this.sgdExtension) {
            case NONE: {
                break;
            }
            case ADAGRAD: {
                this.accumulatedGradientSquares = new float[termStore.getVariableCounts().unobserved];
                break;
            }
            case ADAM: {
                int unobservedCount = termStore.getVariableCounts().unobserved;
                this.accumulatedGradientMean = new float[unobservedCount];
                this.accumulatedGradientVariance = new float[unobservedCount];
                break;
            }
            default: {
                throw new IllegalArgumentException(String.format("Unsupported SGD Extensions: '%s'", new Object[]{this.sgdExtension}));
            }
        }
    }

    @Override
    protected void optimizationComplete(TermStore<SGDObjectiveTerm> termStore, Reasoner.ObjectiveResult finalObjective, long totalTime, int iteration) {
        super.optimizationComplete(termStore, finalObjective, totalTime, iteration);
        this.prevGradient = null;
        this.accumulatedGradientSquares = null;
        this.accumulatedGradientMean = null;
        this.accumulatedGradientVariance = null;
    }

    @Override
    protected boolean breakOptimization(int iteration, TermStore<SGDObjectiveTerm> termStore, Reasoner.ObjectiveResult objective, Reasoner.ObjectiveResult oldObjective) {
        if (super.breakOptimization(iteration, termStore, objective, oldObjective)) {
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if (objective != null && objective.violatedConstraints > 0L) {
            return false;
        }
        if (this.firstOrderBreak && MathUtils.equals(MathUtils.pNorm(this.prevGradient, this.firstOrderNorm), 0.0f, this.firstOrderTolerance)) {
            log.trace("Breaking optimization. Gradient magnitude: {} below tolerance: {}.", Float.valueOf(MathUtils.pNorm(this.prevGradient, this.firstOrderNorm)), Float.valueOf(this.firstOrderTolerance));
            return true;
        }
        return false;
    }

    private void addTermGradient(SGDObjectiveTerm term, float[] gradient, float[] variableValues, GroundAtom[] variableAtoms) {
        int size = term.size();
        int[] variableIndexes = term.getAtomIndexes();
        float innerPotential = term.computeInnerPotential(variableValues);
        for (int i = 0; i < size; ++i) {
            if (variableAtoms[variableIndexes[i]].isFixed()) continue;
            int n = variableIndexes[i];
            gradient[n] = gradient[n] + term.computeVariablePartial(i, innerPotential);
        }
    }

    private float calculateAnnealedLearningRate(int iteration) {
        switch (this.learningSchedule) {
            case CONSTANT: {
                return this.initialLearningRate;
            }
            case STEPDECAY: {
                return this.initialLearningRate / (float)Math.pow(iteration, this.learningRateInverseScaleExp);
            }
        }
        throw new IllegalArgumentException(String.format("Illegal value found for SGD learning schedule: '%s'", new Object[]{this.learningSchedule}));
    }

    private void variableUpdate(SGDObjectiveTerm term, TermStore termStore, int iteration, float learningRate) {
        float variableStep = 0.0f;
        float newValue = 0.0f;
        float partial = 0.0f;
        GroundAtom[] variableAtoms = termStore.getVariableAtoms();
        float[] variableValues = termStore.getVariableValues();
        int size = term.size();
        int[] variableIndexes = term.getAtomIndexes();
        float innerPotential = term.computeInnerPotential(variableValues);
        for (int i = 0; i < size; ++i) {
            if (variableAtoms[variableIndexes[i]].isFixed()) continue;
            partial = term.computeVariablePartial(i, innerPotential);
            variableStep = this.computeVariableStep(variableIndexes[i], iteration, learningRate, partial);
            variableValues[variableIndexes[i]] = newValue = Math.max(0.0f, Math.min(1.0f, variableValues[variableIndexes[i]] - variableStep));
            if (!this.coordinateStep) continue;
            innerPotential = term.computeInnerPotential(variableValues);
        }
    }

    private float computeVariableStep(int variableIndex, int iteration, float learningRate, float partial) {
        float step = 0.0f;
        float adaptedLearningRate = 0.0f;
        switch (this.sgdExtension) {
            case NONE: {
                step = partial * learningRate;
                break;
            }
            case ADAGRAD: {
                this.accumulatedGradientSquares = ArrayUtils.ensureCapacity(this.accumulatedGradientSquares, variableIndex);
                this.accumulatedGradientSquares[variableIndex] = this.accumulatedGradientSquares[variableIndex] + partial * partial;
                adaptedLearningRate = learningRate / (float)Math.sqrt(this.accumulatedGradientSquares[variableIndex] + 1.0E-8f);
                step = partial * adaptedLearningRate;
                break;
            }
            case ADAM: {
                float biasedGradientMean = 0.0f;
                float biasedGradientVariance = 0.0f;
                this.accumulatedGradientMean = ArrayUtils.ensureCapacity(this.accumulatedGradientMean, variableIndex);
                this.accumulatedGradientMean[variableIndex] = this.adamBeta1 * this.accumulatedGradientMean[variableIndex] + (1.0f - this.adamBeta1) * partial;
                this.accumulatedGradientVariance = ArrayUtils.ensureCapacity(this.accumulatedGradientVariance, variableIndex);
                this.accumulatedGradientVariance[variableIndex] = this.adamBeta2 * this.accumulatedGradientVariance[variableIndex] + (1.0f - this.adamBeta2) * partial * partial;
                biasedGradientMean = this.accumulatedGradientMean[variableIndex] / (1.0f - (float)Math.pow(this.adamBeta1, iteration));
                biasedGradientVariance = this.accumulatedGradientVariance[variableIndex] / (1.0f - (float)Math.pow(this.adamBeta2, iteration));
                adaptedLearningRate = learningRate / ((float)Math.sqrt(biasedGradientVariance) + 1.0E-8f);
                step = biasedGradientMean * adaptedLearningRate;
                break;
            }
            default: {
                throw new IllegalArgumentException(String.format("Unsupported SGD Extensions: '%s'", new Object[]{this.sgdExtension}));
            }
        }
        return step;
    }

    public static enum SGDLearningSchedule {
        CONSTANT,
        STEPDECAY;

    }

    public static enum SGDExtension {
        NONE,
        ADAGRAD,
        ADAM;

    }
}

