/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search.grid;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseGridSearch
extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(BaseGridSearch.class);
    public static final String CONFIG_PREFIX = "basegridsearch";
    protected String currentLocation = null;
    protected int maxNumLocations;
    protected int numLocations = this.maxNumLocations = 0;
    protected Map<String, Double> objectives = new HashMap<String, Double>();

    public BaseGridSearch(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    public BaseGridSearch(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB, false);
    }

    @Override
    protected void doLearn() {
        double bestObjective = -1.0;
        double[] bestWeights = new double[this.mutableRules.size()];
        this.computeObservedIncompatibility();
        double[] weights = new double[this.mutableRules.size()];
        for (int iteration = 0; iteration < this.numLocations; ++iteration) {
            if (!this.chooseNextLocation()) {
                log.debug("Stopping search.");
                break;
            }
            log.debug("Iteration {} / {} ({}) -- Inspecting location {}", iteration, this.numLocations, this.maxNumLocations, this.currentLocation);
            this.getWeights(weights);
            for (int i = 0; i < this.mutableRules.size(); ++i) {
                ((WeightedRule)this.mutableRules.get(i)).setWeight(weights[i]);
            }
            log.trace("Weights: {}", (Object)weights);
            this.inMPEState = false;
            this.inLatentMPEState = false;
            double objective = this.inspectLocation(weights);
            this.objectives.put(this.currentLocation, new Double(objective));
            if (iteration == 0 || objective < bestObjective) {
                bestObjective = objective;
                for (int i = 0; i < this.mutableRules.size(); ++i) {
                    bestWeights[i] = weights[i];
                }
            }
            log.debug("Location {} -- objective: {}", (Object)this.currentLocation, (Object)objective);
        }
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            ((WeightedRule)this.mutableRules.get(i)).setWeight(bestWeights[i]);
        }
        this.inMPEState = false;
        this.inLatentMPEState = false;
    }

    protected double inspectLocation(double[] weights) {
        this.setDefaultRandomVariables();
        if (this.termStore instanceof ADMMTermStore) {
            ((ADMMTermStore)this.termStore).resetLocalVairables();
        }
        this.computeExpectedIncompatibility();
        this.evaluator.compute(this.trainingMap);
        double score = this.evaluator.getRepresentativeMetric();
        score = this.evaluator.isHigherRepresentativeBetter() ? -1.0 * score : score;
        return score;
    }

    protected abstract void getWeights(double[] var1);

    protected abstract boolean chooseNextLocation();
}

