/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm;

import org.linqs.psl.config.Config;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.admm.term.LinearConstraintTerm;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;
import org.linqs.psl.util.RandUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ADMMReasoner
implements Reasoner {
    private static final Logger log = LoggerFactory.getLogger(ADMMReasoner.class);
    public static final String CONFIG_PREFIX = "admmreasoner";
    public static final String MAX_ITER_KEY = "admmreasoner.maxiterations";
    public static final int MAX_ITER_DEFAULT = 25000;
    public static final String COMPUTE_PERIOD_KEY = "admmreasoner.computeperiod";
    public static final int COMPUTE_PERIOD_DEFAULT = 50;
    public static final String STEP_SIZE_KEY = "admmreasoner.stepsize";
    public static final float STEP_SIZE_DEFAULT = 1.0f;
    public static final String EPSILON_ABS_KEY = "admmreasoner.epsilonabs";
    public static final float EPSILON_ABS_DEFAULT = 1.0E-5f;
    public static final String EPSILON_REL_KEY = "admmreasoner.epsilonrel";
    public static final float EPSILON_REL_DEFAULT = 0.001f;
    public static final String OBJECTIVE_BREAK_KEY = "admmreasoner.objectivebreak";
    public static final boolean OBJECTIVE_BREAK_DEFAULT = true;
    public static final String INITIAL_CONSENSUS_VALUE_KEY = "admmreasoner.initialconsensusvalue";
    public static final String INITIAL_CONSENSUS_VALUE_DEFAULT = InitialValue.RANDOM.toString();
    public static final String INITIAL_LOCAL_VALUE_KEY = "admmreasoner.initiallocalvalue";
    public static final String INITIAL_LOCAL_VALUE_DEFAULT = InitialValue.RANDOM.toString();
    private static final float LOWER_BOUND = 0.0f;
    private static final float UPPER_BOUND = 1.0f;
    private int computePeriod;
    private final float stepSize;
    private float epsilonRel;
    private float epsilonAbs;
    private float primalRes;
    private float epsilonPrimal;
    private float dualRes;
    private float epsilonDual;
    private float AxNorm;
    private float AyNorm;
    private float BzNorm;
    private float lagrangePenalty;
    private float augmentedLagrangePenalty;
    private int maxIter = Config.getInt("admmreasoner.maxiterations", 25000);
    private float[] consensusValues;
    private int termBlockSize;
    private int variableBlockSize;
    private boolean objectiveBreak;

    public ADMMReasoner() {
        this.stepSize = Config.getFloat(STEP_SIZE_KEY, 1.0f);
        this.computePeriod = Config.getInt(COMPUTE_PERIOD_KEY, 50);
        this.objectiveBreak = Config.getBoolean(OBJECTIVE_BREAK_KEY, true);
        this.epsilonAbs = Config.getFloat(EPSILON_ABS_KEY, 1.0E-5f);
        if (this.epsilonAbs <= 0.0f) {
            throw new IllegalArgumentException("Property admmreasoner.epsilonabs must be positive.");
        }
        this.epsilonRel = Config.getFloat(EPSILON_REL_KEY, 0.001f);
        if (this.epsilonRel <= 0.0f) {
            throw new IllegalArgumentException("Property admmreasoner.epsilonrel must be positive.");
        }
    }

    public int getMaxIter() {
        return this.maxIter;
    }

    public void setMaxIter(int maxIter) {
        this.maxIter = maxIter;
    }

    public float getEpsilonRel() {
        return this.epsilonRel;
    }

    public void setEpsilonRel(float epsilonRel) {
        this.epsilonRel = epsilonRel;
    }

    public float getEpsilonAbs() {
        return this.epsilonAbs;
    }

    public void setEpsilonAbs(float epsilonAbs) {
        this.epsilonAbs = epsilonAbs;
    }

    public float getLagrangianPenalty() {
        return this.lagrangePenalty;
    }

    public float getAugmentedLagrangianPenalty() {
        return this.augmentedLagrangePenalty;
    }

    @Override
    public void optimize(TermStore baseTermStore) {
        InitialValue initialConsensus = InitialValue.valueOf(Config.getString(INITIAL_CONSENSUS_VALUE_KEY, INITIAL_CONSENSUS_VALUE_DEFAULT).toUpperCase());
        InitialValue initialLocal = InitialValue.valueOf(Config.getString(INITIAL_LOCAL_VALUE_KEY, INITIAL_LOCAL_VALUE_DEFAULT).toUpperCase());
        this.optimize(baseTermStore, initialConsensus, initialLocal);
    }

    public void optimize(TermStore baseTermStore, InitialValue initialConsensus, InitialValue initialLocal) {
        int iteration;
        if (!(baseTermStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException("ADMMReasoner requires an ADMMTermStore (found " + baseTermStore.getClass().getName() + ").");
        }
        ADMMTermStore termStore = (ADMMTermStore)baseTermStore;
        termStore.resetLocalVairables(initialLocal);
        int numTerms = termStore.size();
        int numVariables = termStore.getNumGlobalVariables();
        log.debug("Performing optimization with {} variables and {} terms.", (Object)numVariables, (Object)numTerms);
        this.initConsensusValues(termStore, initialConsensus);
        this.termBlockSize = numTerms / (Parallel.getNumThreads() * 4) + 1;
        this.variableBlockSize = numVariables / (Parallel.getNumThreads() * 4) + 1;
        int numTermBlocks = (int)Math.ceil((float)numTerms / (float)this.termBlockSize);
        int numVariableBlocks = (int)Math.ceil((float)numVariables / (float)this.variableBlockSize);
        float epsilonAbsTerm = (float)(Math.sqrt(termStore.getNumLocalVariables()) * (double)this.epsilonAbs);
        float objective = 0.0f;
        float oldObjective = 0.0f;
        for (iteration = 1; !(iteration != 1 && !(this.primalRes > this.epsilonPrimal) && !(this.dualRes > this.epsilonDual) || this.objectiveBreak && !MathUtils.isZero(oldObjective) && MathUtils.equals(objective, oldObjective) || iteration > this.maxIter); ++iteration) {
            this.primalRes = 0.0f;
            this.dualRes = 0.0f;
            this.AxNorm = 0.0f;
            this.AyNorm = 0.0f;
            this.BzNorm = 0.0f;
            this.lagrangePenalty = 0.0f;
            this.augmentedLagrangePenalty = 0.0f;
            Parallel.count(numTermBlocks, new TermWorker(termStore, this.termBlockSize));
            Parallel.count(numVariableBlocks, new VariableWorker(termStore, this.variableBlockSize));
            this.primalRes = (float)Math.sqrt(this.primalRes);
            this.dualRes = (float)((double)this.stepSize * Math.sqrt(this.dualRes));
            this.epsilonPrimal = (float)((double)epsilonAbsTerm + (double)this.epsilonRel * Math.max(Math.sqrt(this.AxNorm), Math.sqrt(this.BzNorm)));
            this.epsilonDual = (float)((double)epsilonAbsTerm + (double)this.epsilonRel * Math.sqrt(this.AyNorm));
            if (iteration % this.computePeriod != 0) continue;
            if (!this.objectiveBreak) {
                log.trace("Iteration {} -- Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}.", iteration, Float.valueOf(this.primalRes), Float.valueOf(this.dualRes), Float.valueOf(this.epsilonPrimal), Float.valueOf(this.epsilonDual));
                continue;
            }
            oldObjective = objective;
            objective = 0.0f;
            boolean feasible = true;
            for (ADMMObjectiveTerm term : termStore) {
                if (term instanceof LinearConstraintTerm) {
                    if (!(term.evaluate() > 0.0f)) continue;
                    feasible = false;
                    continue;
                }
                objective += 1.0f - term.evaluate();
            }
            log.trace("Iteration {} -- Objective: {}, Feasible: {}, Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}.", iteration, Float.valueOf(objective), feasible, Float.valueOf(this.primalRes), Float.valueOf(this.dualRes), Float.valueOf(this.epsilonPrimal), Float.valueOf(this.epsilonDual));
        }
        objective = 0.0f;
        int infeasibleCount = 0;
        for (ADMMObjectiveTerm term : termStore) {
            if (term instanceof LinearConstraintTerm) {
                if (!(term.evaluate() > 0.0f)) continue;
                ++infeasibleCount;
                continue;
            }
            objective += 1.0f - term.evaluate();
        }
        if (infeasibleCount > 0) {
            log.warn("No feasible solution found. {} constraints violated.", (Object)infeasibleCount);
        }
        log.info("Optimization completed in {} iterations. Objective: {}, Feasible: {}, Primal res.: {}, Dual res.: {}", iteration - 1, Float.valueOf(objective), infeasibleCount == 0, Float.valueOf(this.primalRes), Float.valueOf(this.dualRes));
        termStore.updateVariables(this.consensusValues);
    }

    @Override
    public void close() {
    }

    public double getDualIncompatibility(GroundRule groundRule, ADMMTermStore termStore, float[] consensusBuffer) {
        if (consensusBuffer == null) {
            consensusBuffer = new float[termStore.getNumGlobalVariables()];
        }
        assert (consensusBuffer.length == this.consensusValues.length);
        for (ADMMObjectiveTerm term : termStore.getTerms(groundRule)) {
            for (LocalVariable localVariable : term.getVariables()) {
                consensusBuffer[localVariable.getGlobalId()] = localVariable.getValue();
            }
        }
        termStore.updateVariables(consensusBuffer);
        double incompatibility = ((WeightedGroundRule)groundRule).getIncompatibility();
        termStore.updateVariables(this.consensusValues);
        return incompatibility;
    }

    private void initConsensusValues(ADMMTermStore termStore, InitialValue initialConsensus) {
        this.consensusValues = new float[termStore.getNumGlobalVariables()];
        if (initialConsensus == InitialValue.ZERO) {
            for (int i = 0; i < this.consensusValues.length; ++i) {
                this.consensusValues[i] = 0.0f;
            }
        } else if (initialConsensus == InitialValue.RANDOM) {
            for (int i = 0; i < this.consensusValues.length; ++i) {
                this.consensusValues[i] = RandUtils.nextFloat();
            }
        } else if (initialConsensus == InitialValue.ATOM) {
            termStore.getAtomValues(this.consensusValues);
        } else {
            throw new IllegalStateException("Unknown initial consensus value: " + (Object)((Object)initialConsensus));
        }
    }

    private synchronized void updateIterationVariables(float primalRes, float dualRes, float AxNorm, float BzNorm, float AyNorm, float lagrangePenalty, float augmentedLagrangePenalty) {
        this.primalRes += primalRes;
        this.dualRes += dualRes;
        this.AxNorm += AxNorm;
        this.AyNorm += AyNorm;
        this.BzNorm += BzNorm;
        this.lagrangePenalty += lagrangePenalty;
        this.augmentedLagrangePenalty += augmentedLagrangePenalty;
    }

    private class VariableWorker
    extends Parallel.Worker<Integer> {
        private ADMMTermStore termStore;
        private int blockSize;

        public VariableWorker(ADMMTermStore termStore, int blockSize) {
            this.termStore = termStore;
            this.blockSize = blockSize;
        }

        public Object clone() {
            return new VariableWorker(this.termStore, this.blockSize);
        }

        @Override
        public void work(int blockIndex, Integer ignore) {
            int variableIndex;
            int numVariables = this.termStore.getNumGlobalVariables();
            float primalResInc = 0.0f;
            float dualResInc = 0.0f;
            float AxNormInc = 0.0f;
            float BzNormInc = 0.0f;
            float AyNormInc = 0.0f;
            float lagrangePenaltyInc = 0.0f;
            float augmentedLagrangePenaltyInc = 0.0f;
            for (int innerBlockIndex = 0; innerBlockIndex < this.blockSize && (variableIndex = blockIndex * this.blockSize + innerBlockIndex) < numVariables; ++innerBlockIndex) {
                float total = 0.0f;
                int numLocalVariables = this.termStore.getLocalVariables(variableIndex).size();
                for (int localVarIndex = 0; localVarIndex < numLocalVariables; ++localVarIndex) {
                    LocalVariable localVariable = this.termStore.getLocalVariables(variableIndex).get(localVarIndex);
                    total += localVariable.getValue() + localVariable.getLagrange() / ADMMReasoner.this.stepSize;
                    AxNormInc += localVariable.getValue() * localVariable.getValue();
                    AyNormInc += localVariable.getLagrange() * localVariable.getLagrange();
                }
                float newConsensusValue = total / (float)numLocalVariables;
                newConsensusValue = Math.max(Math.min(newConsensusValue, 1.0f), 0.0f);
                float diff = ADMMReasoner.this.consensusValues[variableIndex] - newConsensusValue;
                dualResInc += diff * diff * (float)numLocalVariables;
                BzNormInc += newConsensusValue * newConsensusValue * (float)numLocalVariables;
                ((ADMMReasoner)ADMMReasoner.this).consensusValues[variableIndex] = newConsensusValue;
                for (int localVarIndex = 0; localVarIndex < numLocalVariables; ++localVarIndex) {
                    LocalVariable localVariable = this.termStore.getLocalVariables(variableIndex).get(localVarIndex);
                    diff = localVariable.getValue() - newConsensusValue;
                    primalResInc += diff * diff;
                    lagrangePenaltyInc += localVariable.getLagrange() * (localVariable.getValue() - ADMMReasoner.this.consensusValues[variableIndex]);
                    augmentedLagrangePenaltyInc = (float)((double)augmentedLagrangePenaltyInc + 0.5 * (double)ADMMReasoner.this.stepSize * Math.pow(localVariable.getValue() - ADMMReasoner.this.consensusValues[variableIndex], 2.0));
                }
            }
            ADMMReasoner.this.updateIterationVariables(primalResInc, dualResInc, AxNormInc, BzNormInc, AyNormInc, lagrangePenaltyInc, augmentedLagrangePenaltyInc);
        }
    }

    private class TermWorker
    extends Parallel.Worker<Integer> {
        private ADMMTermStore termStore;
        private int blockSize;

        public TermWorker(ADMMTermStore termStore, int blockSize) {
            this.termStore = termStore;
            this.blockSize = blockSize;
        }

        public Object clone() {
            return new TermWorker(this.termStore, this.blockSize);
        }

        @Override
        public void work(int blockIndex, Integer ignore) {
            int termIndex;
            int numTerms = this.termStore.size();
            for (int innerBlockIndex = 0; innerBlockIndex < this.blockSize && (termIndex = blockIndex * this.blockSize + innerBlockIndex) < numTerms; ++innerBlockIndex) {
                this.termStore.get(termIndex).updateLagrange(ADMMReasoner.this.stepSize, ADMMReasoner.this.consensusValues);
                this.termStore.get(termIndex).minimize(ADMMReasoner.this.stepSize, ADMMReasoner.this.consensusValues);
            }
        }
    }

    public static enum InitialValue {
        ZERO,
        RANDOM,
        ATOM;

    }
}

