/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.term;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.config.Options;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.model.ModelPredicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.reasoner.term.Hyperplane;
import org.linqs.psl.reasoner.term.MemoryTermStore;
import org.linqs.psl.reasoner.term.ReasonerLocalVariable;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.util.Logger;

public abstract class MemoryVariableTermStore<T extends ReasonerTerm, V extends ReasonerLocalVariable>
implements VariableTermStore<T, V> {
    private static final Logger log = Logger.getLogger(MemoryVariableTermStore.class);
    private MemoryTermStore<T> store;
    private Map<V, Integer> variables;
    private float[] variableValues;
    private RandomVariableAtom[] variableAtoms;
    private boolean shuffle = Options.MEMORY_VTS_SHUFFLE.getBoolean();
    private int defaultSize = Options.MEMORY_VTS_DEFAULT_SIZE.getInt();
    private Set<ModelPredicate> modelPredicates;
    private Map<RandomVariableAtom, List<MirrorTermCoefficient>> mirrorVariables;
    private boolean variablesExternallyUpdatedFlag;

    public MemoryVariableTermStore() {
        this.store = new MemoryTermStore();
        this.ensureVariableCapacity(this.defaultSize);
        this.modelPredicates = new HashSet<ModelPredicate>();
        this.mirrorVariables = new HashMap<RandomVariableAtom, List<MirrorTermCoefficient>>();
        this.variablesExternallyUpdatedFlag = false;
    }

    @Override
    public int getVariableIndex(V variable) {
        return this.variables.get(variable);
    }

    @Override
    public float getVariableValue(int index) {
        return this.variableValues[index];
    }

    @Override
    public float[] getVariableValues() {
        return this.variableValues;
    }

    @Override
    public double syncAtoms() {
        double movement = 0.0;
        for (int i = 0; i < this.variables.size(); ++i) {
            movement += Math.pow(this.variableAtoms[i].getValue() - this.variableValues[i], 2.0);
            this.variableAtoms[i].setValue(this.variableValues[i]);
        }
        return Math.sqrt(movement);
    }

    @Override
    public GroundAtom[] getVariableAtoms() {
        return this.variableAtoms;
    }

    @Override
    public int getNumVariables() {
        return this.variables.size();
    }

    @Override
    public int getNumRandomVariables() {
        return this.getNumVariables();
    }

    @Override
    public int getNumObservedVariables() {
        return 0;
    }

    @Override
    public boolean isLoaded() {
        return true;
    }

    @Override
    public synchronized V createLocalVariable(GroundAtom groundAtom) {
        if (!(groundAtom instanceof RandomVariableAtom)) {
            throw new IllegalArgumentException("MemoryVariableTermStores do not keep track of observed atoms (" + groundAtom + ").");
        }
        RandomVariableAtom atom = (RandomVariableAtom)groundAtom;
        V variable = this.convertAtomToVariable(atom);
        if (this.variables.containsKey(variable)) {
            return variable;
        }
        if (this.variables.size() >= this.variableAtoms.length) {
            this.ensureVariableCapacity(this.variables.size() * 2);
        }
        int index = this.variables.size();
        this.variables.put((Integer)variable, index);
        this.variableValues[index] = atom.getValue();
        this.variableAtoms[index] = atom;
        return variable;
    }

    private synchronized void createMirrorVariable(RandomVariableAtom atom, float coefficient, T term) {
        if (atom.getPredicate() instanceof ModelPredicate) {
            this.modelPredicates.add((ModelPredicate)atom.getPredicate());
        }
        if (!this.mirrorVariables.containsKey(atom)) {
            this.mirrorVariables.put(atom, new ArrayList());
        }
        this.mirrorVariables.get(atom).add(new MirrorTermCoefficient(this, term, coefficient));
    }

    @Override
    public void variablesExternallyUpdated() {
        this.variablesExternallyUpdatedFlag = true;
        this.store.variablesExternallyUpdated();
    }

    public boolean getVariablesExternallyUpdatedFlag() {
        return this.variablesExternallyUpdatedFlag;
    }

    public void resetVariablesExternallyUpdatedFlag() {
        this.variablesExternallyUpdatedFlag = false;
    }

    @Override
    public void ensureVariableCapacity(int capacity) {
        if (capacity < 0) {
            throw new IllegalArgumentException("Variable capacity must be non-negative. Got: " + capacity);
        }
        if (capacity == 0) {
            return;
        }
        if (this.variables == null || this.variables.size() == 0) {
            this.variables = new HashMap<V, Integer>((int)Math.ceil((double)capacity / 0.75));
            this.variableValues = new float[capacity];
            this.variableAtoms = new RandomVariableAtom[capacity];
        } else if (this.variables.size() < capacity) {
            if (capacity < this.variables.size() * 2) {
                capacity = this.variables.size() * 2;
            }
            HashMap<Integer, Integer> newVariables = new HashMap<Integer, Integer>((int)Math.ceil((double)capacity / 0.75));
            newVariables.putAll(this.variables);
            this.variables = newVariables;
            this.variableValues = Arrays.copyOf(this.variableValues, capacity);
            this.variableAtoms = Arrays.copyOf(this.variableAtoms, capacity);
        }
    }

    @Override
    public Iterable<V> getVariables() {
        return this.variables.keySet();
    }

    @Override
    public void add(GroundRule rule, T term, Hyperplane hyperplane) {
        this.store.add(rule, term, hyperplane);
        if (hyperplane.getIntegratedRVAs() != null) {
            Iterator<Hyperplane.IntegratedRVA> iterator = hyperplane.getIntegratedRVAs().iterator();
            while (iterator.hasNext()) {
                Hyperplane.IntegratedRVA object;
                Hyperplane.IntegratedRVA integratedRVA = object = iterator.next();
                this.createMirrorVariable(integratedRVA.atom, integratedRVA.coefficient, term);
            }
        }
    }

    @Override
    public void clear() {
        if (this.store != null) {
            this.store.clear();
        }
        if (this.variables != null) {
            this.variables.clear();
        }
        if (this.modelPredicates != null) {
            this.modelPredicates.clear();
        }
        if (this.mirrorVariables != null) {
            this.mirrorVariables.clear();
        }
        this.variableValues = null;
        this.variableAtoms = null;
    }

    @Override
    public void reset() {
        for (int i = 0; i < this.variables.size(); ++i) {
            this.variableValues[i] = this.variableAtoms[i].getValue();
        }
    }

    @Override
    public void close() {
        this.clear();
        if (this.store != null) {
            this.store.close();
            this.store = null;
        }
        this.variables = null;
    }

    @Override
    public void initForOptimization() {
        this.initialFitModelAtoms();
        this.updateModelAtoms();
    }

    @Override
    public void iterationComplete() {
        this.fitModelAtoms();
        this.updateModelAtoms();
    }

    public RandomVariableAtom getAtom(int index) {
        return this.variableAtoms[index];
    }

    private void updateModelAtoms() {
        if (this.modelPredicates.size() == 0) {
            return;
        }
        for (ModelPredicate predicate : this.modelPredicates) {
            predicate.runModel();
        }
        double rmse = 0.0;
        int count = 0;
        for (RandomVariableAtom mirrorAtom : this.mirrorVariables.keySet()) {
            if (!(mirrorAtom.getPredicate() instanceof ModelPredicate)) continue;
            ModelPredicate predicate = (ModelPredicate)mirrorAtom.getPredicate();
            float oldValue = mirrorAtom.getValue();
            float newValue = predicate.getValue(mirrorAtom);
            mirrorAtom.setValue(newValue);
            for (MirrorTermCoefficient pair : this.mirrorVariables.get(mirrorAtom)) {
                pair.term.adjustConstant(pair.coefficient * oldValue, pair.coefficient * newValue);
            }
            rmse += Math.pow(newValue - predicate.getLabel(mirrorAtom), 2.0);
            ++count;
        }
        if (count != 0) {
            rmse = Math.pow(rmse / (double)count, 0.5);
        }
        log.trace("Batch update of {} model atoms. RMSE: {}", count, rmse);
        this.variablesExternallyUpdated();
    }

    private void initialFitModelAtoms() {
        for (ModelPredicate predicate : this.modelPredicates) {
            predicate.initialFit();
        }
    }

    private void fitModelAtoms() {
        if (this.modelPredicates.size() == 0) {
            return;
        }
        for (ModelPredicate predicate : this.modelPredicates) {
            predicate.resetLabels();
        }
        int count = 0;
        for (RandomVariableAtom mirrorAtom : this.mirrorVariables.keySet()) {
            if (!(mirrorAtom.getPredicate() instanceof ModelPredicate)) continue;
            float labelValue = this.variableValues[this.variables.get(this.convertAtomToVariable(mirrorAtom.getMirror()))];
            ((ModelPredicate)mirrorAtom.getPredicate()).setLabel(mirrorAtom, labelValue);
            ++count;
        }
        for (ModelPredicate predicate : this.modelPredicates) {
            predicate.fit();
        }
        log.trace("Batch fit of {} model atoms.", count);
    }

    @Override
    public T get(long index) {
        return this.store.get(index);
    }

    @Override
    public long size() {
        return this.store.size();
    }

    @Override
    public void ensureCapacity(long capacity) {
        this.store.ensureCapacity(capacity);
    }

    @Override
    public Iterator<T> iterator() {
        if (this.shuffle) {
            this.store.shuffle();
        }
        return this.store.iterator();
    }

    @Override
    public Iterator<T> noWriteIterator() {
        return this.iterator();
    }

    protected abstract V convertAtomToVariable(RandomVariableAtom var1);

    private class MirrorTermCoefficient {
        public T term;
        public float coefficient;
        final /* synthetic */ MemoryVariableTermStore this$0;

        /*
         * WARNING - Possible parameter corruption
         */
        public MirrorTermCoefficient(T term, float coefficient) {
            this.this$0 = (MemoryVariableTermStore)f;
            this.term = term;
            this.coefficient = coefficient;
        }
    }
}

