/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight;

import java.util.List;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class VotedPerceptron
extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(VotedPerceptron.class);
    public static final String CONFIG_PREFIX = "votedperceptron";
    public static final String L2_REGULARIZATION_KEY = "votedperceptron.l2regularization";
    public static final double L2_REGULARIZATION_DEFAULT = 0.0;
    public static final String L1_REGULARIZATION_KEY = "votedperceptron.l1regularization";
    public static final double L1_REGULARIZATION_DEFAULT = 0.0;
    public static final String STEP_SIZE_KEY = "votedperceptron.stepsize";
    public static final double STEP_SIZE_DEFAULT = 0.2;
    public static final String INERTIA_KEY = "votedperceptron.inertia";
    public static final double INERTIA_DEFAULT = 0.0;
    public static final String SCALE_GRADIENT_KEY = "votedperceptron.scalegradient";
    public static final boolean SCALE_GRADIENT_DEFAULT = true;
    public static final String AVERAGE_STEPS_KEY = "votedperceptron.averagesteps";
    public static final boolean AVERAGE_STEPS_DEFAULT = false;
    public static final String NUM_STEPS_KEY = "votedperceptron.numsteps";
    public static final int NUM_STEPS_DEFAULT = 25;
    public static final String CLIP_NEGATIVE_WEIGHTS_KEY = "votedperceptron.clipnegativeweights";
    public static final boolean CLIP_NEGATIVE_WEIGHTS_DEFAULT = true;
    public static final String CUT_OBJECTIVE_KEY = "votedperceptron.cutobjective";
    public static final boolean CUT_OBJECTIVE_DEFAULT = false;
    public static final String SCALE_STEP_SIZE_KEY = "votedperceptron.scalestepsize";
    public static final boolean SCALE_STEP_SIZE_DEFAULT = true;
    public static final String ZERO_INITIAL_WEIGHTS_KEY = "votedperceptron.zeroinitialweights";
    public static final boolean ZERO_INITIAL_WEIGHTS_DEFAULT = false;
    protected final double l2Regularization;
    protected final double l1Regularization;
    protected final boolean scaleGradient;
    protected double baseStepSize = Config.getDouble("votedperceptron.stepsize", 0.2);
    protected boolean scaleStepSize;
    protected boolean averageSteps;
    protected boolean zeroInitialWeights;
    protected boolean clipNegativeWeights;
    protected boolean cutObjective;
    protected double inertia;
    protected final int maxNumSteps;
    protected int numSteps;
    private double currentLoss;

    public VotedPerceptron(List<Rule> rules, Database rvDB, Database observedDB, boolean supportsLatentVariables) {
        super(rules, rvDB, observedDB, supportsLatentVariables);
        if (this.baseStepSize <= 0.0) {
            throw new IllegalArgumentException("Step size must be positive.");
        }
        this.inertia = Config.getDouble(INERTIA_KEY, 0.0);
        if (this.inertia < 0.0 || this.inertia >= 1.0) {
            throw new IllegalArgumentException("Inertia must be in [0, 1), found: " + this.inertia);
        }
        this.maxNumSteps = this.numSteps = Config.getInt(NUM_STEPS_KEY, 25);
        if (this.numSteps <= 0) {
            throw new IllegalArgumentException("Number of steps must be positive.");
        }
        this.l2Regularization = Config.getDouble(L2_REGULARIZATION_KEY, 0.0);
        if (this.l2Regularization < 0.0) {
            throw new IllegalArgumentException("L2 regularization parameter must be non-negative.");
        }
        this.l1Regularization = Config.getDouble(L1_REGULARIZATION_KEY, 0.0);
        if (this.l1Regularization < 0.0) {
            throw new IllegalArgumentException("L1 regularization parameter must be non-negative.");
        }
        this.scaleGradient = Config.getBoolean(SCALE_GRADIENT_KEY, true);
        this.averageSteps = Config.getBoolean(AVERAGE_STEPS_KEY, false);
        this.scaleStepSize = Config.getBoolean(SCALE_STEP_SIZE_KEY, true);
        this.zeroInitialWeights = Config.getBoolean(ZERO_INITIAL_WEIGHTS_KEY, false);
        this.clipNegativeWeights = Config.getBoolean(CLIP_NEGATIVE_WEIGHTS_KEY, true);
        this.cutObjective = Config.getBoolean(CUT_OBJECTIVE_KEY, false);
        this.currentLoss = Double.NaN;
    }

    @Override
    protected void doLearn() {
        int i;
        double[] avgWeights = new double[this.mutableRules.size()];
        this.computeObservedIncompatibility();
        this.setDefaultRandomVariables();
        if (this.zeroInitialWeights) {
            for (WeightedRule rule : this.mutableRules) {
                rule.setWeight(0.0);
            }
        }
        if (log.isDebugEnabled() && this.evaluator != null) {
            this.computeMPEState();
            this.evaluator.compute(this.trainingMap);
            double objective = this.evaluator.getRepresentativeMetric();
            objective = this.evaluator.isHigherRepresentativeBetter() ? -1.0 * objective : objective;
            log.debug("Initial Training Objective: {}", (Object)objective);
        }
        double[] scalingFactor = this.computeScalingFactor();
        double[] lastSteps = new double[this.mutableRules.size()];
        double lastObjective = -1.0;
        double[] lastWeights = new double[this.mutableRules.size()];
        for (i = 0; i < this.mutableRules.size(); ++i) {
            lastWeights[i] = ((WeightedRule)this.mutableRules.get(i)).getWeight();
        }
        for (int step = 0; step < this.numSteps; ++step) {
            int i2;
            log.debug("Starting iteration {}", (Object)step);
            this.currentLoss = Double.NaN;
            this.computeExpectedIncompatibility();
            double norm = 0.0;
            for (int i3 = 0; i3 < this.mutableRules.size(); ++i3) {
                double newWeight = ((WeightedRule)this.mutableRules.get(i3)).getWeight();
                double currentStep = (this.expectedIncompatibility[i3] - this.observedIncompatibility[i3] - this.l2Regularization * newWeight - this.l1Regularization) / scalingFactor[i3];
                currentStep *= this.baseStepSize;
                if (this.scaleStepSize) {
                    currentStep /= (double)(step + 1);
                }
                newWeight = this.clipNegativeWeights ? Math.max(0.0, newWeight + currentStep) : (newWeight += (currentStep += this.inertia * lastSteps[i3]));
                log.trace("Gradient: {} (without momentun: {}), Expected Incomp.: {}, Observed Incomp.: {} -- ({}) {}", currentStep, currentStep - this.inertia * lastSteps[i3], this.expectedIncompatibility[i3], this.observedIncompatibility[i3], i3, this.mutableRules.get(i3));
                ((WeightedRule)this.mutableRules.get(i3)).setWeight(newWeight);
                lastSteps[i3] = currentStep;
                int n = i3;
                avgWeights[n] = avgWeights[n] + newWeight;
                norm += Math.pow(this.expectedIncompatibility[i3] - this.observedIncompatibility[i3], 2.0);
            }
            this.inMPEState = false;
            this.inLatentMPEState = false;
            norm = Math.sqrt(norm);
            if (log.isDebugEnabled()) {
                this.getLoss();
            }
            double objective = -1.0;
            if ((this.cutObjective || log.isDebugEnabled()) && this.evaluator != null) {
                this.computeMPEState();
                this.evaluator.compute(this.trainingMap);
                objective = this.evaluator.getRepresentativeMetric();
                double d = objective = this.evaluator.isHigherRepresentativeBetter() ? -1.0 * objective : objective;
                if (this.cutObjective && step > 0 && objective > lastObjective) {
                    log.trace("Objective increased: {} -> {}, cutting step size: {} -> {}.", lastObjective, objective, this.baseStepSize, this.baseStepSize / 2.0);
                    this.baseStepSize /= 2.0;
                    objective = lastObjective;
                    for (i2 = 0; i2 < this.mutableRules.size(); ++i2) {
                        lastSteps[i2] = 0.0;
                        int n = i2;
                        avgWeights[n] = avgWeights[n] - ((WeightedRule)this.mutableRules.get(i2)).getWeight();
                        ((WeightedRule)this.mutableRules.get(i2)).setWeight(lastWeights[i2]);
                    }
                } else {
                    lastObjective = objective;
                }
            }
            for (i2 = 0; i2 < this.mutableRules.size(); ++i2) {
                lastWeights[i2] = ((WeightedRule)this.mutableRules.get(i2)).getWeight();
            }
            log.debug("Iteration {} complete. Likelihood: {}. Training Objective: {}, Icomp. L2-norm: {}", step, this.currentLoss, objective, norm);
            log.trace("Model {} ", (Object)this.mutableRules);
        }
        if (this.averageSteps) {
            for (i = 0; i < this.mutableRules.size(); ++i) {
                ((WeightedRule)this.mutableRules.get(i)).setWeight(avgWeights[i] / (double)this.numSteps);
            }
        }
    }

    protected double computeRegularizer() {
        if (this.l1Regularization == 0.0 && this.l2Regularization == 0.0) {
            return 0.0;
        }
        double l2 = 0.0;
        double l1 = 0.0;
        for (WeightedRule rule : this.mutableRules) {
            l2 += Math.pow(rule.getWeight(), 2.0);
            l1 += Math.abs(rule.getWeight());
        }
        return 0.5 * this.l2Regularization * l2 + this.l1Regularization * l1;
    }

    public double getLoss() {
        if (Double.isNaN(this.currentLoss)) {
            this.currentLoss = this.computeLoss();
        }
        return this.currentLoss;
    }

    protected double[] computeScalingFactor() {
        double[] factor = new double[this.mutableRules.size()];
        for (int i = 0; i < factor.length; ++i) {
            factor[i] = Math.max(1.0, (double)this.groundRuleStore.count((Rule)this.mutableRules.get(i)));
        }
        return factor;
    }

    @Override
    public void setBudget(double budget) {
        super.setBudget(budget);
        this.numSteps = (int)Math.ceil(budget * (double)this.maxNumSteps);
    }
}

