/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm;

import org.linqs.psl.config.Options;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.admm.term.LinearConstraintTerm;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ADMMReasoner
extends Reasoner {
    private static final Logger log = LoggerFactory.getLogger(ADMMReasoner.class);
    private static final float LOWER_BOUND = 0.0f;
    private static final float UPPER_BOUND = 1.0f;
    private int computePeriod;
    private final float stepSize;
    private float epsilonRel;
    private float epsilonAbs;
    private float primalRes;
    private float epsilonPrimal;
    private float dualRes;
    private float epsilonDual;
    private float AxNorm;
    private float AyNorm;
    private float BzNorm;
    private float lagrangePenalty;
    private float augmentedLagrangePenalty;
    private int maxIterations = Options.ADMM_MAX_ITER.getInt();
    private int termBlockSize;
    private int variableBlockSize;

    public ADMMReasoner() {
        this.stepSize = Options.ADMM_STEP_SIZE.getFloat();
        this.computePeriod = Options.ADMM_COMPUTE_PERIOD.getInt();
        this.epsilonAbs = Options.ADMM_EPSILON_ABS.getFloat();
        this.epsilonRel = Options.ADMM_EPSILON_REL.getFloat();
    }

    public float getEpsilonRel() {
        return this.epsilonRel;
    }

    public void setEpsilonRel(float epsilonRel) {
        this.epsilonRel = epsilonRel;
    }

    public float getEpsilonAbs() {
        return this.epsilonAbs;
    }

    public void setEpsilonAbs(float epsilonAbs) {
        this.epsilonAbs = epsilonAbs;
    }

    public float getLagrangianPenalty() {
        return this.lagrangePenalty;
    }

    public float getAugmentedLagrangianPenalty() {
        return this.augmentedLagrangePenalty;
    }

    @Override
    public void optimize(TermStore baseTermStore) {
        if (!(baseTermStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException("ADMMReasoner requires an ADMMTermStore (found " + baseTermStore.getClass().getName() + ").");
        }
        ADMMTermStore termStore = (ADMMTermStore)baseTermStore;
        termStore.initForOptimization();
        int numTerms = termStore.size();
        int numVariables = termStore.getNumGlobalVariables();
        log.debug("Performing optimization with {} variables and {} terms.", (Object)numVariables, (Object)numTerms);
        this.termBlockSize = numTerms / (Parallel.getNumThreads() * 4) + 1;
        this.variableBlockSize = numVariables / (Parallel.getNumThreads() * 4) + 1;
        int numTermBlocks = (int)Math.ceil((float)numTerms / (float)this.termBlockSize);
        int numVariableBlocks = (int)Math.ceil((float)numVariables / (float)this.variableBlockSize);
        float epsilonAbsTerm = (float)(Math.sqrt(termStore.getNumLocalVariables()) * (double)this.epsilonAbs);
        ObjectiveResult objective = null;
        ObjectiveResult oldObjective = null;
        if (log.isTraceEnabled()) {
            objective = this.computeObjective(termStore, false);
            log.trace("Iteration {} -- Objective: {}, Feasible: {}.", 0, Float.valueOf(objective.objective), objective.violatedConstraints == 0);
        }
        int iteration = 1;
        do {
            this.primalRes = 0.0f;
            this.dualRes = 0.0f;
            this.AxNorm = 0.0f;
            this.AyNorm = 0.0f;
            this.BzNorm = 0.0f;
            this.lagrangePenalty = 0.0f;
            this.augmentedLagrangePenalty = 0.0f;
            Parallel.count(numTermBlocks, new TermWorker(termStore, this.termBlockSize));
            Parallel.count(numVariableBlocks, new VariableWorker(termStore, this.variableBlockSize));
            this.primalRes = (float)Math.sqrt(this.primalRes);
            this.dualRes = (float)((double)this.stepSize * Math.sqrt(this.dualRes));
            this.epsilonPrimal = (float)((double)epsilonAbsTerm + (double)this.epsilonRel * Math.max(Math.sqrt(this.AxNorm), Math.sqrt(this.BzNorm)));
            this.epsilonDual = (float)((double)epsilonAbsTerm + (double)this.epsilonRel * Math.sqrt(this.AyNorm));
            if (iteration % this.computePeriod == 0) {
                if (!this.objectiveBreak) {
                    log.trace("Iteration {} -- Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}.", iteration, Float.valueOf(this.primalRes), Float.valueOf(this.dualRes), Float.valueOf(this.epsilonPrimal), Float.valueOf(this.epsilonDual));
                } else {
                    oldObjective = objective;
                    objective = this.computeObjective(termStore, false);
                    log.trace("Iteration {} -- Objective: {}, Feasible: {}, Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}.", iteration, Float.valueOf(objective.objective), objective.violatedConstraints == 0, Float.valueOf(this.primalRes), Float.valueOf(this.dualRes), Float.valueOf(this.epsilonPrimal), Float.valueOf(this.epsilonDual));
                }
            }
            termStore.iterationComplete();
        } while (!this.breakOptimization(++iteration, objective, oldObjective) || !this.breakOptimization(iteration, objective = this.computeObjective(termStore, false), oldObjective));
        log.info("Optimization completed in {} iterations. Objective: {}, Feasible: {}, Primal res.: {}, Dual res.: {}", iteration - 1, Float.valueOf(objective.objective), objective.violatedConstraints == 0, Float.valueOf(this.primalRes), Float.valueOf(this.dualRes));
        if (objective.violatedConstraints > 0) {
            log.warn("No feasible solution found. {} constraints violated.", (Object)objective.violatedConstraints);
            this.computeObjective(termStore, true);
        }
        termStore.updateVariables();
    }

    private boolean breakOptimization(int iteration, ObjectiveResult objective, ObjectiveResult oldObjective) {
        if (iteration > (int)((double)this.maxIterations * this.budget)) {
            return true;
        }
        if (objective != null && objective.violatedConstraints > 0) {
            return false;
        }
        if (iteration > 1 && this.primalRes < this.epsilonPrimal && this.dualRes < this.epsilonDual) {
            return true;
        }
        return this.objectiveBreak && oldObjective != null && MathUtils.equals(objective.objective, oldObjective.objective, this.tolerance);
    }

    @Override
    public void close() {
    }

    private ObjectiveResult computeObjective(ADMMTermStore termStore, boolean logViolatedConstraints) {
        float objective = 0.0f;
        int violatedConstraints = 0;
        float[] consensusValues = termStore.getConsensusValues();
        for (ADMMObjectiveTerm term : termStore) {
            if (term instanceof LinearConstraintTerm) {
                if (!(term.evaluate(consensusValues) > 0.0f)) continue;
                ++violatedConstraints;
                if (!logViolatedConstraints) continue;
                log.trace("    {}", (Object)term.getGroundRule());
                continue;
            }
            objective += term.evaluate(consensusValues);
        }
        return new ObjectiveResult(objective, violatedConstraints);
    }

    private synchronized void updateIterationVariables(float primalRes, float dualRes, float AxNorm, float BzNorm, float AyNorm, float lagrangePenalty, float augmentedLagrangePenalty) {
        this.primalRes += primalRes;
        this.dualRes += dualRes;
        this.AxNorm += AxNorm;
        this.AyNorm += AyNorm;
        this.BzNorm += BzNorm;
        this.lagrangePenalty += lagrangePenalty;
        this.augmentedLagrangePenalty += augmentedLagrangePenalty;
    }

    private static class ObjectiveResult {
        public final float objective;
        public final int violatedConstraints;

        public ObjectiveResult(float objective, int violatedConstraints) {
            this.objective = objective;
            this.violatedConstraints = violatedConstraints;
        }
    }

    private class VariableWorker
    extends Parallel.Worker<Integer> {
        private final ADMMTermStore termStore;
        private final int blockSize;
        private final float[] consensusValues;

        public VariableWorker(ADMMTermStore termStore, int blockSize) {
            this.termStore = termStore;
            this.blockSize = blockSize;
            this.consensusValues = termStore.getConsensusValues();
        }

        public Object clone() {
            return new VariableWorker(this.termStore, this.blockSize);
        }

        @Override
        public void work(int blockIndex, Integer ignore) {
            int variableIndex;
            int numVariables = this.termStore.getNumGlobalVariables();
            float primalResInc = 0.0f;
            float dualResInc = 0.0f;
            float AxNormInc = 0.0f;
            float BzNormInc = 0.0f;
            float AyNormInc = 0.0f;
            float lagrangePenaltyInc = 0.0f;
            float augmentedLagrangePenaltyInc = 0.0f;
            for (int innerBlockIndex = 0; innerBlockIndex < this.blockSize && (variableIndex = blockIndex * this.blockSize + innerBlockIndex) < numVariables; ++innerBlockIndex) {
                float total = 0.0f;
                int numLocalVariables = this.termStore.getLocalVariables(variableIndex).size();
                for (int localVarIndex = 0; localVarIndex < numLocalVariables; ++localVarIndex) {
                    LocalVariable localVariable = this.termStore.getLocalVariables(variableIndex).get(localVarIndex);
                    total += localVariable.getValue() + localVariable.getLagrange() / ADMMReasoner.this.stepSize;
                    AxNormInc += localVariable.getValue() * localVariable.getValue();
                    AyNormInc += localVariable.getLagrange() * localVariable.getLagrange();
                }
                float newConsensusValue = total / (float)numLocalVariables;
                newConsensusValue = Math.max(Math.min(newConsensusValue, 1.0f), 0.0f);
                float diff = this.consensusValues[variableIndex] - newConsensusValue;
                dualResInc += diff * diff * (float)numLocalVariables;
                BzNormInc += newConsensusValue * newConsensusValue * (float)numLocalVariables;
                this.consensusValues[variableIndex] = newConsensusValue;
                for (int localVarIndex = 0; localVarIndex < numLocalVariables; ++localVarIndex) {
                    LocalVariable localVariable = this.termStore.getLocalVariables(variableIndex).get(localVarIndex);
                    diff = localVariable.getValue() - newConsensusValue;
                    primalResInc += diff * diff;
                    lagrangePenaltyInc += localVariable.getLagrange() * (localVariable.getValue() - this.consensusValues[variableIndex]);
                    augmentedLagrangePenaltyInc = (float)((double)augmentedLagrangePenaltyInc + 0.5 * (double)ADMMReasoner.this.stepSize * Math.pow(localVariable.getValue() - this.consensusValues[variableIndex], 2.0));
                }
            }
            ADMMReasoner.this.updateIterationVariables(primalResInc, dualResInc, AxNormInc, BzNormInc, AyNormInc, lagrangePenaltyInc, augmentedLagrangePenaltyInc);
        }
    }

    private class TermWorker
    extends Parallel.Worker<Integer> {
        private final ADMMTermStore termStore;
        private final int blockSize;
        private final float[] consensusValues;

        public TermWorker(ADMMTermStore termStore, int blockSize) {
            this.termStore = termStore;
            this.blockSize = blockSize;
            this.consensusValues = termStore.getConsensusValues();
        }

        public Object clone() {
            return new TermWorker(this.termStore, this.blockSize);
        }

        @Override
        public void work(int blockIndex, Integer ignore) {
            int termIndex;
            int numTerms = this.termStore.size();
            for (int innerBlockIndex = 0; innerBlockIndex < this.blockSize && (termIndex = blockIndex * this.blockSize + innerBlockIndex) < numTerms; ++innerBlockIndex) {
                this.termStore.get(termIndex).updateLagrange(ADMMReasoner.this.stepSize, this.consensusValues);
                this.termStore.get(termIndex).minimize(ADMMReasoner.this.stepSize, this.consensusValues);
            }
        }
    }
}

