/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search.grid;

import java.util.List;
import org.linqs.psl.application.learning.weight.search.grid.BaseGridSearch;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GridSearch
extends BaseGridSearch {
    private static final Logger log = LoggerFactory.getLogger(GridSearch.class);
    public static final String CONFIG_PREFIX = "gridsearch";
    public static final String POSSIBLE_WEIGHTS_KEY = "gridsearch.weights";
    public static final String POSSIBLE_WEIGHTS_DEFAULT = "0.001:0.01:0.1:1:10";
    public static final String DELIM = ":";
    protected final double[] possibleWeights = StringUtils.splitDouble(Config.getString("gridsearch.weights", "0.001:0.01:0.1:1:10"), ":");

    public GridSearch(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    public GridSearch(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB);
        if (this.possibleWeights.length == 0) {
            throw new IllegalArgumentException("No weights provided for grid search.");
        }
        this.numLocations = this.maxNumLocations = (int)Math.pow(this.possibleWeights.length, this.mutableRules.size());
    }

    @Override
    protected void getWeights(double[] weights) {
        int[] indexes = StringUtils.splitInt(this.currentLocation, DELIM);
        assert (indexes.length == this.mutableRules.size());
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            weights[i] = this.possibleWeights[indexes[i]];
        }
    }

    @Override
    protected boolean chooseNextLocation() {
        if (this.currentLocation == null) {
            this.currentLocation = StringUtils.join(DELIM, new int[this.mutableRules.size()]);
            return true;
        }
        int[] indexes = StringUtils.splitInt(this.currentLocation, DELIM);
        assert (indexes.length == this.mutableRules.size());
        for (int i = this.mutableRules.size() - 1; i >= 0; --i) {
            int n = i;
            indexes[n] = indexes[n] + 1;
            if (indexes[i] != this.possibleWeights.length) break;
            indexes[i] = 0;
        }
        this.currentLocation = StringUtils.join(DELIM, indexes);
        return true;
    }
}

