/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.maxlikelihood;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.VotedPerceptron;
import org.linqs.psl.application.learning.weight.maxlikelihood.SimplexSampler;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.grounding.AtomRegisterGroundRuleStore;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTerm;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTermGenerator;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTermStore;

public class MaxPseudoLikelihood
extends VotedPerceptron {
    public static final String CONFIG_PREFIX = "maxspeudolikelihood";
    public static final String BOOLEAN_KEY = "maxspeudolikelihood.bool";
    public static final boolean BOOLEAN_DEFAULT = false;
    public static final String NUM_SAMPLES_KEY = "maxspeudolikelihood.numsamples";
    public static final int NUM_SAMPLES_DEFAULT = 10;
    public static final String MIN_WIDTH_KEY = "maxspeudolikelihood.minwidth";
    public static final double MIN_WIDTH_DEFAULT = 0.01;
    private final boolean bool = Config.getBoolean("maxspeudolikelihood.bool", false);
    private final double minWidth;
    private final int maxNumSamples;
    private int numSamples = this.maxNumSamples = Config.getInt("maxspeudolikelihood.numsamples", 10);

    public MaxPseudoLikelihood(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    public MaxPseudoLikelihood(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB, false);
        if (this.numSamples <= 0) {
            throw new IllegalArgumentException("Number of samples must be positive integer.");
        }
        this.minWidth = Config.getDouble(MIN_WIDTH_KEY, 0.01);
        if (this.minWidth <= 0.0) {
            throw new IllegalArgumentException("Minimum width must be positive double.");
        }
        Config.setProperty("weightlearning.groundrulestore", AtomRegisterGroundRuleStore.class.getName());
        Config.setProperty("weightlearning.termstore", ConstraintBlockerTermStore.class.getName());
        Config.setProperty("weightlearning.termgenerator", ConstraintBlockerTermGenerator.class.getName());
        this.cutObjective = false;
    }

    @Override
    protected void computeExpectedIncompatibility() {
        if (!(this.termStore instanceof ConstraintBlockerTermStore)) {
            throw new IllegalArgumentException("ConstraintBlockerTermStore required.");
        }
        ConstraintBlockerTermStore blocker = (ConstraintBlockerTermStore)this.termStore;
        for (int i = 0; i < this.expectedIncompatibility.length; ++i) {
            this.expectedIncompatibility[i] = 0.0;
        }
        for (ConstraintBlockerTerm block : blocker) {
            double[][] samples;
            if (block.size() == 0) continue;
            if (!this.bool) {
                samples = new double[Math.max(this.numSamples * block.size(), 150)][];
                SimplexSampler simplexSampler = new SimplexSampler();
                for (int sampleIndex = 0; sampleIndex < samples.length; ++sampleIndex) {
                    samples[sampleIndex] = simplexSampler.getNext(samples.length);
                }
            } else {
                samples = new double[block.getExactlyOne() ? block.size() : block.size() + 1][];
                for (int iRV = 0; iRV < (block.getExactlyOne() ? samples.length : samples.length - 1); ++iRV) {
                    samples[iRV] = new double[block.size()];
                    samples[iRV][iRV] = 1.0;
                }
                if (!block.getExactlyOne()) {
                    samples[samples.length - 1] = new double[block.size()];
                }
            }
            HashMap<WeightedRule, double[]> incompatibilities = new HashMap<WeightedRule, double[]>();
            float[] originalState = new float[block.size()];
            for (int i = 0; i < block.size(); ++i) {
                originalState[i] = block.getAtoms()[i].getValue();
            }
            for (WeightedGroundRule groundRule : block.getIncidentGRs()) {
                if (!(groundRule instanceof WeightedGroundRule)) continue;
                WeightedRule rule = (WeightedRule)groundRule.getRule();
                if (!incompatibilities.containsKey(rule)) {
                    incompatibilities.put(rule, new double[samples.length]);
                }
                double[] inc = (double[])incompatibilities.get(rule);
                int sampleIndex = 0;
                while (sampleIndex < samples.length) {
                    for (int i = 0; i < block.size(); ++i) {
                        block.getAtoms()[i].setValue((float)samples[sampleIndex][i]);
                    }
                    int n = sampleIndex++;
                    inc[n] = inc[n] + groundRule.getIncompatibility();
                }
            }
            for (int i = 0; i < block.size(); ++i) {
                block.getAtoms()[i].setValue(originalState[i]);
            }
            HashMap<WeightedRule, Double> expIncAtom = new HashMap<WeightedRule, Double>();
            double partition = 0.0;
            for (int j = 0; j < samples.length; ++j) {
                double sum = 0.0;
                for (Map.Entry e2 : incompatibilities.entrySet()) {
                    WeightedRule rule = (WeightedRule)e2.getKey();
                    double[] inc = (double[])e2.getValue();
                    sum -= rule.getWeight() * inc[j];
                }
                double exp = Math.exp(sum);
                partition += exp;
                for (Map.Entry e2 : incompatibilities.entrySet()) {
                    WeightedRule rule = (WeightedRule)e2.getKey();
                    if (!expIncAtom.containsKey(rule)) {
                        expIncAtom.put(rule, 0.0);
                    }
                    double val = (Double)expIncAtom.get(rule);
                    expIncAtom.put(rule, val += exp * ((double[])incompatibilities.get(rule))[j]);
                }
            }
            for (int i = 0; i < this.mutableRules.size(); ++i) {
                WeightedRule rule = (WeightedRule)this.mutableRules.get(i);
                if (!expIncAtom.containsKey(rule) || !((Double)expIncAtom.get(rule) > 0.0)) continue;
                int n = i;
                this.expectedIncompatibility[n] = this.expectedIncompatibility[n] + (Double)expIncAtom.get(rule) / partition;
            }
        }
    }

    @Override
    public void setBudget(double budget) {
        super.setBudget(budget);
        this.numSamples = (int)Math.ceil(budget * (double)this.maxNumSamples);
    }
}

