/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search;

import java.util.List;
import org.linqs.psl.application.learning.weight.VotedPerceptron;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.maxlikelihood.MaxLikelihoodMPE;
import org.linqs.psl.application.learning.weight.search.Hyperband;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InitialWeightHyperband
extends Hyperband {
    private static final Logger log = LoggerFactory.getLogger(InitialWeightHyperband.class);
    public static final String CONFIG_PREFIX = "initialweighthyperband";
    public static final String INTERNAL_WLA_KEY = "initialweighthyperband.internalwla";
    public static final String INTERNAL_WLA_DEFAULT = MaxLikelihoodMPE.class.getName();
    private VotedPerceptron internalWLA;

    public InitialWeightHyperband(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    public InitialWeightHyperband(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB);
        String wlaName = Config.getString(INTERNAL_WLA_KEY, INTERNAL_WLA_DEFAULT);
        this.internalWLA = (VotedPerceptron)WeightLearningApplication.getWLA(wlaName, rules, rvDB, observedDB);
    }

    @Override
    protected void postInitGroundModel() {
        super.postInitGroundModel();
        this.internalWLA.initGroundModel(this.reasoner, this.groundRuleStore, this.termStore, this.termGenerator, this.atomManager, this.trainingMap);
    }

    @Override
    public void setBudget(double budget) {
        this.internalWLA.setBudget(budget);
        super.setBudget(budget);
    }

    @Override
    protected double run(double[] weights) {
        this.internalWLA.learn();
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            weights[i] = ((WeightedRule)this.mutableRules.get(i)).getWeight();
        }
        return super.run(weights);
    }

    @Override
    public void close() {
        super.close();
        this.internalWLA.close();
    }
}

