/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.inference;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import org.linqs.psl.application.ModelApplication;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.AtomManager;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.grounding.Grounding;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.UnweightedRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.InitialValue;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.term.TermGenerator;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Reflection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class InferenceApplication
implements ModelApplication {
    private static final Logger log = LoggerFactory.getLogger(InferenceApplication.class);
    protected List<Rule> rules;
    protected Database db;
    protected Reasoner reasoner;
    protected InitialValue initialValue;
    protected boolean normalizeWeights;
    protected boolean relaxHardConstraints;
    protected double relaxationMultiplier;
    protected boolean relaxationSquared;
    protected GroundRuleStore groundRuleStore;
    protected TermStore termStore;
    protected TermGenerator termGenerator;
    protected PersistedAtomManager atomManager;
    private boolean atomsCommitted;

    protected InferenceApplication(List<Rule> rules, Database db) {
        this(rules, db, Options.INFERENCE_RELAX.getBoolean());
    }

    protected InferenceApplication(List<Rule> rules, Database db, boolean relaxHardConstraints) {
        this.rules = new ArrayList<Rule>(rules);
        this.db = db;
        this.atomsCommitted = false;
        this.initialValue = InitialValue.valueOf(Options.INFERENCE_INITIAL_VARIABLE_VALUE.getString());
        this.normalizeWeights = Options.INFERENCE_NORMALIZE_WEIGHTS.getBoolean();
        this.relaxHardConstraints = relaxHardConstraints;
        this.relaxationMultiplier = Options.INFERENCE_RELAX_MULTIPLIER.getDouble();
        this.relaxationSquared = Options.INFERENCE_RELAX_SQUARED.getBoolean();
        this.initialize();
    }

    protected void initialize() {
        log.debug("Creating persisted atom mannager.");
        this.atomManager = this.createAtomManager(this.db);
        log.debug("Atom manager initialization complete.");
        this.initializeAtoms();
        this.reasoner = this.createReasoner();
        this.termStore = this.createTermStore();
        this.groundRuleStore = this.createGroundRuleStore();
        this.termGenerator = this.createTermGenerator();
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount());
        if (this.normalizeWeights) {
            this.normalizeWeights();
        }
        if (this.relaxHardConstraints) {
            this.relaxHardConstraints();
        }
        this.completeInitialize();
    }

    protected PersistedAtomManager createAtomManager(Database db) {
        return new PersistedAtomManager(db, false, this.initialValue);
    }

    protected GroundRuleStore createGroundRuleStore() {
        return (GroundRuleStore)Options.INFERENCE_GRS.getNewObject();
    }

    protected Reasoner createReasoner() {
        return (Reasoner)Options.INFERENCE_REASONER.getNewObject();
    }

    protected TermGenerator createTermGenerator() {
        return (TermGenerator)Options.INFERENCE_TG.getNewObject();
    }

    protected TermStore createTermStore() {
        return (TermStore)Options.INFERENCE_TS.getNewObject();
    }

    protected void completeInitialize() {
        log.info("Grounding out model.");
        int groundCount = Grounding.groundAll(this.rules, (AtomManager)this.atomManager, this.groundRuleStore);
        log.info("Grounding complete.");
        log.debug("Initializing objective terms for {} ground rules.", (Object)groundCount);
        int termCount = this.termGenerator.generateTerms(this.groundRuleStore, this.termStore);
        log.debug("Generated {} objective terms from {} ground rules.", (Object)termCount, (Object)groundCount);
    }

    public void inference() {
        this.inference(true, false);
    }

    public void inference(boolean commitAtoms, boolean reset) {
        if (reset) {
            this.initializeAtoms();
            if (this.termStore != null) {
                this.termStore.reset();
            }
        }
        log.info("Beginning inference.");
        this.internalInference();
        log.info("Inference complete.");
        this.atomsCommitted = false;
        if (commitAtoms) {
            this.commit();
        }
    }

    protected void internalInference() {
        this.reasoner.optimize(this.termStore);
    }

    public Reasoner getReasoner() {
        return this.reasoner;
    }

    public GroundRuleStore getGroundRuleStore() {
        return this.groundRuleStore;
    }

    public TermStore getTermStore() {
        return this.termStore;
    }

    public PersistedAtomManager getAtomManager() {
        return this.atomManager;
    }

    public void setBudget(double budget) {
        this.reasoner.setBudget(budget);
    }

    public void initializeAtoms() {
        for (RandomVariableAtom atom : this.atomManager.getDatabase().getAllCachedRandomVariableAtoms()) {
            atom.setValue(this.initialValue.getVariableValue(atom));
        }
    }

    public void commit() {
        if (this.atomsCommitted) {
            return;
        }
        log.info("Writing results to Database.");
        this.atomManager.commitPersistedAtoms();
        log.info("Results committed to database.");
        this.atomsCommitted = true;
    }

    @Override
    public void close() {
        if (this.termStore != null) {
            this.termStore.close();
            this.termStore = null;
        }
        if (this.groundRuleStore != null) {
            this.groundRuleStore.close();
            this.groundRuleStore = null;
        }
        if (this.reasoner != null) {
            this.reasoner.close();
            this.reasoner = null;
        }
        this.rules = null;
        this.db = null;
    }

    protected void normalizeWeights() {
        double max = 0.0;
        boolean hasWeightedRule = false;
        for (WeightedRule rule : IteratorUtils.filterClass(this.rules, WeightedRule.class)) {
            double weight = rule.getWeight();
            if (hasWeightedRule && !(weight > max)) continue;
            max = weight;
            hasWeightedRule = true;
        }
        if (!hasWeightedRule) {
            return;
        }
        for (WeightedRule rule : IteratorUtils.filterClass(this.rules, WeightedRule.class)) {
            double oldWeight = rule.getWeight();
            double newWeight = 1.0;
            if (!MathUtils.isZero(max)) {
                newWeight = oldWeight / max;
            }
            log.debug("Normalizing rule weight (old weight: {}, new weight: {}): {}", oldWeight, newWeight, rule);
            rule.setWeight(newWeight);
        }
    }

    protected void relaxHardConstraints() {
        double largestWeight = 0.0;
        boolean hasUnweightedRule = false;
        for (Rule rule : this.rules) {
            if (rule instanceof WeightedRule) {
                double weight = ((WeightedRule)rule).getWeight();
                if (!(weight > largestWeight)) continue;
                largestWeight = weight;
                continue;
            }
            hasUnweightedRule = true;
        }
        if (!hasUnweightedRule) {
            return;
        }
        double weight = Math.max(1.0, largestWeight * this.relaxationMultiplier);
        for (int i = 0; i < this.rules.size(); ++i) {
            if (!(this.rules.get(i) instanceof UnweightedRule)) continue;
            log.debug("Relaxing hard constraint (weight: {}, squared: {}): {}", weight, this.relaxationSquared, this.rules.get(i));
            this.rules.set(i, ((UnweightedRule)this.rules.get(i)).relax(weight, this.relaxationSquared));
        }
    }

    public static InferenceApplication getInferenceApplication(String className, List<Rule> rules, Database db) {
        className = Reflection.resolveClassName(className);
        Class<?> classObject = null;
        try {
            Class<?> uncheckedClassObject;
            classObject = uncheckedClassObject = Class.forName(className);
        }
        catch (ClassNotFoundException ex) {
            throw new IllegalArgumentException("Could not find class: " + className, ex);
        }
        Constructor<?> constructor = null;
        try {
            constructor = classObject.getConstructor(List.class, Database.class);
        }
        catch (NoSuchMethodException ex) {
            throw new IllegalArgumentException("No sutible constructor (List<Rules>, Database) found for inference application: " + className + ".", ex);
        }
        InferenceApplication inferenceApplication = null;
        try {
            inferenceApplication = (InferenceApplication)constructor.newInstance(rules, db);
        }
        catch (InstantiationException ex) {
            throw new RuntimeException("Unable to instantiate inference application (" + className + ")", ex);
        }
        catch (IllegalAccessException ex) {
            throw new RuntimeException("Insufficient access to constructor for " + className, ex);
        }
        catch (InvocationTargetException ex) {
            throw new RuntimeException("Error thrown while constructing " + className, ex);
        }
        return inferenceApplication;
    }
}

