/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.model.rule.arithmetic.expression;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.formula.Conjunction;
import org.linqs.psl.model.formula.Formula;
import org.linqs.psl.model.rule.arithmetic.expression.SummationAtom;
import org.linqs.psl.model.rule.arithmetic.expression.SummationAtomOrAtom;
import org.linqs.psl.model.rule.arithmetic.expression.SummationVariable;
import org.linqs.psl.model.rule.arithmetic.expression.SummationVariableOrTerm;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Cardinality;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.Coefficient;
import org.linqs.psl.model.rule.arithmetic.expression.coefficient.ConstantNumber;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.util.HashCode;
import org.linqs.psl.util.MathUtils;

public class ArithmeticRuleExpression {
    protected final List<Coefficient> coefficients;
    protected final List<SummationAtomOrAtom> atoms;
    protected final FunctionComparator comparator;
    protected final Coefficient constant;
    protected final Set<Variable> vars;
    protected final Map<SummationVariable, SummationAtom> summationMapping;
    private int hash;

    public ArithmeticRuleExpression(List<Coefficient> coefficients, List<SummationAtomOrAtom> atoms, FunctionComparator comparator, Coefficient constant) {
        this(coefficients, atoms, comparator, constant, false);
    }

    public ArithmeticRuleExpression(List<Coefficient> coefficients, List<SummationAtomOrAtom> atoms, FunctionComparator comparator, Coefficient constant, boolean skipCardinalityValidation) {
        this.coefficients = Collections.unmodifiableList(coefficients);
        this.atoms = Collections.unmodifiableList(atoms);
        this.comparator = comparator;
        this.constant = constant;
        HashSet<Variable> vars = new HashSet<Variable>();
        HashSet<String> sumVarNames = new HashSet<String>();
        HashMap<SummationVariable, SummationAtom> summationMapping = new HashMap<SummationVariable, SummationAtom>();
        if (atoms.size() == 0) {
            throw new IllegalArgumentException("Cannot have an arithmetic rule without atoms.");
        }
        for (SummationAtomOrAtom atom : this.getAtoms()) {
            if (atom instanceof SummationAtom) {
                for (SummationVariableOrTerm summationVariableOrTerm : ((SummationAtom)atom).getArguments()) {
                    if (summationVariableOrTerm instanceof Variable) {
                        vars.add((Variable)summationVariableOrTerm);
                        continue;
                    }
                    if (!(summationVariableOrTerm instanceof SummationVariable)) continue;
                    if (summationMapping.containsKey((SummationVariable)summationVariableOrTerm)) {
                        throw new IllegalArgumentException("Each summation variable in an ArithmeticRuleExpression must be unique.");
                    }
                    sumVarNames.add(((SummationVariable)summationVariableOrTerm).getVariable().getName());
                    summationMapping.put((SummationVariable)summationVariableOrTerm, (SummationAtom)atom);
                }
                continue;
            }
            for (SummationVariableOrTerm summationVariableOrTerm : ((Atom)atom).getArguments()) {
                if (!(summationVariableOrTerm instanceof Variable)) continue;
                vars.add((Variable)summationVariableOrTerm);
            }
        }
        for (Variable var : vars) {
            if (!sumVarNames.contains(var.getName())) continue;
            throw new IllegalArgumentException(String.format("Summation variable (+%s) cannot be used as a normal variable (%s).", var.getName(), var.getName()));
        }
        if (!skipCardinalityValidation) {
            for (Coefficient coefficient : coefficients) {
                String name;
                if (!(coefficient instanceof Cardinality) || sumVarNames.contains(name = ((Cardinality)coefficient).getSummationVariable().getVariable().getName())) continue;
                throw new IllegalArgumentException(String.format("Cannot use variable (%s) in cardinality. Only summation variables can be used in cardinality.", name));
            }
        }
        this.vars = Collections.unmodifiableSet(vars);
        this.summationMapping = Collections.unmodifiableMap(summationMapping);
        this.hash = HashCode.build(HashCode.build((Object)comparator), constant);
        for (Coefficient coefficient : coefficients) {
            this.hash = HashCode.build(this.hash, coefficient);
        }
        for (SummationAtomOrAtom atom : atoms) {
            this.hash = HashCode.build(this.hash, atom);
        }
    }

    public int hashCode() {
        return this.hash;
    }

    public List<Coefficient> getAtomCoefficients() {
        return this.coefficients;
    }

    public List<SummationAtomOrAtom> getAtoms() {
        return this.atoms;
    }

    public FunctionComparator getComparator() {
        return this.comparator;
    }

    public Coefficient getFinalCoefficient() {
        return this.constant;
    }

    public Set<Variable> getVariables() {
        return this.vars;
    }

    public Set<SummationVariable> getSummationVariables() {
        return this.summationMapping.keySet();
    }

    public Map<SummationVariable, SummationAtom> getSummationMapping() {
        return this.summationMapping;
    }

    public boolean looksLikeNegativePrior() {
        return this.summationMapping.size() == 0 && this.atoms.size() == 1 && FunctionComparator.EQ.equals((Object)this.comparator) && this.constant instanceof ConstantNumber && MathUtils.isZero(this.constant.getValue(null));
    }

    public Formula getQueryFormula() {
        ArrayList<Atom> queryAtoms = new ArrayList<Atom>();
        for (SummationAtomOrAtom atom : this.atoms) {
            if (atom instanceof SummationAtom) {
                queryAtoms.add(((SummationAtom)atom).getQueryAtom());
                continue;
            }
            queryAtoms.add((Atom)atom);
        }
        if (queryAtoms.size() == 1) {
            return (Formula)queryAtoms.get(0);
        }
        return new Conjunction(queryAtoms.toArray(new Formula[0]));
    }

    public String toString() {
        StringBuilder s = new StringBuilder();
        if (this.coefficients.size() > 0) {
            for (int i = 0; i < this.coefficients.size(); ++i) {
                if (i != 0) {
                    s.append(" + ");
                }
                s.append(this.coefficients.get(i));
                s.append(" * ");
                s.append(this.atoms.get(i));
            }
        } else {
            s.append("0.0");
        }
        s.append(" ");
        s.append((Object)this.comparator);
        s.append(" ");
        s.append(this.constant);
        return s.toString();
    }

    public boolean equals(Object other) {
        if (this == other) {
            return true;
        }
        if (other == null || this.getClass() != other.getClass()) {
            return false;
        }
        ArithmeticRuleExpression otherExpression = (ArithmeticRuleExpression)other;
        if (this.hash != otherExpression.hash) {
            return false;
        }
        if (this.comparator != otherExpression.comparator || this.constant != otherExpression.constant) {
            return false;
        }
        if (this.atoms.size() != otherExpression.atoms.size()) {
            return false;
        }
        for (int thisIndex = 0; thisIndex < this.atoms.size(); ++thisIndex) {
            int otherIndex = otherExpression.atoms.indexOf(this.atoms.get(thisIndex));
            if (otherIndex == -1) {
                return false;
            }
            if (this.atoms.get(thisIndex).equals(otherExpression.atoms.get(otherIndex)) && this.coefficients.get(thisIndex).equals(otherExpression.coefficients.get(otherIndex))) continue;
            return false;
        }
        return true;
    }
}

