/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import org.linqs.psl.application.ModelApplication;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.model.deep.DeepModelPredicate;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.Reflection;

public abstract class WeightLearningApplication
implements ModelApplication {
    public static final String DELIM = ":";
    private static final Logger log = Logger.getLogger(WeightLearningApplication.class);
    protected Database trainTargetDatabase;
    protected Database trainTruthDatabase;
    protected Database validationTargetDatabase;
    protected Database validationTruthDatabase;
    protected List<DeepPredicate> deepPredicates;
    protected List<DeepModelPredicate> deepModelPredicates;
    protected List<DeepModelPredicate> validationDeepModelPredicates;
    protected boolean runValidation;
    protected List<Rule> allRules;
    protected List<WeightedRule> mutableRules;
    protected TrainingMap trainingMap;
    protected TrainingMap validationMap;
    protected InferenceApplication trainInferenceApplication;
    protected InferenceApplication validationInferenceApplication;
    protected EvaluationInstance evaluation;
    private boolean groundModelInit;
    protected boolean inTrainingMAPState;
    protected boolean inValidationMAPState;

    public WeightLearningApplication(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, Boolean runValidation) {
        this.trainTargetDatabase = trainTargetDatabase;
        this.trainTruthDatabase = trainTruthDatabase;
        this.validationTargetDatabase = validationTargetDatabase;
        this.validationTruthDatabase = validationTruthDatabase;
        this.runValidation = runValidation;
        this.deepPredicates = new ArrayList<DeepPredicate>();
        this.deepModelPredicates = new ArrayList<DeepModelPredicate>();
        this.validationDeepModelPredicates = new ArrayList<DeepModelPredicate>();
        this.allRules = new ArrayList<Rule>();
        this.mutableRules = new ArrayList<WeightedRule>();
        for (Rule rule : rules) {
            this.allRules.add(rule);
            if (!(rule instanceof WeightedRule)) continue;
            this.mutableRules.add((WeightedRule)rule);
        }
        this.trainInferenceApplication = null;
        this.validationInferenceApplication = null;
        this.trainingMap = null;
        this.validationMap = null;
        this.groundModelInit = false;
        this.inTrainingMAPState = false;
        this.inValidationMAPState = false;
        this.evaluation = null;
    }

    public void setEvaluation(EvaluationInstance evaluation) {
        this.evaluation = evaluation;
    }

    public void learn() {
        this.initGroundModel();
        this.doLearn();
    }

    protected abstract void doLearn();

    public void setBudget(double budget) {
        this.trainInferenceApplication.setBudget(budget);
    }

    protected void initGroundModel() {
        if (this.groundModelInit) {
            return;
        }
        InferenceApplication trainInferenceApplication = InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), this.allRules, this.trainTargetDatabase);
        trainInferenceApplication.loadDeepPredicates("learning");
        InferenceApplication validationInferenceApplication = InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), this.allRules, this.validationTargetDatabase);
        this.initGroundModel(trainInferenceApplication, validationInferenceApplication);
    }

    private void initGroundModel(InferenceApplication trainInferenceApplication, InferenceApplication validationInferenceApplication) {
        if (this.groundModelInit) {
            return;
        }
        TrainingMap trainingMap = new TrainingMap(trainInferenceApplication.getDatabase(), this.trainTruthDatabase);
        TrainingMap validationMap = new TrainingMap(validationInferenceApplication.getDatabase(), this.validationTruthDatabase);
        this.initGroundModel(trainInferenceApplication, trainingMap, validationInferenceApplication, validationMap);
    }

    public void initGroundModel(InferenceApplication trainInferenceApplication, TrainingMap trainingMap, InferenceApplication validationInferenceApplication, TrainingMap validationMap) {
        if (this.groundModelInit) {
            return;
        }
        this.trainInferenceApplication = trainInferenceApplication;
        this.trainingMap = trainingMap;
        this.validationInferenceApplication = validationInferenceApplication;
        this.validationMap = validationMap;
        if (Options.WLA_RANDOM_WEIGHTS.getBoolean()) {
            this.initRandomWeights();
        }
        for (Predicate predicate : Predicate.getAll()) {
            if (!(predicate instanceof DeepPredicate)) continue;
            this.deepPredicates.add((DeepPredicate)predicate);
            this.deepModelPredicates.add(((DeepPredicate)predicate).getDeepModel());
            DeepModelPredicate validationDeepModelPredicate = ((DeepPredicate)predicate).getDeepModel().copy();
            validationDeepModelPredicate.setAtomStore(validationInferenceApplication.getDatabase().getAtomStore(), true);
            this.validationDeepModelPredicates.add(validationDeepModelPredicate);
        }
        this.postInitGroundModel();
        this.groundModelInit = true;
    }

    private void initRandomWeights() {
        log.trace("Randomly Weighted Rules:");
        for (WeightedRule rule : this.mutableRules) {
            rule.setWeight(RandUtils.nextFloat());
            log.trace("    " + rule.toString());
        }
    }

    protected void postInitGroundModel() {
    }

    protected void computeTrainingMAPState() {
        if (this.inTrainingMAPState) {
            return;
        }
        this.computeMAPState(this.trainInferenceApplication);
        this.inTrainingMAPState = true;
    }

    protected void computeValidationMAPState() {
        if (this.inValidationMAPState) {
            return;
        }
        this.computeMAPState(this.validationInferenceApplication);
        this.inValidationMAPState = true;
    }

    protected void computeMAPState(InferenceApplication inferenceApplication) {
        inferenceApplication.inference(false, true);
    }

    @Override
    public void close() {
        if (this.trainInferenceApplication != null) {
            this.trainInferenceApplication.commit();
            this.trainInferenceApplication.close();
            this.trainInferenceApplication = null;
        }
        if (this.validationInferenceApplication != null) {
            this.validationInferenceApplication.commit();
            this.validationInferenceApplication.close();
            this.validationInferenceApplication = null;
        }
        this.trainingMap = null;
        this.trainTargetDatabase = null;
        this.trainTruthDatabase = null;
        this.validationMap = null;
        this.validationTargetDatabase = null;
        this.validationTruthDatabase = null;
    }

    public static WeightLearningApplication getWLA(String name, List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
        String className = Reflection.resolveClassName(name);
        if (className == null) {
            throw new IllegalArgumentException("Could not find class: " + name);
        }
        Class<?> classObject = null;
        try {
            Class<?> uncheckedClassObject;
            classObject = uncheckedClassObject = Class.forName(className);
        }
        catch (ClassNotFoundException ex) {
            throw new IllegalArgumentException("Could not find class: " + className, ex);
        }
        Constructor<?> constructor = null;
        try {
            constructor = classObject.getConstructor(List.class, Database.class, Database.class, Database.class, Database.class, Boolean.TYPE);
        }
        catch (NoSuchMethodException ex) {
            throw new IllegalArgumentException("No suitable constructor found for weight learner: " + className + ".", ex);
        }
        WeightLearningApplication wla = null;
        try {
            wla = (WeightLearningApplication)constructor.newInstance(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);
        }
        catch (InstantiationException ex) {
            throw new RuntimeException("Unable to instantiate weight learner (" + className + ")", ex);
        }
        catch (IllegalAccessException ex) {
            throw new RuntimeException("Insufficient access to constructor for " + className, ex);
        }
        catch (InvocationTargetException ex) {
            throw new RuntimeException("Error thrown while constructing " + className, ex);
        }
        return wla;
    }
}

