/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm;

import java.util.List;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.Parallel;

public class ADMMReasoner
extends Reasoner<ADMMObjectiveTerm> {
    private static final Logger log = Logger.getLogger(ADMMReasoner.class);
    private static final float LOWER_BOUND = 0.0f;
    private static final float UPPER_BOUND = 1.0f;
    private int computePeriod;
    private final float stepSize;
    private boolean primalDualBreak;
    private double epsilonRel;
    private double epsilonAbs;
    private double primalRes;
    private double epsilonPrimal;
    private double dualRes;
    private double epsilonDual;
    private double AxNorm;
    private double AyNorm;
    private double BzNorm;
    private long termBlockSize;
    private long variableBlockSize;

    public ADMMReasoner() {
        this.maxIterations = Options.ADMM_MAX_ITER.getInt();
        this.primalDualBreak = Options.ADMM_PRIMAL_DUAL_BREAK.getBoolean();
        this.stepSize = Options.ADMM_STEP_SIZE.getFloat();
        this.computePeriod = Options.ADMM_COMPUTE_PERIOD.getInt();
        this.epsilonAbs = Options.ADMM_EPSILON_ABS.getDouble();
        this.epsilonRel = Options.ADMM_EPSILON_REL.getDouble();
    }

    @Override
    public double optimize(TermStore<ADMMObjectiveTerm> baseTermStore, List<EvaluationInstance> evaluations, TrainingMap trainingMap) {
        if (!(baseTermStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException("ADMMReasoner requires an ADMMTermStore (found " + baseTermStore.getClass().getName() + ").");
        }
        ADMMTermStore termStore = (ADMMTermStore)baseTermStore;
        termStore.initForOptimization();
        this.initForOptimization(termStore);
        long numTerms = termStore.size();
        int numVariables = termStore.getNumVariables();
        this.termBlockSize = numTerms / (long)(Parallel.getNumThreads() * 4) + 1L;
        this.variableBlockSize = numVariables / (Parallel.getNumThreads() * 4) + 1;
        long numTermBlocks = (long)Math.ceil((double)numTerms / (double)this.termBlockSize);
        long numVariableBlocks = (long)Math.ceil((double)numVariables / (double)this.variableBlockSize);
        double epsilonAbsTerm = Math.sqrt(termStore.getNumLocalVariables()) * this.epsilonAbs;
        Reasoner.ObjectiveResult objective = null;
        Reasoner.ObjectiveResult oldObjective = null;
        boolean breakADMM = false;
        long totalTime = 0L;
        int iteration = 1;
        while (!breakADMM) {
            long start = System.currentTimeMillis();
            this.primalRes = 0.0;
            this.dualRes = 0.0;
            this.AxNorm = 0.0;
            this.AyNorm = 0.0;
            this.BzNorm = 0.0;
            Parallel.count(numTermBlocks, new TermWorker(termStore, this.termBlockSize));
            Parallel.count(numVariableBlocks, new VariableWorker(termStore, this.variableBlockSize, numVariables));
            this.primalRes = Math.sqrt(this.primalRes);
            this.dualRes = (double)this.stepSize * Math.sqrt(this.dualRes);
            this.epsilonPrimal = epsilonAbsTerm + this.epsilonRel * Math.max(Math.sqrt(this.AxNorm), Math.sqrt(this.BzNorm));
            this.epsilonDual = epsilonAbsTerm + this.epsilonRel * Math.sqrt(this.AyNorm);
            long end = System.currentTimeMillis();
            totalTime += end - start;
            breakADMM = this.breakOptimization(iteration, termStore, objective, oldObjective);
            if (iteration % this.computePeriod == 0 || breakADMM) {
                oldObjective = objective;
                objective = this.parallelComputeObjective(termStore);
                if (objective.violatedConstraints > 0L && iteration <= (int)((double)this.maxIterations * this.budget)) {
                    breakADMM = false;
                }
                log.trace("Iteration {} -- Objective: {}, Violated Constraints: {}, Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}, Iteration Time: {}, Total Optimization Time: {}.", iteration, Float.valueOf(objective.objective), objective.violatedConstraints, this.primalRes, this.dualRes, this.epsilonPrimal, this.epsilonDual, end - start, totalTime);
                this.evaluate(termStore, iteration, evaluations, trainingMap);
            }
            ++iteration;
        }
        this.optimizationComplete(termStore, objective, totalTime, iteration - 1);
        return objective.objective;
    }

    @Override
    protected boolean breakOptimization(int iteration, TermStore<ADMMObjectiveTerm> termStore, Reasoner.ObjectiveResult objective, Reasoner.ObjectiveResult oldObjective) {
        if (super.breakOptimization(iteration, termStore, objective, oldObjective)) {
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if (objective != null && objective.violatedConstraints > 0L) {
            return false;
        }
        if (this.primalDualBreak && iteration > 1 && this.primalRes < this.epsilonPrimal && this.dualRes < this.epsilonDual) {
            log.trace("Breaking optimization. Primal residual: {} below tolerance: {} and dual residual: {} below tolerance: {}.", this.primalRes, this.epsilonPrimal, this.dualRes, this.epsilonDual);
            return true;
        }
        return false;
    }

    private synchronized void updateIterationVariables(double primalRes, double dualRes, double AxNorm, double BzNorm, double AyNorm) {
        this.primalRes += primalRes;
        this.dualRes += dualRes;
        this.AxNorm += AxNorm;
        this.AyNorm += AyNorm;
        this.BzNorm += BzNorm;
    }

    private class VariableWorker
    extends Parallel.Worker<Long> {
        private final ADMMTermStore termStore;
        private final long blockSize;
        private final int numVariables;
        private final float[] consensusValues;
        private final GroundAtom[] consensusAtoms;

        public VariableWorker(ADMMTermStore termStore, long blockSize, int numVariables) {
            this.termStore = termStore;
            this.blockSize = blockSize;
            this.numVariables = numVariables;
            this.consensusValues = termStore.getVariableValues();
            this.consensusAtoms = termStore.getVariableAtoms();
        }

        public Object clone() {
            return new VariableWorker(this.termStore, this.blockSize, this.numVariables);
        }

        @Override
        public void work(long blockIndex, Long ignore) {
            int variableIndex;
            double primalResInc = 0.0;
            double dualResInc = 0.0;
            double AxNormInc = 0.0;
            double BzNormInc = 0.0;
            double AyNormInc = 0.0;
            int innerBlockIndex = 0;
            while ((long)innerBlockIndex < this.blockSize && (variableIndex = (int)(blockIndex * this.blockSize + (long)innerBlockIndex)) < this.numVariables) {
                List<ADMMTermStore.LocalRecord> localRecords = this.termStore.getLocalRecords(variableIndex);
                if (localRecords != null) {
                    double total = 0.0;
                    int numLocalVariables = 0;
                    for (ADMMTermStore.LocalRecord localRecord : localRecords) {
                        ADMMObjectiveTerm term = (ADMMObjectiveTerm)this.termStore.get(localRecord.termIndex);
                        if (!term.isActive()) continue;
                        float localValue = term.getVariableValue(localRecord.variableIndex);
                        float localLagrange = term.getVariableLagrange(localRecord.variableIndex);
                        total += (double)(localValue + localLagrange / ADMMReasoner.this.stepSize);
                        AxNormInc += (double)(localValue * localValue);
                        AyNormInc += (double)(localLagrange * localLagrange);
                        ++numLocalVariables;
                    }
                    if (numLocalVariables != 0) {
                        float newConsensusValue = 0.0f;
                        if (this.consensusAtoms[variableIndex].isFixed()) {
                            newConsensusValue = this.consensusValues[variableIndex];
                        } else {
                            newConsensusValue = (float)(total / (double)numLocalVariables);
                            newConsensusValue = Math.max(Math.min(newConsensusValue, 1.0f), 0.0f);
                        }
                        float diff = this.consensusValues[variableIndex] - newConsensusValue;
                        dualResInc += (double)(diff * diff * (float)numLocalVariables);
                        BzNormInc += (double)(newConsensusValue * newConsensusValue * (float)numLocalVariables);
                        this.consensusValues[variableIndex] = newConsensusValue;
                        for (ADMMTermStore.LocalRecord localRecord : localRecords) {
                            ADMMObjectiveTerm term = (ADMMObjectiveTerm)this.termStore.get(localRecord.termIndex);
                            if (!term.isActive()) continue;
                            float localValue = term.getVariableValue(localRecord.variableIndex);
                            diff = localValue - newConsensusValue;
                            primalResInc += (double)(diff * diff);
                        }
                    }
                }
                ++innerBlockIndex;
            }
            ADMMReasoner.this.updateIterationVariables(primalResInc, dualResInc, AxNormInc, BzNormInc, AyNormInc);
        }
    }

    private class TermWorker
    extends Parallel.Worker<Long> {
        private final ADMMTermStore termStore;
        private final long blockSize;
        private final float[] consensusValues;

        public TermWorker(ADMMTermStore termStore, long blockSize) {
            this.termStore = termStore;
            this.blockSize = blockSize;
            this.consensusValues = termStore.getVariableValues();
        }

        public Object clone() {
            return new TermWorker(this.termStore, this.blockSize);
        }

        @Override
        public void work(long blockIndex, Long ignore) {
            long termIndex;
            long numTerms = this.termStore.size();
            int innerBlockIndex = 0;
            while ((long)innerBlockIndex < this.blockSize && (termIndex = blockIndex * this.blockSize + (long)innerBlockIndex) < numTerms) {
                ADMMObjectiveTerm term = (ADMMObjectiveTerm)this.termStore.get(termIndex);
                if (term.isActive()) {
                    term.updateLagrange(ADMMReasoner.this.stepSize, this.consensusValues);
                    term.minimize(ADMMReasoner.this.stepSize, this.consensusValues);
                }
                ++innerBlockIndex;
            }
        }
    }
}

