/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.maxlikelihood;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.linqs.psl.application.learning.weight.VotedPerceptron;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.Parallel;
import org.linqs.psl.util.RandUtils;

public class MaxPiecewisePseudoLikelihood
extends VotedPerceptron {
    private final int maxNumSamples;
    private int numSamples;
    private List<Map<RandomVariableAtom, List<WeightedGroundRule>>> ruleRandomVariableMap;
    private Random[] rands;

    public MaxPiecewisePseudoLikelihood(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    public MaxPiecewisePseudoLikelihood(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB);
        this.numSamples = this.maxNumSamples = Options.WLA_MPPLE_NUM_SAMPLES.getInt();
        this.rands = new Random[Parallel.getNumThreads()];
        for (int i = 0; i < Parallel.getNumThreads(); ++i) {
            this.rands[i] = new Random(RandUtils.nextLong());
        }
        this.ruleRandomVariableMap = null;
        this.averageSteps = false;
    }

    @Override
    protected void postInitGroundModel() {
        this.populateRandomVariableMap();
    }

    private void populateRandomVariableMap() {
        this.ruleRandomVariableMap = new ArrayList<Map<RandomVariableAtom, List<WeightedGroundRule>>>();
        for (Rule rule : this.mutableRules) {
            HashMap groundRuleMap = new HashMap();
            for (GroundRule groundRule : this.inference.getGroundRuleStore().getGroundRules(rule)) {
                for (GroundAtom atom : groundRule.getAtoms()) {
                    if (!(atom instanceof RandomVariableAtom)) continue;
                    RandomVariableAtom rva = (RandomVariableAtom)atom;
                    if (!groundRuleMap.containsKey(rva)) {
                        groundRuleMap.put(rva, new ArrayList());
                    }
                    ((List)groundRuleMap.get(atom)).add((WeightedGroundRule)groundRule);
                }
            }
            this.ruleRandomVariableMap.add(groundRuleMap);
        }
    }

    @Override
    protected void computeExpectedIncompatibility() {
        this.setLabeledRandomVariables();
        Parallel.count(this.mutableRules.size(), new Parallel.Worker<Integer>(){

            @Override
            public void work(int ruleIndex, Integer ignore) {
                WeightedRule rule = (WeightedRule)MaxPiecewisePseudoLikelihood.this.mutableRules.get(ruleIndex);
                Map groundRuleMap = (Map)MaxPiecewisePseudoLikelihood.this.ruleRandomVariableMap.get(ruleIndex);
                double accumulateIncompatibility = 0.0;
                double weight = rule.getWeight();
                for (RandomVariableAtom atom : groundRuleMap.keySet()) {
                    List groundRules = (List)groundRuleMap.get(atom);
                    double numerator = 0.0;
                    double denominator = 1.0E-6;
                    for (int sampleIndex = 0; sampleIndex < MaxPiecewisePseudoLikelihood.this.numSamples; ++sampleIndex) {
                        float sample = MaxPiecewisePseudoLikelihood.this.rands[this.id].nextFloat();
                        double energy = 0.0;
                        for (int i = 0; i < groundRules.size(); ++i) {
                            energy += ((WeightedGroundRule)groundRules.get(i)).getIncompatibility(atom, sample);
                        }
                        numerator += Math.exp(-1.0 * weight * energy) * energy;
                        denominator += Math.exp(-1.0 * weight * energy);
                    }
                    accumulateIncompatibility += numerator / denominator;
                }
                ((MaxPiecewisePseudoLikelihood)MaxPiecewisePseudoLikelihood.this).expectedIncompatibility[ruleIndex] = accumulateIncompatibility;
            }
        });
    }

    @Override
    public double computeLoss() {
        this.setLabeledRandomVariables();
        final double[] losses = new double[this.mutableRules.size()];
        Parallel.count(this.mutableRules.size(), new Parallel.Worker<Integer>(){

            @Override
            public void work(int ruleIndex, Integer ignore) {
                Map groundRuleMap = (Map)MaxPiecewisePseudoLikelihood.this.ruleRandomVariableMap.get(ruleIndex);
                WeightedRule rule = (WeightedRule)MaxPiecewisePseudoLikelihood.this.mutableRules.get(ruleIndex);
                double weight = rule.getWeight();
                for (RandomVariableAtom atom : groundRuleMap.keySet()) {
                    List groundRules = (List)groundRuleMap.get(atom);
                    double expInc = 0.0;
                    for (int sampleIndex = 0; sampleIndex < MaxPiecewisePseudoLikelihood.this.numSamples; ++sampleIndex) {
                        float sample = MaxPiecewisePseudoLikelihood.this.rands[this.id].nextFloat();
                        double energy = 0.0;
                        for (int i = 0; i < groundRules.size(); ++i) {
                            energy -= ((WeightedGroundRule)groundRules.get(i)).getIncompatibility(atom, sample);
                        }
                        expInc += Math.exp(weight * energy);
                    }
                    double obsInc = 0.0;
                    for (int i = 0; i < groundRules.size(); ++i) {
                        obsInc += -1.0 * weight * ((WeightedGroundRule)groundRules.get(i)).getIncompatibility();
                    }
                    expInc = -1.0 * Math.log(expInc / (double)MaxPiecewisePseudoLikelihood.this.numSamples);
                    int n = ruleIndex;
                    losses[n] = losses[n] + (obsInc + expInc);
                }
                int n = ruleIndex;
                losses[n] = losses[n] + -0.5 * MaxPiecewisePseudoLikelihood.this.l2Regularization * Math.pow(weight, 2.0);
            }
        });
        double loss = 0.0;
        for (double ruleLoss : losses) {
            loss += ruleLoss;
        }
        return loss;
    }

    @Override
    protected void computeObservedIncompatibility() {
        this.setLabeledRandomVariables();
        for (int ruleIndex = 0; ruleIndex < this.mutableRules.size(); ++ruleIndex) {
            WeightedRule rule = (WeightedRule)this.mutableRules.get(ruleIndex);
            Map<RandomVariableAtom, List<WeightedGroundRule>> groundRuleMap = this.ruleRandomVariableMap.get(ruleIndex);
            double weight = rule.getWeight();
            double obsInc = 0.0;
            for (RandomVariableAtom atom : groundRuleMap.keySet()) {
                for (WeightedGroundRule groundRule : groundRuleMap.get(atom)) {
                    obsInc += groundRule.getIncompatibility();
                }
            }
            this.observedIncompatibility[ruleIndex] = obsInc;
        }
    }

    @Override
    public void setBudget(double budget) {
        super.setBudget(budget);
        this.numSamples = (int)Math.ceil(budget * (double)this.maxNumSamples);
    }
}

