/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.bayesian;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.bayesian.GaussianProcessKernel;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.FloatMatrix;
import org.linqs.psl.util.ListUtils;
import org.linqs.psl.util.Parallel;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GaussianProcessPrior
extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(GaussianProcessPrior.class);
    public static final String CONFIG_PREFIX = "gpp";
    public static final String KERNEL_KEY = "gpp.kernel";
    public static final String KERNEL_DEFAULT = GaussianProcessKernel.KernelType.SQUARED_EXP.toString();
    public static final String MAX_ITERATIONS_KEY = "gpp.maxiterations";
    public static final int MAX_ITERATIONS_DEFAULT = 25;
    public static final String MAX_CONFIGS_KEY = "gpp.maxconfigs";
    public static final int MAX_CONFIGS_DEFAULT = 1000000;
    public static final String EXPLORATION_KEY = "gpp.explore";
    public static final float EXPLORATION_DEFAULT = 2.0f;
    public static final String RANDOM_CONFIGS_ONLY_KEY = "gpp.randomConfigsOnly";
    public static final boolean RANDOM_CONFIGS_ONLY_DEFAULT = true;
    public static final String EARLY_STOPPING_KEY = "gpp.earlyStopping";
    public static final boolean EARLY_STOPPING_DEFAULT = true;
    public static final int MAX_RAND_INT_VAL = 100000000;
    public static final float SMALL_VALUE = 0.4f;
    private GaussianProcessKernel.KernelType kernelType = GaussianProcessKernel.KernelType.valueOf(Config.getString("gpp.kernel", KERNEL_DEFAULT).toUpperCase());
    private int maxIterations = Config.getInt("gpp.maxiterations", 25);
    private int maxConfigs = Config.getInt("gpp.maxconfigs", 1000000);
    private float exploration = Config.getFloat("gpp.explore", 2.0f);
    private boolean randomConfigsOnly = Config.getBoolean("gpp.randomConfigsOnly", true);
    private boolean earlyStopping = Config.getBoolean("gpp.earlyStopping", true);
    private float minConfigVal = 1.0E-8f;
    private FloatMatrix knownDataStdInv;
    private GaussianProcessKernel kernel;
    private GaussianProcessKernel.Space space = GaussianProcessKernel.Space.valueOf(Config.getString("gppker.space", GaussianProcessKernel.SPACE_DEFAULT));
    private List<WeightConfig> configs;
    private List<WeightConfig> exploredConfigs;
    private FloatMatrix blasYKnown;

    public GaussianProcessPrior(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB, false);
    }

    public GaussianProcessPrior(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    private void reset() {
        this.configs = this.getConfigs();
        this.exploredConfigs = new ArrayList<WeightConfig>();
    }

    protected void setKnownDataStdInvForTest(FloatMatrix data) {
        this.knownDataStdInv = data;
    }

    protected void setKernelForTest(GaussianProcessKernel kernel) {
        this.kernel = kernel;
    }

    protected void setBlasYKnownForTest(FloatMatrix blasYKnown) {
        this.blasYKnown = blasYKnown;
    }

    @Override
    protected void doLearn() {
        int iteration;
        this.kernel = GaussianProcessKernel.makeKernel(this.kernelType, this);
        this.reset();
        ArrayList<Float> exploredFnVal = new ArrayList<Float>();
        WeightConfig bestConfig = null;
        float bestVal = 0.0f;
        boolean allStdSmall = false;
        block0: for (iteration = 0; !(iteration >= this.maxIterations || this.configs.size() <= 0 || this.earlyStopping && allStdSmall); ++iteration) {
            int nextPoint = this.getNextPoint(this.configs, iteration);
            WeightConfig config = this.configs.get(nextPoint);
            this.exploredConfigs.add(config);
            this.configs.remove(nextPoint);
            float fnVal = this.getFunctionValue(config);
            exploredFnVal.add(Float.valueOf(fnVal));
            config.valueAndStd.value = fnVal;
            config.valueAndStd.std = 0.0f;
            if (bestConfig == null || fnVal > bestVal) {
                bestVal = fnVal;
                bestConfig = config;
            }
            log.info(String.format("Iteration %d -- Config Picked: %s, Curent Best Config: %s.", iteration + 1, this.exploredConfigs.get(iteration), bestConfig));
            int numKnown = exploredFnVal.size();
            this.knownDataStdInv = FloatMatrix.zeroes(numKnown, numKnown);
            for (int i = 0; i < numKnown; ++i) {
                for (int j = 0; j < numKnown; ++j) {
                    this.knownDataStdInv.set(i, j, this.kernel.kernel(this.exploredConfigs.get((int)i).config, this.exploredConfigs.get((int)j).config));
                }
            }
            this.knownDataStdInv = this.knownDataStdInv.inverse();
            this.blasYKnown = FloatMatrix.columnVector(ListUtils.toPrimitiveFloatArray(exploredFnVal), false);
            ComputePredictionFunctionValueWorker fnValWorker = new ComputePredictionFunctionValueWorker();
            int index = 0;
            for (WeightConfig weightConfig : this.configs) {
                fnValWorker.work(index, weightConfig);
                ++index;
            }
            allStdSmall = true;
            for (int i = 0; i < this.configs.size(); ++i) {
                if (!(this.configs.get((int)i).valueAndStd.std > 0.4f)) continue;
                allStdSmall = false;
                continue block0;
            }
        }
        this.setWeights(bestConfig);
        log.info(String.format("Total number of iterations completed: %d. Stopped early: %s.", iteration, this.earlyStopping && allStdSmall));
        log.info("Best config: " + bestConfig);
    }

    private void setWeights(WeightConfig config) {
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            ((WeightedRule)this.mutableRules.get(i)).setWeight(config.config[i]);
        }
        this.inMPEState = false;
    }

    protected List<WeightConfig> getConfigs() {
        int numMutableRules = this.mutableRules.size();
        ArrayList<WeightConfig> configs = new ArrayList<WeightConfig>();
        float max = 1.0f;
        float min = 1.0E-8f;
        if (this.space == GaussianProcessKernel.Space.OS) {
            min = 0.0f;
        }
        int numPerSplit = (int)Math.exp(Math.log(this.maxConfigs) / (double)numMutableRules);
        if (this.randomConfigsOnly) {
            log.debug("Generating random configs.");
            return this.getRandomConfigs();
        }
        if (numPerSplit < 5) {
            log.warn("Note not picking random points and large number of rules will yield bad exploration.");
        }
        float inc = max / (float)numPerSplit;
        float[] configArray = new float[numMutableRules];
        Arrays.fill(configArray, min);
        WeightConfig config = new WeightConfig(configArray);
        boolean done = false;
        block0: while (!done) {
            int i = 0;
            configs.add(new WeightConfig(config));
            for (int j = 0; j < numMutableRules; ++j) {
                if (config.config[i] < max) {
                    int n = i;
                    config.config[n] = config.config[n] + inc;
                    continue block0;
                }
                if (i == numMutableRules - 1) {
                    done = true;
                    continue block0;
                }
                config.config[i] = min;
                ++i;
            }
        }
        return configs;
    }

    protected int[] computeScalingFactor() {
        int[] factor = new int[this.mutableRules.size()];
        for (int i = 0; i < factor.length; ++i) {
            factor[i] = Math.max(1, this.groundRuleStore.count((Rule)this.mutableRules.get(i)));
        }
        return factor;
    }

    private List<WeightConfig> getRandomConfigs() {
        int numMutableRules = this.mutableRules.size();
        ArrayList<WeightConfig> configs = new ArrayList<WeightConfig>();
        for (int i = 0; i < this.maxConfigs; ++i) {
            WeightConfig curConfig = new WeightConfig(new float[numMutableRules]);
            for (int j = 0; j < numMutableRules; ++j) {
                curConfig.config[j] = (float)(RandUtils.nextInt(100000000) + 1) / 1.0E8f;
            }
            configs.add(curConfig);
        }
        return configs;
    }

    protected ValueAndStd predictFnValAndStd(float[] x, List<WeightConfig> xKnown) {
        return this.predictFnValAndStd(x, xKnown, new float[this.blasYKnown.size()], new float[x.length], new float[x.length], new FloatMatrix(), new FloatMatrix(), new FloatMatrix(), FloatMatrix.zeroes(1, x.length));
    }

    protected ValueAndStd predictFnValAndStd(float[] x, List<WeightConfig> xKnown, float[] xyStdData, float[] kernelBuffer1, float[] kernelBuffer2, FloatMatrix kernelMatrixShell1, FloatMatrix kernelMatrixShell2, FloatMatrix xyStdMatrixShell, FloatMatrix mulBuffer) {
        ValueAndStd fnAndStd = new ValueAndStd();
        for (int i = 0; i < xyStdData.length; ++i) {
            xyStdData[i] = this.kernel.kernel(x, xKnown.get((int)i).config, kernelBuffer1, kernelBuffer2, kernelMatrixShell1, kernelMatrixShell2);
        }
        xyStdMatrixShell.assume(xyStdData, 1, xyStdData.length);
        FloatMatrix xyStd = xyStdMatrixShell;
        FloatMatrix product = xyStd.mul(this.knownDataStdInv, mulBuffer, false, false, 1.0f, 0.0f);
        fnAndStd.value = product.dot(this.blasYKnown);
        fnAndStd.std = this.kernel.kernel(x, x, kernelBuffer1, kernelBuffer2, kernelMatrixShell1, kernelMatrixShell2) - product.dot(xyStd);
        return fnAndStd;
    }

    protected float getFunctionValue(WeightConfig config) {
        this.setWeights(config);
        this.computeMPEState();
        this.evaluator.compute(this.trainingMap);
        double score = this.evaluator.getRepresentativeMetric();
        score = this.evaluator.isHigherRepresentativeBetter() ? score : -1.0 * score;
        return (float)score;
    }

    protected int getNextPoint(List<WeightConfig> configs, int iteration) {
        int bestConfig = -1;
        float curBestVal = -3.4028235E38f;
        for (int i = 0; i < configs.size(); ++i) {
            float curVal = configs.get((int)i).valueAndStd.value / this.exploration + configs.get((int)i).valueAndStd.std;
            if (bestConfig != -1 && !(curVal > curBestVal)) continue;
            curBestVal = curVal;
            bestConfig = i;
        }
        return bestConfig;
    }

    protected static class WeightConfig {
        public float[] config;
        public ValueAndStd valueAndStd;

        public WeightConfig(float[] config) {
            this(config, 0.0f, 1.0f);
        }

        public WeightConfig(WeightConfig config) {
            this(Arrays.copyOf(config.config, config.config.length), config.valueAndStd.value, config.valueAndStd.std);
        }

        public WeightConfig(float[] config, float val, float std) {
            this.config = config;
            this.valueAndStd = new ValueAndStd(val, std);
        }

        public String toString() {
            return String.format("(weights: [%s], val: %f, std: %f)", StringUtils.join(", ", this.config), Float.valueOf(this.valueAndStd.value), Float.valueOf(this.valueAndStd.std));
        }
    }

    protected static class ValueAndStd {
        float value;
        float std;

        ValueAndStd() {
            this(0.0f, 1.0f);
        }

        ValueAndStd(float value, float std) {
            this.value = value;
            this.std = std;
        }
    }

    private class ComputePredictionFunctionValueWorker
    extends Parallel.Worker<WeightConfig> {
        private float[] xyStdData;
        private FloatMatrix xyStdMatrixShell;
        private float[] kernelBuffer1;
        private float[] kernelBuffer2;
        private FloatMatrix kernelMatrixShell1;
        private FloatMatrix kernelMatrixShell2;
        private FloatMatrix mulBuffer;

        public ComputePredictionFunctionValueWorker() {
            this.xyStdData = new float[GaussianProcessPrior.this.blasYKnown.size()];
            this.xyStdMatrixShell = new FloatMatrix();
            this.kernelBuffer1 = new float[GaussianProcessPrior.this.mutableRules.size()];
            this.kernelBuffer2 = new float[GaussianProcessPrior.this.mutableRules.size()];
            this.kernelMatrixShell1 = new FloatMatrix();
            this.kernelMatrixShell2 = new FloatMatrix();
            this.mulBuffer = FloatMatrix.zeroes(1, GaussianProcessPrior.this.blasYKnown.size());
        }

        public Object clone() {
            return new ComputePredictionFunctionValueWorker();
        }

        @Override
        public void work(int index, WeightConfig item) {
            ValueAndStd valAndStd;
            ((WeightConfig)((GaussianProcessPrior)GaussianProcessPrior.this).configs.get((int)index)).valueAndStd = valAndStd = GaussianProcessPrior.this.predictFnValAndStd(((WeightConfig)((GaussianProcessPrior)GaussianProcessPrior.this).configs.get((int)index)).config, GaussianProcessPrior.this.exploredConfigs, this.xyStdData, this.kernelBuffer1, this.kernelBuffer2, this.kernelMatrixShell1, this.kernelMatrixShell2, this.xyStdMatrixShell, this.mulBuffer);
        }
    }
}

