/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.term;

import java.util.HashSet;
import org.linqs.psl.config.Config;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.UnweightedGroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.function.ConstraintTerm;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.reasoner.function.FunctionTerm;
import org.linqs.psl.reasoner.function.GeneralFunction;
import org.linqs.psl.reasoner.term.Hyperplane;
import org.linqs.psl.reasoner.term.ReasonerLocalVariable;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.TermGenerator;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class HyperplaneTermGenerator<T extends ReasonerTerm, V extends ReasonerLocalVariable>
implements TermGenerator<T, V> {
    private static final Logger log = LoggerFactory.getLogger(HyperplaneTermGenerator.class);
    public static final String CONFIG_PREFIX = "hyperplanetermgenerator";
    public static final String INVERT_NEGATIVE_WEIGHTS_KEY = "hyperplanetermgenerator.invertnegativeweights";
    public static final boolean INVERT_NEGATIVE_WEIGHTS_DEFAULT = false;
    private boolean invertNegativeWeight = Config.getBoolean("hyperplanetermgenerator.invertnegativeweights", false);

    @Override
    public int generateTerms(GroundRuleStore ruleStore, final TermStore<T, V> termStore) {
        int initialSize = termStore.size();
        termStore.ensureCapacity(initialSize + ruleStore.size());
        HashSet<WeightedRule> rules = new HashSet<WeightedRule>();
        for (GroundRule groundRule : ruleStore.getGroundRules()) {
            if (!(groundRule instanceof WeightedGroundRule)) continue;
            rules.add((WeightedRule)groundRule.getRule());
        }
        for (WeightedRule weightedRule : rules) {
            if (!(weightedRule.getWeight() < 0.0)) continue;
            log.warn("Found a rule with a negative weight, but config says not to invert it... skipping: " + weightedRule);
        }
        Parallel.foreach(ruleStore.getGroundRules(), new Parallel.Worker<GroundRule>(){

            @Override
            public void work(int index, GroundRule rule) {
                boolean negativeWeight;
                boolean bl = negativeWeight = rule instanceof WeightedGroundRule && ((WeightedGroundRule)rule).getWeight() < 0.0;
                if (negativeWeight) {
                    if (!HyperplaneTermGenerator.this.invertNegativeWeight) {
                        return;
                    }
                    for (GroundRule negatedRule : rule.negate()) {
                        Object term = HyperplaneTermGenerator.this.createTerm(negatedRule, termStore);
                        if (term == null || term.size() <= 0) continue;
                        termStore.add(rule, term);
                    }
                } else {
                    Object term = HyperplaneTermGenerator.this.createTerm(rule, termStore);
                    if (term != null && term.size() > 0) {
                        termStore.add(rule, term);
                    }
                }
            }
        });
        return termStore.size() - initialSize;
    }

    public T createTerm(GroundRule groundRule, TermStore<T, V> termStore) {
        if (groundRule instanceof WeightedGroundRule) {
            GeneralFunction function = ((WeightedGroundRule)groundRule).getFunctionDefinition();
            Hyperplane<V> hyperplane = this.processHyperplane(function, termStore);
            if (hyperplane == null) {
                return null;
            }
            return this.createLossTerm(termStore, function.isNonNegative(), function.isSquared(), groundRule, hyperplane);
        }
        if (groundRule instanceof UnweightedGroundRule) {
            ConstraintTerm constraint = ((UnweightedGroundRule)groundRule).getConstraintDefinition();
            GeneralFunction function = constraint.getFunction();
            Hyperplane<V> hyperplane = this.processHyperplane(function, termStore);
            if (hyperplane == null) {
                return null;
            }
            hyperplane.setConstant(constraint.getValue() + hyperplane.getConstant());
            return this.createLinearConstraintTerm(termStore, groundRule, hyperplane, constraint.getComparator());
        }
        throw new IllegalArgumentException("Unsupported ground rule: " + groundRule);
    }

    private Hyperplane<V> processHyperplane(GeneralFunction sum, TermStore<T, V> termStore) {
        Hyperplane<V> hyperplane = new Hyperplane<V>(this.getLocalVariableType(), sum.size(), -1.0f * sum.getConstant());
        for (int i = 0; i < sum.size(); ++i) {
            float coefficient = sum.getCoefficient(i);
            FunctionTerm term = sum.getTerm(i);
            if (term instanceof RandomVariableAtom) {
                V variable = termStore.createLocalVariable((RandomVariableAtom)term);
                int localIndex = hyperplane.indexOfVariable(variable);
                if (localIndex != -1) {
                    if (sum.isNonNegative() && !MathUtils.signsMatch(hyperplane.getCoefficient(localIndex), coefficient)) {
                        return null;
                    }
                    hyperplane.appendCoefficient(localIndex, coefficient);
                    continue;
                }
                hyperplane.addTerm(variable, coefficient);
                continue;
            }
            if (term.isConstant()) {
                hyperplane.setConstant(hyperplane.getConstant() - coefficient * term.getValue());
                continue;
            }
            throw new IllegalArgumentException("Unexpected summand: " + sum + "[" + i + "] (" + term + ").");
        }
        return hyperplane;
    }

    public abstract Class<V> getLocalVariableType();

    public abstract T createLossTerm(TermStore<T, V> var1, boolean var2, boolean var3, GroundRule var4, Hyperplane<V> var5);

    public abstract T createLinearConstraintTerm(TermStore<T, V> var1, GroundRule var2, Hyperplane<V> var3, FunctionComparator var4);
}

