/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm.term;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.term.MemoryTermStore;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.IteratorUtils;

public class ADMMTermStore
implements TermStore<ADMMObjectiveTerm, LocalVariable> {
    private TermStore<ADMMObjectiveTerm, ?> store;
    private Map<RandomVariableAtom, Integer> variableIndexes;
    private List<List<LocalVariable>> localVariables;
    private float[] consensusValues;
    private int numLocalVariables;

    public ADMMTermStore() {
        this(new MemoryTermStore<ADMMObjectiveTerm>());
    }

    public ADMMTermStore(TermStore<ADMMObjectiveTerm, ?> store) {
        this.store = store;
        this.variableIndexes = new HashMap<RandomVariableAtom, Integer>();
        this.localVariables = new ArrayList<List<LocalVariable>>();
        this.numLocalVariables = 0;
        this.consensusValues = null;
    }

    @Override
    public void ensureVariableCapacity(int capacity) {
        if (capacity == 0) {
            return;
        }
        ((ArrayList)this.localVariables).ensureCapacity(capacity);
        if (this.variableIndexes.size() == 0) {
            this.variableIndexes = new HashMap<RandomVariableAtom, Integer>((int)Math.ceil((double)capacity / 0.75));
        }
    }

    @Override
    public synchronized LocalVariable createLocalVariable(RandomVariableAtom atom) {
        int globalId;
        ++this.numLocalVariables;
        if (this.variableIndexes.containsKey(atom)) {
            globalId = this.variableIndexes.get(atom);
        } else {
            if (this.consensusValues != null) {
                throw new RuntimeException("No new variables can be created after the consensus varibles have been requested.");
            }
            globalId = this.variableIndexes.size();
            this.variableIndexes.put(atom, globalId);
            this.localVariables.add(new ArrayList());
        }
        LocalVariable localVariable = new LocalVariable(globalId, atom.getValue());
        this.localVariables.get(globalId).add(localVariable);
        return localVariable;
    }

    public int getNumLocalVariables() {
        return this.numLocalVariables;
    }

    public int getNumGlobalVariables() {
        return this.variableIndexes.size();
    }

    public List<LocalVariable> getLocalVariables(int globalId) {
        return this.localVariables.get(globalId);
    }

    public float[] getConsensusValues() {
        if (this.consensusValues != null) {
            return this.consensusValues;
        }
        this.consensusValues = new float[this.variableIndexes.size()];
        for (Map.Entry<RandomVariableAtom, Integer> entry : this.variableIndexes.entrySet()) {
            this.consensusValues[entry.getValue().intValue()] = entry.getKey().getValue();
        }
        return this.consensusValues;
    }

    public Map<RandomVariableAtom, Integer> getGlobalVariables() {
        return Collections.unmodifiableMap(this.variableIndexes);
    }

    public void updateVariables() {
        for (Map.Entry<RandomVariableAtom, Integer> entry : this.variableIndexes.entrySet()) {
            entry.getKey().setValue(this.consensusValues[entry.getValue()]);
        }
    }

    @Override
    public void add(GroundRule rule, ADMMObjectiveTerm term) {
        this.store.add(rule, term);
    }

    @Override
    public void clear() {
        if (this.store != null) {
            this.store.clear();
        }
        if (this.variableIndexes != null) {
            this.variableIndexes.clear();
        }
        if (this.localVariables != null) {
            this.localVariables.clear();
        }
        this.numLocalVariables = 0;
        this.consensusValues = null;
    }

    @Override
    public void reset() {
        for (Map.Entry<RandomVariableAtom, Integer> entry : this.variableIndexes.entrySet()) {
            if (this.consensusValues != null) {
                this.consensusValues[entry.getValue().intValue()] = entry.getKey().getValue();
            }
            for (LocalVariable local : this.localVariables.get(entry.getValue())) {
                local.setValue(entry.getKey().getValue());
                local.setLagrange(0.0f);
            }
        }
    }

    @Override
    public void close() {
        this.clear();
        if (this.store != null) {
            this.store.close();
            this.store = null;
        }
        this.variableIndexes = null;
        this.localVariables = null;
    }

    @Override
    public void initForOptimization() {
        this.store.initForOptimization();
    }

    @Override
    public void iterationComplete() {
        this.store.iterationComplete();
    }

    @Override
    public ADMMObjectiveTerm get(int index) {
        return this.store.get(index);
    }

    @Override
    public int size() {
        return this.store.size();
    }

    @Override
    public void ensureCapacity(int capacity) {
        this.store.ensureCapacity(capacity);
    }

    @Override
    public Iterator<ADMMObjectiveTerm> iterator() {
        return this.store.iterator();
    }

    @Override
    public Iterator<ADMMObjectiveTerm> noWriteIterator() {
        return this.iterator();
    }

    public Iterable<ADMMObjectiveTerm> getTerms(GroundRule groundRule) {
        final GroundRule finalGroundRule = groundRule;
        return IteratorUtils.filter(this.store, new IteratorUtils.FilterFunction<ADMMObjectiveTerm>(){

            @Override
            public boolean keep(ADMMObjectiveTerm term) {
                return finalGroundRule.equals(term.getGroundRule());
            }
        });
    }
}

