/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.bool;

import java.util.Collection;
import java.util.HashSet;
import org.linqs.psl.config.Config;
import org.linqs.psl.grounding.GroundRules;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTerm;
import org.linqs.psl.reasoner.term.blocker.ConstraintBlockerTermStore;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.RandUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BooleanMaxWalkSat
implements Reasoner {
    private static final Logger log = LoggerFactory.getLogger(BooleanMaxWalkSat.class);
    public static final String CONFIG_PREFIX = "booleanmaxwalksat";
    public static final String MAX_FLIPS_KEY = "booleanmaxwalksat.maxflips";
    public static final int MAX_FLIPS_DEFAULT = 50000;
    public static final String NOISE_KEY = "booleanmaxwalksat.noise";
    public static final double NOISE_DEFAULT = 0.01;
    private final int maxFlips = Config.getInt("booleanmaxwalksat.maxflips", 50000);
    private final double noise;

    public BooleanMaxWalkSat() {
        if (this.maxFlips <= 0) {
            throw new IllegalArgumentException("Max flips must be positive.");
        }
        this.noise = Config.getDouble(NOISE_KEY, 0.01);
        if (this.noise < 0.0 || this.noise > 1.0) {
            throw new IllegalArgumentException("Noise must be in [0,1].");
        }
    }

    @Override
    public void optimize(TermStore termStore) {
        if (!(termStore instanceof ConstraintBlockerTermStore)) {
            throw new IllegalArgumentException("ConstraintBlockerTermStore required.");
        }
        ConstraintBlockerTermStore blocker = (ConstraintBlockerTermStore)termStore;
        blocker.randomlyInitialize();
        HashSet<GroundRule> unsatGKs = new HashSet<GroundRule>();
        HashSet<RandomVariableAtom> rvasToInclude = new HashSet<RandomVariableAtom>();
        HashSet<ConstraintBlockerTerm> blocksToInclude = new HashSet<ConstraintBlockerTerm>();
        for (GroundRule groundRule : blocker.getGroundRuleStore().getGroundRules()) {
            if (!(groundRule instanceof WeightedGroundRule) || !(((WeightedGroundRule)groundRule).getIncompatibility() > 0.0)) continue;
            unsatGKs.add(groundRule);
        }
        for (int flip = 0; flip < this.maxFlips; ++flip) {
            int positiveRVAIndex;
            int blockToChange;
            GroundRule groundRule;
            if (unsatGKs.size() == 0) {
                return;
            }
            groundRule = (GroundRule)this.selectAtRandom(unsatGKs);
            rvasToInclude.clear();
            blocksToInclude.clear();
            for (GroundAtom atom : groundRule.getAtoms()) {
                int blockIndex;
                if (!(atom instanceof RandomVariableAtom) || (blockIndex = blocker.getBlockIndex((RandomVariableAtom)atom)) == -1) continue;
                rvasToInclude.add((RandomVariableAtom)atom);
                blocksToInclude.add(blocker.get(blockIndex));
            }
            if (blocksToInclude.size() == 0) {
                --flip;
                continue;
            }
            RandomVariableAtom[][] candidateRVBlocks = new RandomVariableAtom[blocksToInclude.size()][];
            WeightedGroundRule[][] candidateIncidentGKs = new WeightedGroundRule[blocksToInclude.size()][];
            boolean[] candidateExactlyOne = new boolean[blocksToInclude.size()];
            int candidateBlockIndex = 0;
            for (ConstraintBlockerTerm block : blocksToInclude) {
                candidateRVBlocks[candidateBlockIndex] = block.getAtoms();
                candidateExactlyOne[candidateBlockIndex] = block.getExactlyOne();
                candidateIncidentGKs[candidateBlockIndex] = block.getIncidentGRs();
                ++candidateBlockIndex;
            }
            if (RandUtils.nextDouble() <= this.noise) {
                blockToChange = RandUtils.nextInt(candidateRVBlocks.length);
                int blockSize = candidateRVBlocks[blockToChange].length;
                do {
                    positiveRVAIndex = RandUtils.nextInt(blockSize);
                } while (candidateExactlyOne[blockToChange] && (double)candidateRVBlocks[blockToChange][positiveRVAIndex].getValue() == 1.0);
                if ((double)candidateRVBlocks[blockToChange][positiveRVAIndex].getValue() == 1.0) {
                    positiveRVAIndex = -1;
                }
            } else {
                blockToChange = -1;
                positiveRVAIndex = -1;
                double bestIncompatibility = Double.POSITIVE_INFINITY;
                for (int blockIndex = 0; blockIndex < candidateRVBlocks.length; ++blockIndex) {
                    float[] savedState = new float[candidateRVBlocks[blockIndex].length];
                    float savedStateTotal = 0.0f;
                    for (int i = 0; i < candidateRVBlocks[blockIndex].length; ++i) {
                        savedState[i] = candidateRVBlocks[blockIndex][i].getValue();
                        savedStateTotal += savedState[i];
                    }
                    int lastRVIndex = candidateRVBlocks[blockIndex].length;
                    if (!candidateExactlyOne[blockIndex]) {
                        ++lastRVIndex;
                    }
                    for (int currentPositiveRVA = 0; currentPositiveRVA < lastRVIndex; ++currentPositiveRVA) {
                        for (int i = 0; i < candidateRVBlocks[blockIndex].length; ++i) {
                            if (i == currentPositiveRVA) {
                                candidateRVBlocks[blockIndex][i].setValue(1.0f);
                                continue;
                            }
                            candidateRVBlocks[blockIndex][i].setValue(0.0f);
                        }
                        double currentIncompatibility = 0.0;
                        for (WeightedGroundRule incidentGK : candidateIncidentGKs[blockIndex]) {
                            currentIncompatibility += incidentGK.getWeight() * incidentGK.getIncompatibility();
                        }
                        if (!(currentIncompatibility < bestIncompatibility)) continue;
                        bestIncompatibility = currentIncompatibility;
                        blockToChange = blockIndex;
                        positiveRVAIndex = currentPositiveRVA;
                        if (MathUtils.isZero(bestIncompatibility)) break;
                    }
                    for (int i = 0; i < candidateRVBlocks[blockIndex].length; ++i) {
                        candidateRVBlocks[blockIndex][i].setValue(savedState[i]);
                    }
                    if (MathUtils.isZero(bestIncompatibility)) break;
                }
            }
            for (int i = 0; i < candidateRVBlocks[blockToChange].length; ++i) {
                if (i == positiveRVAIndex) {
                    candidateRVBlocks[blockToChange][i].setValue(1.0f);
                    continue;
                }
                candidateRVBlocks[blockToChange][i].setValue(0.0f);
            }
            for (WeightedGroundRule incidentGK : candidateIncidentGKs[blockToChange]) {
                if (incidentGK.getIncompatibility() > 0.0) {
                    unsatGKs.add(incidentGK);
                    continue;
                }
                unsatGKs.remove(incidentGK);
            }
            if (flip % 5000 != 0) continue;
            log.info("Flip {}, Total weighted incompatibility: {}, Infeasbility norm: {}", flip, GroundRules.getTotalWeightedIncompatibility(blocker.getGroundRuleStore().getCompatibilityRules()), GroundRules.getInfeasibilityNorm(blocker.getGroundRuleStore().getConstraintRules()));
        }
    }

    private Object selectAtRandom(Collection<? extends Object> collection) {
        int i = 0;
        int selection = RandUtils.nextInt(collection.size());
        for (Object object : collection) {
            if (i++ != selection) continue;
            return object;
        }
        return null;
    }

    @Override
    public void close() {
    }
}

