/*
 * Decompiled with CFR 0.152.
 */
package model.tree;

import data.catalog.Catalog;
import data.feature.LinkFeature;
import data.feature.SimpleFeature;
import data.instance.Instance;
import data.instance.Instances;
import data.parameter.NumericShiftFunction;
import data.value.Value;
import eval.SimpleEvaluation;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import model.Model;
import model.ModelOptions;
import model.inference.hc.AggregateBase;
import model.tree.DecisionTree;
import model.tree.RunnableDecisionTree;
import org.jdom2.Content;
import org.jdom2.Element;
import util.FileWrite;
import util.GlobalRandom;

public class RandomForest
extends Model {
    protected HashMap<RunnableDecisionTree, Double> forest;
    protected HashMap<RunnableDecisionTree, Instances> tests;
    protected Instances allInsts;
    protected double oob;

    public RandomForest(ModelOptions o) {
        this.opts = o;
        this.forest = new HashMap();
        this.oob = 0.0;
    }

    public RandomForest(ModelOptions opt, Element root, Catalog cat) {
        this(opt);
        for (Element tr : root.getChildren("tree")) {
            this.forest.put(new RunnableDecisionTree(opt, tr, cat), 1.0);
        }
    }

    @Override
    public void build(Instances instsTrain, Catalog cat) {
        this.allInsts = instsTrain;
        this.tests = new HashMap();
        CountDownLatch cdl = new CountDownLatch(this.opts.treesInForest);
        int proc = Runtime.getRuntime().availableProcessors();
        Semaphore sem = new Semaphore(proc);
        int n = 1;
        while (n <= this.opts.treesInForest) {
            Instances train = new Instances();
            int i = 0;
            while (i < instsTrain.size()) {
                train.add((Instance)instsTrain.get(GlobalRandom.instance().nextInt(instsTrain.size())));
                ++i;
            }
            Instances test = new Instances();
            int i2 = 0;
            while (i2 < instsTrain.size()) {
                if (!train.contains(instsTrain.get(i2))) {
                    test.add((Instance)instsTrain.get(i2));
                }
                ++i2;
            }
            ModelOptions optsIndiv = this.opts.clone();
            Random gr = new Random(GlobalRandom.instance().nextLong());
            RunnableDecisionTree dt2 = new RunnableDecisionTree(optsIndiv, cat, train, gr, sem, cdl, n);
            dt2.addToName("_submodel_" + n);
            this.tests.put(dt2, test);
            ++n;
        }
        this.tests.keySet().stream().sorted(new Comparator<RunnableDecisionTree>(){

            @Override
            public int compare(RunnableDecisionTree dt1, RunnableDecisionTree dt2) {
                return Integer.compare(dt1.getIndex(), dt2.getIndex());
            }
        }).forEach(dt -> new Thread((Runnable)dt, dt.name()).start());
        try {
            cdl.await();
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        for (RunnableDecisionTree dt3 : this.tests.keySet()) {
            this.forest.put(dt3, 1.0);
        }
        this.initShifts(instsTrain, cat);
        this.oob = this.computeOob(cat);
    }

    protected double computeOob(Catalog cat) {
        double res = 0.0;
        if (this.opts.mode.equals("classification")) {
            int numInsts = 0;
            for (Instance inst : this.allInsts) {
                HashMap<Value, Double> scores = new HashMap<Value, Double>();
                for (Map.Entry<RunnableDecisionTree, Double> ent : this.forest.entrySet()) {
                    if (!this.tests.get(ent.getKey()).contains(inst)) continue;
                    HashSet<Value> cls = ent.getKey().classifyMajority(inst, cat);
                    double norm = cls.size();
                    for (Value cl : cls) {
                        if (!scores.containsKey(cl)) {
                            scores.put(cl, 0.0);
                        }
                        scores.put(cl, (Double)scores.get(cl) + ent.getValue() / norm);
                    }
                }
                ArrayList<Value> outcomes = new ArrayList<Value>();
                double bestScore = -1.0;
                for (Map.Entry ent : scores.entrySet()) {
                    if (!((Double)ent.getValue() >= bestScore)) continue;
                    if ((Double)ent.getValue() > bestScore) {
                        outcomes.clear();
                    }
                    bestScore = (Double)ent.getValue();
                    outcomes.add((Value)ent.getKey());
                }
                Value pred = null;
                pred = outcomes.isEmpty() ? null : (outcomes.size() == 1 ? (Value)outcomes.get(0) : (Value)outcomes.get(GlobalRandom.instance().nextInt(outcomes.size())));
                if (pred == null) continue;
                ++numInsts;
                if (!pred.equals(inst.getLabel())) continue;
                res += 1.0;
            }
            res /= (double)numInsts;
        } else if (this.opts.mode.equals("regression")) {
            int numInsts = 0;
            for (Instance inst : this.allInsts) {
                double sOutputs = 0.0;
                double sWeights = 0.0;
                for (Map.Entry<RunnableDecisionTree, Double> ent : this.forest.entrySet()) {
                    if (!this.tests.get(ent.getKey()).contains(inst)) continue;
                    double w = ent.getValue();
                    sOutputs += w * ent.getKey().classify(inst, cat).getNumericValue();
                    sWeights += w;
                }
                if (sWeights == 0.0) continue;
                ++numInsts;
                double diff = (sOutputs /= sWeights) - inst.getLabel().getNumericValue();
                res += diff * diff;
            }
            res = Math.sqrt(res / (double)numInsts);
        }
        return res;
    }

    @Override
    public Value classify(Instance inst, Catalog cat) {
        if (this.opts.mode.equals("classification")) {
            HashMap<Value, Double> scores = new HashMap<Value, Double>();
            for (Map.Entry<RunnableDecisionTree, Double> ent : this.forest.entrySet()) {
                HashSet<Value> cls = ent.getKey().classifyMajority(inst, cat);
                double norm = cls.size();
                for (Value cl : cls) {
                    if (!scores.containsKey(cl)) {
                        scores.put(cl, 0.0);
                    }
                    scores.put(cl, (Double)scores.get(cl) + ent.getValue() / norm);
                }
            }
            ArrayList<Value> outcomes = new ArrayList<Value>();
            double bestScore = -1.0;
            for (Map.Entry ent : scores.entrySet()) {
                if (!((Double)ent.getValue() >= bestScore)) continue;
                if ((Double)ent.getValue() > bestScore) {
                    outcomes.clear();
                }
                bestScore = (Double)ent.getValue();
                outcomes.add((Value)ent.getKey());
            }
            if (outcomes.isEmpty()) {
                return null;
            }
            if (outcomes.size() == 1) {
                return (Value)outcomes.get(0);
            }
            return (Value)outcomes.get(GlobalRandom.instance().nextInt(outcomes.size()));
        }
        if (this.opts.mode.equals("regression")) {
            double sOutputs = 0.0;
            double sWeights = 0.0;
            for (Map.Entry<RunnableDecisionTree, Double> ent : this.forest.entrySet()) {
                double w = ent.getValue();
                sOutputs += w * ent.getKey().classify(inst, cat).getNumericValue();
                sWeights += w;
            }
            return new Value(sOutputs /= sWeights);
        }
        return null;
    }

    @Override
    public Model clone() {
        return new RandomForest(this.opts.clone());
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("OOB error : " + this.oob + "\n\n");
        this.forest.keySet().stream().sorted(new Comparator<RunnableDecisionTree>(){

            @Override
            public int compare(RunnableDecisionTree dt1, RunnableDecisionTree dt2) {
                return Integer.compare(dt1.getIndex(), dt2.getIndex());
            }
        }).forEach(dt -> {
            sb.append(String.valueOf(dt.name()) + ", weight " + this.forest.get(dt) + "\n");
            sb.append(dt.toString());
        });
        return new String(sb);
    }

    public HashMap<SimpleFeature, Double> mainFeatureImportance(Catalog cat) throws CloneNotSupportedException {
        HashMap<SimpleFeature, Double> res = new HashMap<SimpleFeature, Double>();
        HashSet<SimpleFeature> keys = new HashSet<SimpleFeature>();
        for (DecisionTree decisionTree : this.forest.keySet()) {
            HashSet<SimpleFeature> features = decisionTree.getMainFeatures();
            keys.addAll(features);
        }
        Catalog catalog = cat.clone();
        catalog.setCoupled(false);
        for (SimpleFeature sf : keys) {
            double permScore = 0.0;
            HashSet<SimpleFeature> set = new HashSet<SimpleFeature>();
            set.add(sf);
            for (DecisionTree decisionTree : this.forest.keySet()) {
                Instances test = this.tests.get(decisionTree);
                catalog.permuteFeatures(test, null, set);
                SimpleEvaluation evalNorm = new SimpleEvaluation(this.opts.mode);
                evalNorm.evaluateModel(decisionTree, test, cat, false);
                SimpleEvaluation evalPerm = new SimpleEvaluation(this.opts.mode);
                evalPerm.evaluateModel(decisionTree, test, catalog, false);
                if (this.opts.mode.equals("classification")) {
                    permScore += (Double)evalNorm.getStatistics().getStat("accuracy") - (Double)evalPerm.getStatistics().getStat("accuracy");
                } else if (this.opts.mode.equals("regression")) {
                    permScore += (Double)evalNorm.getStatistics().getStat("rmse") - (Double)evalPerm.getStatistics().getStat("rmse");
                }
                catalog.clearPermutations();
            }
            res.put(sf, permScore /= (double)this.forest.size());
            FileWrite.writeToFile(String.valueOf(this.name()) + "mainAttrImportance", sf + ";" + permScore + "\n");
        }
        return res;
    }

    public HashMap<AggregateBase, Double> aggregateProcessImportance(Catalog cat) throws CloneNotSupportedException {
        HashMap<AggregateBase, Double> res = new HashMap<AggregateBase, Double>();
        HashSet<AggregateBase> keys = new HashSet<AggregateBase>();
        for (DecisionTree decisionTree : this.forest.keySet()) {
            HashMap<AggregateBase, HashSet<SimpleFeature>> features = decisionTree.getComplexAggregates();
            keys.addAll(features.keySet());
        }
        Catalog catalog = cat.clone();
        catalog.setCoupled(false);
        for (AggregateBase ab : keys) {
            double permScore = 0.0;
            for (DecisionTree decisionTree : this.forest.keySet()) {
                Instances test = this.tests.get(decisionTree);
                catalog.permuteBlocks(test, ab);
                SimpleEvaluation evalNorm = new SimpleEvaluation(this.opts.mode);
                evalNorm.evaluateModel(decisionTree, test, cat, false);
                SimpleEvaluation evalPerm = new SimpleEvaluation(this.opts.mode);
                evalPerm.evaluateModel(decisionTree, test, catalog, false);
                if (this.opts.mode.equals("classification")) {
                    permScore += (Double)evalNorm.getStatistics().getStat("accuracy") - (Double)evalPerm.getStatistics().getStat("accuracy");
                } else if (this.opts.mode.equals("regression")) {
                    permScore += (Double)evalNorm.getStatistics().getStat("rmse") - (Double)evalPerm.getStatistics().getStat("rmse");
                }
                catalog.clearPermutations();
            }
            res.put(ab, permScore /= (double)this.forest.size());
            FileWrite.writeToFile(String.valueOf(this.name()) + "processImportance", ab.getFunction() + ";" + (ab.getFeature() == null ? "null" : ab.getFeature().toString()) + ";" + permScore + "\n");
        }
        return res;
    }

    public HashMap<LinkFeature, HashMap<HashSet<SimpleFeature>, Double>> conditionImportance(Catalog cat) throws CloneNotSupportedException {
        HashMap<LinkFeature, HashMap<HashSet<SimpleFeature>, Double>> res = new HashMap<LinkFeature, HashMap<HashSet<SimpleFeature>, Double>>();
        HashMap keys = new HashMap();
        for (DecisionTree decisionTree : this.forest.keySet()) {
            HashMap<AggregateBase, HashSet<SimpleFeature>> features = decisionTree.getComplexAggregates();
            for (Map.Entry<AggregateBase, HashSet<SimpleFeature>> ent : features.entrySet()) {
                LinkFeature lf = ent.getKey().getLink();
                if (!keys.containsKey(lf)) {
                    keys.put(lf, new HashSet());
                }
                HashSet setSfs = new HashSet();
                setSfs.add(new HashSet());
                for (SimpleFeature sf : ent.getValue()) {
                    HashSet<SimpleFeature> toAdd = new HashSet<SimpleFeature>();
                    toAdd.add(sf);
                    ((HashSet)keys.get(lf)).add(toAdd);
                }
            }
        }
        Catalog catalog = cat.clone();
        catalog.setCoupled(false);
        for (Map.Entry ent : keys.entrySet()) {
            HashMap<HashSet, Double> toAdd = new HashMap<HashSet, Double>();
            for (HashSet set : (HashSet)ent.getValue()) {
                System.out.println(set.toString());
                double permScore = 0.0;
                for (DecisionTree decisionTree : this.forest.keySet()) {
                    Instances test = this.tests.get(decisionTree);
                    catalog.permuteFeatures(test, (LinkFeature)ent.getKey(), set);
                    SimpleEvaluation evalNorm = new SimpleEvaluation(this.opts.mode);
                    evalNorm.evaluateModel(decisionTree, test, cat, false);
                    SimpleEvaluation evalPerm = new SimpleEvaluation(this.opts.mode);
                    evalPerm.evaluateModel(decisionTree, test, catalog, false);
                    if (this.opts.mode.equals("classification")) {
                        permScore += (Double)evalNorm.getStatistics().getStat("accuracy") - (Double)evalPerm.getStatistics().getStat("accuracy");
                    } else if (this.opts.mode.equals("regression")) {
                        permScore += (Double)evalNorm.getStatistics().getStat("rmse") - (Double)evalPerm.getStatistics().getStat("rmse");
                    }
                    catalog.clearPermutations();
                }
                toAdd.put(set, permScore /= (double)this.forest.size());
                FileWrite.writeToFile(String.valueOf(this.name()) + "conditionImportance", String.valueOf(set.toString()) + ";" + permScore + "\n");
            }
            res.put((LinkFeature)ent.getKey(), toAdd);
        }
        return res;
    }

    public HashMap<AggregateBase, HashMap<HashSet<SimpleFeature>, Double>> variableImportance(Catalog cat) throws CloneNotSupportedException {
        HashMap<AggregateBase, HashMap<HashSet<SimpleFeature>, Double>> res = new HashMap<AggregateBase, HashMap<HashSet<SimpleFeature>, Double>>();
        HashMap keys = new HashMap();
        for (DecisionTree decisionTree : this.forest.keySet()) {
            HashMap<AggregateBase, HashSet<SimpleFeature>> features = decisionTree.getComplexAggregates();
            for (Map.Entry<AggregateBase, HashSet<SimpleFeature>> ent : features.entrySet()) {
                HashSet setSfs = new HashSet();
                setSfs.add(new HashSet());
                for (SimpleFeature sf : ent.getValue()) {
                    HashSet<SimpleFeature> toAdd = new HashSet<SimpleFeature>();
                    toAdd.add(sf);
                    setSfs.add(toAdd);
                }
                if (keys.containsKey(ent.getKey())) {
                    ((HashSet)keys.get(ent.getKey())).addAll(setSfs);
                    continue;
                }
                keys.put(ent.getKey(), setSfs);
            }
        }
        Catalog catalog = cat.clone();
        catalog.setCoupled(true);
        for (Map.Entry ent : keys.entrySet()) {
            AggregateBase ab = (AggregateBase)ent.getKey();
            System.out.println(ab.toString());
            HashSet vals = (HashSet)ent.getValue();
            HashMap<HashSet, Double> toAdd = new HashMap<HashSet, Double>();
            for (HashSet set : vals) {
                System.out.println(set.toString());
                double permScore = 0.0;
                for (DecisionTree decisionTree : this.forest.keySet()) {
                    Instances test = this.tests.get(decisionTree);
                    catalog.permuteBlocks(test, ab);
                    catalog.permuteFeatures(test, ab.getLink(), set);
                    SimpleEvaluation evalNorm = new SimpleEvaluation(this.opts.mode);
                    evalNorm.evaluateModel(decisionTree, test, cat, false);
                    SimpleEvaluation evalPerm = new SimpleEvaluation(this.opts.mode);
                    evalPerm.evaluateModel(decisionTree, test, catalog, false);
                    if (this.opts.mode.equals("classification")) {
                        permScore += (Double)evalNorm.getStatistics().getStat("accuracy") - (Double)evalPerm.getStatistics().getStat("accuracy");
                    } else if (this.opts.mode.equals("regression")) {
                        permScore += (Double)evalNorm.getStatistics().getStat("rmse") - (Double)evalPerm.getStatistics().getStat("rmse");
                    }
                    catalog.clearPermutations();
                }
                FileWrite.writeToFile(String.valueOf(this.name()) + "variableImportance", ab.getFunction() + ";" + (ab.getFeature() == null ? "null" : ab.getFeature().toString()) + ";" + set.toString() + ";" + (permScore /= (double)this.forest.size()) + "\n");
                toAdd.put(set, permScore);
            }
            res.put(ab, toAdd);
        }
        return res;
    }

    @Override
    public String name() {
        return this.opts.name;
    }

    @Override
    public HashSet<NumericShiftFunction> getParameters() {
        return null;
    }

    public Instances getAllInstances() {
        return this.allInsts;
    }

    @Override
    public Element toXMLElement() {
        Element el = new Element("forest");
        el.setAttribute("name", this.name());
        for (RunnableDecisionTree dt : this.forest.keySet()) {
            el.addContent((Content)dt.toXMLElement());
        }
        return el;
    }

    @Override
    public HashMap<NumericShiftFunction, double[]> getShifts(Instances insts, Catalog cat) {
        HashMap<NumericShiftFunction, double[]> shifts = new HashMap<NumericShiftFunction, double[]>();
        for (DecisionTree decisionTree : this.forest.keySet()) {
            shifts.putAll(decisionTree.getShifts(insts, cat));
        }
        return shifts;
    }

    @Override
    public void initShifts(Instances insts, Catalog cat) {
        for (DecisionTree decisionTree : this.forest.keySet()) {
            decisionTree.initShifts(insts, cat);
        }
    }

    @Override
    public void deployShifts(Instances insts, Catalog cat) {
        for (DecisionTree decisionTree : this.forest.keySet()) {
            decisionTree.deployShifts(insts, cat);
        }
    }

    @Override
    public HashSet<NumericShiftFunction> getShifts() {
        HashSet<NumericShiftFunction> res = new HashSet<NumericShiftFunction>();
        for (DecisionTree decisionTree : this.forest.keySet()) {
            res.addAll(decisionTree.getShifts());
        }
        return res;
    }
}

