/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.inference;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.linqs.psl.application.ModelApplication;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.grounding.Grounding;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.UnweightedRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.InitialValue;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.term.TermGenerator;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Reflection;

public abstract class InferenceApplication
implements ModelApplication {
    private static final Logger log = Logger.getLogger(InferenceApplication.class);
    protected List<Rule> rules;
    protected Database database;
    protected Reasoner reasoner;
    protected InitialValue initialValue;
    protected boolean skipInference;
    protected boolean normalizeWeights;
    protected boolean relaxHardConstraints;
    protected float relaxationMultiplier;
    protected boolean relaxationSquared;
    protected GroundRuleStore groundRuleStore;
    protected TermStore termStore;
    protected TermGenerator termGenerator;
    protected PersistedAtomManager atomManager;
    private boolean atomsCommitted;

    protected InferenceApplication(List<Rule> rules, Database database) {
        this(rules, database, Options.INFERENCE_RELAX.getBoolean());
    }

    protected InferenceApplication(List<Rule> rules, Database database, boolean relaxHardConstraints) {
        this.rules = new ArrayList<Rule>(rules);
        this.database = database;
        this.atomsCommitted = false;
        this.initialValue = InitialValue.valueOf(Options.INFERENCE_INITIAL_VARIABLE_VALUE.getString());
        this.skipInference = Options.INFERENCE_SKIP_INFERENCE.getBoolean();
        this.normalizeWeights = Options.INFERENCE_NORMALIZE_WEIGHTS.getBoolean();
        this.relaxHardConstraints = relaxHardConstraints;
        this.relaxationMultiplier = Options.INFERENCE_RELAX_MULTIPLIER.getFloat();
        this.relaxationSquared = Options.INFERENCE_RELAX_SQUARED.getBoolean();
        this.initialize();
    }

    protected void initialize() {
        log.debug("Creating persisted atom manager.");
        this.atomManager = this.createAtomManager(this.database);
        log.debug("Atom manager initialization complete.");
        this.initializeAtoms();
        if (this.normalizeWeights) {
            this.normalizeWeights();
        }
        if (this.relaxHardConstraints) {
            this.relaxHardConstraints();
        }
        this.reasoner = this.createReasoner();
        this.termGenerator = this.createTermGenerator();
        this.termStore = this.createTermStore();
        this.groundRuleStore = this.createGroundRuleStore();
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount());
        this.completeInitialize();
    }

    protected PersistedAtomManager createAtomManager(Database database) {
        return new PersistedAtomManager(database, false, this.initialValue);
    }

    protected GroundRuleStore createGroundRuleStore() {
        return (GroundRuleStore)Options.INFERENCE_GRS.getNewObject();
    }

    protected Reasoner createReasoner() {
        return (Reasoner)Options.INFERENCE_REASONER.getNewObject();
    }

    protected TermGenerator createTermGenerator() {
        return (TermGenerator)Options.INFERENCE_TG.getNewObject();
    }

    protected TermStore createTermStore() {
        return (TermStore)Options.INFERENCE_TS.getNewObject();
    }

    protected void completeInitialize() {
        log.info("Grounding out model.");
        boolean oldValue = this.atomManager.queryDBForClosedAtoms(false);
        long groundCount = Grounding.groundAll(this.rules, this.atomManager, this.groundRuleStore);
        this.atomManager.queryDBForClosedAtoms(oldValue);
        log.info("Grounding complete.");
        log.debug("Generated {} ground rules.", groundCount);
        if (this.skipInference) {
            return;
        }
        log.debug("Initializing objective terms for {} ground rules.", groundCount);
        long termCount = this.termGenerator.generateTerms(this.groundRuleStore, this.termStore);
        log.debug("Generated {} objective terms from {} ground rules.", termCount, groundCount);
    }

    public double inference() {
        return this.inference(true, false);
    }

    public double inference(boolean commitAtoms, boolean reset) {
        return this.inference(commitAtoms, reset, null, null);
    }

    public double inference(boolean commitAtoms, boolean reset, List<Evaluator> evaluators, Database truthDatabase) {
        if (reset) {
            this.initializeAtoms();
            if (this.termStore != null) {
                this.termStore.reset();
            }
        }
        if (this.skipInference) {
            log.info("Skipping inference.");
            return -1.0;
        }
        TrainingMap trainingMap = null;
        HashSet<StandardPredicate> evaluationPredicates = null;
        if (truthDatabase != null && evaluators.size() > 0) {
            trainingMap = new TrainingMap(this.atomManager, truthDatabase);
            evaluationPredicates = new HashSet<StandardPredicate>();
            for (StandardPredicate predicate : this.database.getDataStore().getRegisteredPredicates()) {
                if (truthDatabase.countAllGroundAtoms(predicate) <= 0) continue;
                evaluationPredicates.add(predicate);
            }
        }
        log.info("Beginning inference.");
        double objective = this.internalInference(evaluators, trainingMap, evaluationPredicates);
        log.info("Inference complete.");
        this.atomsCommitted = false;
        if (commitAtoms) {
            this.commit();
        }
        return objective;
    }

    protected double internalInference(List<Evaluator> evaluators, TrainingMap trainingMap, Set<StandardPredicate> evaluationPredicates) {
        return this.reasoner.optimize(this.termStore, evaluators, trainingMap, evaluationPredicates);
    }

    public Reasoner getReasoner() {
        return this.reasoner;
    }

    public GroundRuleStore getGroundRuleStore() {
        return this.groundRuleStore;
    }

    public TermStore getTermStore() {
        return this.termStore;
    }

    public PersistedAtomManager getAtomManager() {
        return this.atomManager;
    }

    public void setBudget(double budget) {
        this.reasoner.setBudget(budget);
    }

    public void initializeAtoms() {
        for (RandomVariableAtom atom : this.atomManager.getDatabase().getAllCachedRandomVariableAtoms()) {
            atom.setValue(this.initialValue.getVariableValue(atom));
        }
    }

    public void commit() {
        if (this.atomsCommitted) {
            return;
        }
        log.info("Writing results to Database.");
        this.atomManager.commitPersistedAtoms();
        log.info("Results committed to database.");
        this.atomsCommitted = true;
    }

    @Override
    public void close() {
        if (this.termStore != null) {
            this.termStore.close();
            this.termStore = null;
        }
        if (this.groundRuleStore != null) {
            this.groundRuleStore.close();
            this.groundRuleStore = null;
        }
        if (this.reasoner != null) {
            this.reasoner.close();
            this.reasoner = null;
        }
        this.rules = null;
        this.database = null;
    }

    protected void normalizeWeights() {
        float max = 0.0f;
        boolean hasWeightedRule = false;
        for (WeightedRule rule : IteratorUtils.filterClass(this.rules, WeightedRule.class)) {
            float weight = rule.getWeight();
            if (hasWeightedRule && !(weight > max)) continue;
            max = weight;
            hasWeightedRule = true;
        }
        if (!hasWeightedRule) {
            return;
        }
        for (WeightedRule rule : IteratorUtils.filterClass(this.rules, WeightedRule.class)) {
            float oldWeight = rule.getWeight();
            float newWeight = 1.0f;
            if (!MathUtils.isZero(max)) {
                newWeight = oldWeight / max;
            }
            log.debug("Normalizing rule weight (old weight: {}, new weight: {}): {}", Float.valueOf(oldWeight), Float.valueOf(newWeight), rule);
            rule.setWeight(newWeight);
        }
    }

    protected void relaxHardConstraints() {
        float largestWeight = 0.0f;
        boolean hasUnweightedRule = false;
        for (Rule rule : this.rules) {
            if (rule instanceof WeightedRule) {
                float weight = ((WeightedRule)rule).getWeight();
                if (!(weight > largestWeight)) continue;
                largestWeight = weight;
                continue;
            }
            hasUnweightedRule = true;
        }
        if (!hasUnweightedRule) {
            return;
        }
        float weight = Math.max(1.0f, largestWeight * this.relaxationMultiplier);
        for (int i = 0; i < this.rules.size(); ++i) {
            if (!(this.rules.get(i) instanceof UnweightedRule)) continue;
            log.debug("Relaxing hard constraint (weight: {}, squared: {}): {}", Float.valueOf(weight), this.relaxationSquared, this.rules.get(i));
            this.rules.set(i, ((UnweightedRule)this.rules.get(i)).relax(weight, this.relaxationSquared));
        }
    }

    public static InferenceApplication getInferenceApplication(String className, List<Rule> rules, Database database) {
        className = Reflection.resolveClassName(className);
        Class<?> classObject = null;
        try {
            Class<?> uncheckedClassObject;
            classObject = uncheckedClassObject = Class.forName(className);
        }
        catch (ClassNotFoundException ex) {
            throw new IllegalArgumentException("Could not find class: " + className, ex);
        }
        Constructor<?> constructor = null;
        try {
            constructor = classObject.getConstructor(List.class, Database.class);
        }
        catch (NoSuchMethodException ex) {
            throw new IllegalArgumentException("No suitable constructor (List<Rules>, Database) found for inference application: " + className + ".", ex);
        }
        InferenceApplication inferenceApplication = null;
        try {
            inferenceApplication = (InferenceApplication)constructor.newInstance(rules, database);
        }
        catch (InstantiationException ex) {
            throw new RuntimeException("Unable to instantiate inference application (" + className + ")", ex);
        }
        catch (IllegalAccessException ex) {
            throw new RuntimeException("Insufficient access to constructor for " + className, ex);
        }
        catch (InvocationTargetException ex) {
            throw new RuntimeException("Error thrown while constructing " + className, ex);
        }
        return inferenceApplication;
    }
}

