/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm.term;

import java.util.HashMap;
import java.util.Map;
import org.linqs.psl.model.rule.FakeRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.reasoner.term.Hyperplane;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.util.FloatMatrix;
import org.linqs.psl.util.HashCode;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.RandUtils;

public class ADMMObjectiveTerm
implements ReasonerTerm {
    protected final TermType termType;
    protected final Rule rule;
    protected int size;
    private float[] coefficients;
    private LocalVariable[] variables;
    private boolean squared;
    private boolean hinge;
    private float deterConstant;
    private float constant;
    private FunctionComparator comparator;
    private float[] consensusOptimizer;
    private float[] unitNormal;
    private static Map<Integer, FloatMatrix> lowerTriangleCache = new HashMap<Integer, FloatMatrix>();

    private ADMMObjectiveTerm(Hyperplane<LocalVariable> hyperplane, Rule rule, boolean squared, boolean hinge, boolean collectiveDeter, float deterConstant, FunctionComparator comparator) {
        this.rule = rule;
        this.squared = squared;
        this.hinge = hinge;
        this.deterConstant = deterConstant;
        this.comparator = comparator;
        this.size = hyperplane.size();
        this.variables = (LocalVariable[])hyperplane.getVariables();
        this.coefficients = hyperplane.getCoefficients();
        this.constant = hyperplane.getConstant();
        this.termType = this.getTermType(collectiveDeter);
        if (this.termType == TermType.HingeLossTerm || this.termType == TermType.LinearConstraintTerm) {
            this.initUnitNormal();
        }
    }

    public static ADMMObjectiveTerm createLinearConstraintTerm(Hyperplane<LocalVariable> hyperplane, Rule rule, FunctionComparator comparator) {
        return new ADMMObjectiveTerm(hyperplane, rule, false, false, false, 0.0f, comparator);
    }

    public static ADMMObjectiveTerm createLinearLossTerm(Hyperplane<LocalVariable> hyperplane, Rule rule) {
        return new ADMMObjectiveTerm(hyperplane, rule, false, false, false, 0.0f, null);
    }

    public static ADMMObjectiveTerm createHingeLossTerm(Hyperplane<LocalVariable> hyperplane, Rule rule) {
        return new ADMMObjectiveTerm(hyperplane, rule, false, true, false, 0.0f, null);
    }

    public static ADMMObjectiveTerm createSquaredLinearLossTerm(Hyperplane<LocalVariable> hyperplane, Rule rule) {
        return new ADMMObjectiveTerm(hyperplane, rule, true, false, false, 0.0f, null);
    }

    public static ADMMObjectiveTerm createSquaredHingeLossTerm(Hyperplane<LocalVariable> hyperplane, Rule rule) {
        return new ADMMObjectiveTerm(hyperplane, rule, true, true, false, 0.0f, null);
    }

    public static ADMMObjectiveTerm createCollectiveDeterTerm(Hyperplane<LocalVariable> hyperplane, float deterWeight, float deterConstant) {
        return new ADMMObjectiveTerm(hyperplane, new FakeRule(deterWeight, false), false, false, true, deterConstant, null);
    }

    public static ADMMObjectiveTerm createIndependentDeterTerm(Hyperplane<LocalVariable> hyperplane, float deterWeight, float deterConstant) {
        return new ADMMObjectiveTerm(hyperplane, new FakeRule(deterWeight, false), false, false, false, deterConstant, null);
    }

    public void updateLagrange(float stepSize, float[] consensusValues) {
        for (int i = 0; i < this.size; ++i) {
            LocalVariable variable = this.variables[i];
            variable.setLagrange(variable.getLagrange() + stepSize * (variable.getValue() - consensusValues[variable.getGlobalId()]));
        }
    }

    public LocalVariable[] getVariables() {
        return this.variables;
    }

    @Override
    public int size() {
        return this.size;
    }

    @Override
    public void adjustConstant(float oldValue, float newValue) {
        this.constant = this.constant - oldValue + newValue;
    }

    public boolean isConstraint() {
        return this.termType == TermType.LinearConstraintTerm;
    }

    @Override
    public boolean isConvex() {
        return this.termType != TermType.DeterCollectiveTerm && this.termType != TermType.DeterIndependentTerm;
    }

    private TermType getTermType(boolean collectiveDeter) {
        if (this.comparator != null) {
            return TermType.LinearConstraintTerm;
        }
        if (!MathUtils.isZero(this.deterConstant)) {
            if (collectiveDeter) {
                return TermType.DeterCollectiveTerm;
            }
            return TermType.DeterIndependentTerm;
        }
        if (!this.squared && !this.hinge) {
            return TermType.LinearLossTerm;
        }
        if (!this.squared && this.hinge) {
            return TermType.HingeLossTerm;
        }
        if (this.squared && !this.hinge) {
            return TermType.SquaredLinearLossTerm;
        }
        if (this.squared && this.hinge) {
            return TermType.SquaredHingeLossTerm;
        }
        throw new IllegalStateException("Unknown term type.");
    }

    public void minimize(float stepSize, float[] consensusValues) {
        float weight = this.getWeight();
        switch (this.termType) {
            case LinearConstraintTerm: {
                this.minimizeConstraint(stepSize, consensusValues);
                break;
            }
            case LinearLossTerm: {
                this.minimizeLinearLoss(stepSize, weight, consensusValues);
                break;
            }
            case HingeLossTerm: {
                this.minimizeHingeLoss(stepSize, weight, consensusValues);
                break;
            }
            case SquaredLinearLossTerm: {
                this.minimizeSquaredLinearLoss(stepSize, weight, consensusValues);
                break;
            }
            case SquaredHingeLossTerm: {
                this.minimizeSquaredHingeLoss(stepSize, weight, consensusValues);
                break;
            }
            case DeterCollectiveTerm: {
                this.minimizeCollectiveDeter(stepSize, weight, consensusValues);
                break;
            }
            case DeterIndependentTerm: {
                this.minimizeIndependentDeter(stepSize, weight, consensusValues);
                break;
            }
            default: {
                throw new IllegalStateException("Unknown term type.");
            }
        }
    }

    public float evaluate() {
        float weight = this.getWeight();
        switch (this.termType) {
            case LinearConstraintTerm: {
                return this.evaluateConstraint();
            }
            case LinearLossTerm: {
                return this.evaluateLinearLoss(weight);
            }
            case HingeLossTerm: {
                return this.evaluateHingeLoss(weight);
            }
            case SquaredLinearLossTerm: {
                return this.evaluateSquaredLinearLoss(weight);
            }
            case SquaredHingeLossTerm: {
                return this.evaluateSquaredHingeLoss(weight);
            }
            case DeterCollectiveTerm: {
                return this.evaluateCollectiveDeter(weight);
            }
            case DeterIndependentTerm: {
                return this.evaluateIndependentDeter(weight);
            }
        }
        throw new IllegalStateException("Unknown term type.");
    }

    public float evaluate(float[] consensusValues) {
        float weight = this.getWeight();
        switch (this.termType) {
            case LinearConstraintTerm: {
                return this.evaluateConstraint(consensusValues);
            }
            case LinearLossTerm: {
                return this.evaluateLinearLoss(weight, consensusValues);
            }
            case HingeLossTerm: {
                return this.evaluateHingeLoss(weight, consensusValues);
            }
            case SquaredLinearLossTerm: {
                return this.evaluateSquaredLinearLoss(weight, consensusValues);
            }
            case SquaredHingeLossTerm: {
                return this.evaluateSquaredHingeLoss(weight, consensusValues);
            }
            case DeterCollectiveTerm: {
                return this.evaluateCollectiveDeter(weight, consensusValues);
            }
            case DeterIndependentTerm: {
                return this.evaluateIndependentDeter(weight, consensusValues);
            }
        }
        throw new IllegalStateException("Unknown term type.");
    }

    private void minimizeConstraint(float stepSize, float[] consensusValues) {
        if (!this.comparator.equals((Object)FunctionComparator.EQ)) {
            float total = 0.0f;
            for (int i = 0; i < this.size; ++i) {
                LocalVariable variable = this.variables[i];
                variable.setValue(consensusValues[variable.getGlobalId()] - variable.getLagrange() / stepSize);
                total += this.coefficients[i] * variable.getValue();
            }
            if (this.comparator.equals((Object)FunctionComparator.LTE) && total <= this.constant || this.comparator.equals((Object)FunctionComparator.GTE) && total >= this.constant) {
                return;
            }
        }
        this.project(stepSize, consensusValues);
    }

    private float evaluateConstraint() {
        return this.evaluateConstraint(null);
    }

    private float evaluateConstraint(float[] consensusValues) {
        float value = 0.0f;
        value = consensusValues == null ? this.computeInnerPotential() : this.computeInnerPotential(consensusValues);
        if (this.comparator.equals((Object)FunctionComparator.EQ)) {
            if (MathUtils.isZero((double)value, 0.005)) {
                return 0.0f;
            }
            return Float.POSITIVE_INFINITY;
        }
        if (this.comparator.equals((Object)FunctionComparator.LTE)) {
            if (value <= 0.0f) {
                return 0.0f;
            }
            return Float.POSITIVE_INFINITY;
        }
        if (this.comparator.equals((Object)FunctionComparator.GTE)) {
            if (value >= 0.0f) {
                return 0.0f;
            }
            return Float.POSITIVE_INFINITY;
        }
        throw new IllegalStateException("Unknown comparison function.");
    }

    private void minimizeLinearLoss(float stepSize, float weight, float[] consensusValues) {
        for (int i = 0; i < this.size; ++i) {
            LocalVariable variable = this.variables[i];
            float value = consensusValues[variable.getGlobalId()] - variable.getLagrange() / stepSize - weight * this.coefficients[i] / stepSize;
            variable.setValue(value);
        }
    }

    private float evaluateLinearLoss(float weight) {
        return weight * this.computeInnerPotential();
    }

    private float evaluateLinearLoss(float weight, float[] consensusValues) {
        return weight * this.computeInnerPotential(consensusValues);
    }

    private void minimizeHingeLoss(float stepSize, float weight, float[] consensusValues) {
        LocalVariable variable;
        int i;
        float total = 0.0f;
        for (i = 0; i < this.size; ++i) {
            variable = this.variables[i];
            variable.setValue(consensusValues[variable.getGlobalId()] - variable.getLagrange() / stepSize);
            total += this.coefficients[i] * variable.getValue();
        }
        if (total <= this.constant) {
            return;
        }
        total = 0.0f;
        for (i = 0; i < this.size; ++i) {
            variable = this.variables[i];
            variable.setValue(consensusValues[variable.getGlobalId()] - variable.getLagrange() / stepSize - weight * this.coefficients[i] / stepSize);
            total += this.coefficients[i] * variable.getValue();
        }
        if (total >= this.constant) {
            return;
        }
        this.project(stepSize, consensusValues);
    }

    private float evaluateHingeLoss(float weight) {
        return weight * Math.max(0.0f, this.computeInnerPotential());
    }

    private float evaluateHingeLoss(float weight, float[] consensusValues) {
        return weight * Math.max(0.0f, this.computeInnerPotential(consensusValues));
    }

    private void minimizeSquaredLinearLoss(float stepSize, float weight, float[] consensusValues) {
        this.minWeightedSquaredHyperplane(stepSize, weight, consensusValues);
    }

    private float evaluateSquaredLinearLoss(float weight) {
        return weight * (float)Math.pow(this.computeInnerPotential(), 2.0);
    }

    private float evaluateSquaredLinearLoss(float weight, float[] consensusValues) {
        return weight * (float)Math.pow(this.computeInnerPotential(consensusValues), 2.0);
    }

    private void minimizeSquaredHingeLoss(float stepSize, float weight, float[] consensusValues) {
        float total = 0.0f;
        for (int i = 0; i < this.size; ++i) {
            LocalVariable variable = this.variables[i];
            variable.setValue(consensusValues[variable.getGlobalId()] - variable.getLagrange() / stepSize);
            total += this.coefficients[i] * variable.getValue();
        }
        if (total <= this.constant) {
            return;
        }
        this.minWeightedSquaredHyperplane(stepSize, weight, consensusValues);
    }

    private float evaluateSquaredHingeLoss(float weight) {
        return weight * (float)Math.pow(Math.max(0.0f, this.computeInnerPotential()), 2.0);
    }

    private float evaluateSquaredHingeLoss(float weight, float[] consensusValues) {
        return weight * (float)Math.pow(Math.max(0.0f, this.computeInnerPotential(consensusValues)), 2.0);
    }

    private void minimizeCollectiveDeter(float stepSize, float weight, float[] consensusValues) {
        float deterValue = 1.0f / (float)this.size;
        float distance = 0.0f;
        for (int i = 0; i < this.size; ++i) {
            distance += Math.abs(deterValue - consensusValues[this.variables[i].getGlobalId()]);
        }
        if ((distance /= (float)this.size) > this.deterConstant) {
            return;
        }
        int upPoint = RandUtils.nextInt(this.size);
        for (int i = 0; i < this.size; ++i) {
            float value = i == upPoint ? 1.0f : 0.0f;
            this.variables[i].setValue(value);
        }
    }

    private float evaluateCollectiveDeter(float weight) {
        float deterValue = 1.0f / (float)this.size;
        float value = 0.0f;
        for (int i = 0; i < this.size; ++i) {
            float variableValue = this.variables[i].getValue();
            if (variableValue > deterValue) {
                value += 1.0f - variableValue;
                continue;
            }
            value += variableValue;
        }
        return weight * (1.0f / (float)this.size) * value;
    }

    private float evaluateCollectiveDeter(float weight, float[] consensusValues) {
        float deterValue = 1.0f / (float)this.size;
        float value = 0.0f;
        for (int i = 0; i < this.size; ++i) {
            float variableValue = consensusValues[this.variables[i].getGlobalId()];
            if (variableValue > deterValue) {
                value += 1.0f - variableValue;
                continue;
            }
            value += variableValue;
        }
        return weight * (1.0f / (float)this.size) * value;
    }

    private void minimizeIndependentDeter(float stepSize, float weight, float[] consensusValues) {
        for (int i = 0; i < this.size; ++i) {
            LocalVariable variable = this.variables[i];
            float value = 0.0f;
            value = variable.getValue() > this.deterConstant ? consensusValues[variable.getGlobalId()] - variable.getLagrange() / stepSize + weight * this.coefficients[i] / stepSize : consensusValues[variable.getGlobalId()] - variable.getLagrange() / stepSize - weight * this.coefficients[i] / stepSize;
            variable.setValue(value);
        }
    }

    private float evaluateIndependentDeter(float weight) {
        float rawDissatisfaction = this.computeInnerPotential();
        float dissatisfaction = 1.0f - Math.abs(rawDissatisfaction - this.deterConstant);
        return weight * dissatisfaction;
    }

    private float evaluateIndependentDeter(float weight, float[] consensusValues) {
        float rawDissatisfaction = this.computeInnerPotential(consensusValues);
        float dissatisfaction = 1.0f - Math.abs(rawDissatisfaction - this.deterConstant);
        return weight * dissatisfaction;
    }

    private void initUnitNormal() {
        int i;
        if (this.size == 1) {
            this.consensusOptimizer = null;
            this.unitNormal = null;
            return;
        }
        this.consensusOptimizer = new float[this.size];
        this.unitNormal = new float[this.size];
        float length = 0.0f;
        for (i = 0; i < this.size; ++i) {
            length += this.coefficients[i] * this.coefficients[i];
        }
        length = (float)Math.sqrt(length);
        for (i = 0; i < this.size; ++i) {
            this.unitNormal[i] = this.coefficients[i] / length;
        }
    }

    private float computeInnerPotential() {
        float value = 0.0f;
        for (int i = 0; i < this.size; ++i) {
            value += this.coefficients[i] * this.variables[i].getValue();
        }
        return value - this.constant;
    }

    private float computeInnerPotential(float[] consensusValues) {
        float value = 0.0f;
        for (int i = 0; i < this.size; ++i) {
            value += this.coefficients[i] * consensusValues[this.variables[i].getGlobalId()];
        }
        return value - this.constant;
    }

    private void project(float stepSize, float[] consensusValues) {
        int i;
        if (this.size == 1) {
            this.variables[0].setValue(this.constant / this.coefficients[0]);
            return;
        }
        for (int i2 = 0; i2 < this.size; ++i2) {
            this.consensusOptimizer[i2] = consensusValues[this.variables[i2].getGlobalId()] - this.variables[i2].getLagrange() / stepSize;
        }
        float length = this.coefficients[0] / this.unitNormal[0];
        float multiplier = -1.0f * this.constant / length;
        for (i = 0; i < this.size; ++i) {
            multiplier += this.consensusOptimizer[i] * this.unitNormal[i];
        }
        for (i = 0; i < this.size; ++i) {
            this.variables[i].setValue(this.consensusOptimizer[i] - multiplier * this.unitNormal[i]);
        }
    }

    private void minWeightedSquaredHyperplane(float stepSize, float weight, float[] consensusValues) {
        int j;
        float newValue;
        int i;
        for (int i2 = 0; i2 < this.size; ++i2) {
            float value = stepSize * consensusValues[this.variables[i2].getGlobalId()] - this.variables[i2].getLagrange() + 2.0f * weight * this.coefficients[i2] * this.constant;
            this.variables[i2].setValue(value);
        }
        if (this.size == 1) {
            LocalVariable variable = this.variables[0];
            float coefficient = this.coefficients[0];
            variable.setValue(variable.getValue() / (2.0f * weight * coefficient * coefficient + stepSize));
            return;
        }
        if (this.size == 2) {
            LocalVariable variable0 = this.variables[0];
            LocalVariable variable1 = this.variables[1];
            float coefficient0 = this.coefficients[0];
            float coefficient1 = this.coefficients[1];
            float a0 = 2.0f * weight * coefficient0 * coefficient0 + stepSize;
            float b1 = 2.0f * weight * coefficient1 * coefficient1 + stepSize;
            float a1b0 = 2.0f * weight * coefficient0 * coefficient1;
            variable1.setValue(variable1.getValue() - a1b0 * variable0.getValue() / a0);
            variable1.setValue(variable1.getValue() / (b1 - a1b0 * a1b0 / a0));
            variable0.setValue((variable0.getValue() - a1b0 * variable1.getValue()) / a0);
            return;
        }
        FloatMatrix lowerTriangle = this.fetchLowerTriangle(stepSize, weight);
        for (i = 0; i < this.size; ++i) {
            newValue = this.variables[i].getValue();
            for (j = 0; j < i; ++j) {
                newValue -= lowerTriangle.get(i, j) * this.variables[j].getValue();
            }
            this.variables[i].setValue(newValue / lowerTriangle.get(i, i));
        }
        for (i = this.size - 1; i >= 0; --i) {
            newValue = this.variables[i].getValue();
            for (j = this.size - 1; j > i; --j) {
                newValue -= lowerTriangle.get(j, i) * this.variables[j].getValue();
            }
            this.variables[i].setValue(newValue / lowerTriangle.get(i, i));
        }
    }

    private FloatMatrix fetchLowerTriangle(float stepSize, float weight) {
        int hash = HashCode.build(Float.valueOf(weight));
        hash = HashCode.build(hash, Float.valueOf(stepSize));
        for (int i = 0; i < this.size; ++i) {
            hash = HashCode.build(hash, Float.valueOf(this.coefficients[i]));
        }
        FloatMatrix lowerTriangle = lowerTriangleCache.get(hash);
        if (lowerTriangle != null) {
            return lowerTriangle;
        }
        return this.computeLowerTriangle(stepSize, weight, hash);
    }

    private synchronized FloatMatrix computeLowerTriangle(float stepSize, float weight, int hash) {
        if (lowerTriangleCache.containsKey(hash)) {
            return lowerTriangleCache.get(hash);
        }
        float coefficient = 0.0f;
        FloatMatrix matrix = FloatMatrix.zeroes(this.size, this.size);
        for (int i = 0; i < this.size; ++i) {
            for (int j = i; j < this.size; ++j) {
                if (i == j) {
                    coefficient = 2.0f * weight * this.coefficients[i] * this.coefficients[i] + stepSize;
                    matrix.set(i, i, coefficient);
                    continue;
                }
                coefficient = 2.0f * weight * this.coefficients[i] * this.coefficients[j];
                matrix.set(i, j, coefficient);
                matrix.set(j, i, coefficient);
            }
        }
        matrix.choleskyDecomposition(true);
        lowerTriangleCache.put(hash, matrix);
        return matrix;
    }

    private float getWeight() {
        if (this.rule != null && this.rule.isWeighted()) {
            return ((WeightedRule)this.rule).getWeight();
        }
        return Float.POSITIVE_INFINITY;
    }

    public static enum TermType {
        LinearConstraintTerm,
        LinearLossTerm,
        HingeLossTerm,
        SquaredLinearLossTerm,
        SquaredHingeLossTerm,
        DeterCollectiveTerm,
        DeterIndependentTerm;

    }
}

