/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.FunctionalPredicate;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.Logger;

public class TrainingMap {
    private static final Logger log = Logger.getLogger(TrainingMap.class);
    private Map<RandomVariableAtom, ObservedAtom> labelMap;
    private Map<ObservedAtom, ObservedAtom> observedMap;
    private List<RandomVariableAtom> latentVariables;
    private List<ObservedAtom> missingLabels;
    private List<ObservedAtom> missingTargets;

    public TrainingMap(Database targetDatabase, Database truthDatabase) {
        this.labelMap = new HashMap<RandomVariableAtom, ObservedAtom>(targetDatabase.getAtomStore().size());
        this.observedMap = new HashMap<ObservedAtom, ObservedAtom>();
        this.latentVariables = new ArrayList<RandomVariableAtom>();
        this.missingLabels = new ArrayList<ObservedAtom>();
        this.missingTargets = new ArrayList<ObservedAtom>();
        HashSet<ObservedAtom> seenTruthAtoms = new HashSet<ObservedAtom>();
        for (GroundAtom targetAtom : targetDatabase.getAtomStore()) {
            if (targetAtom.getPredicate() instanceof FunctionalPredicate) continue;
            GroundAtom truthAtom = null;
            if (truthDatabase.getAtomStore().hasAtom(targetAtom.getPredicate(), targetAtom.getArguments())) {
                truthAtom = truthDatabase.getAtomStore().getAtom(targetAtom.getPredicate(), targetAtom.getArguments());
            }
            if (truthAtom != null && !(truthAtom instanceof ObservedAtom)) continue;
            if (targetAtom instanceof RandomVariableAtom) {
                if (truthAtom == null) {
                    this.latentVariables.add((RandomVariableAtom)targetAtom);
                    continue;
                }
                seenTruthAtoms.add((ObservedAtom)truthAtom);
                this.labelMap.put((RandomVariableAtom)targetAtom, (ObservedAtom)truthAtom);
                continue;
            }
            if (truthAtom == null) {
                this.missingLabels.add((ObservedAtom)targetAtom);
                continue;
            }
            seenTruthAtoms.add((ObservedAtom)truthAtom);
            this.observedMap.put((ObservedAtom)targetAtom, (ObservedAtom)truthAtom);
        }
        for (GroundAtom truthAtom : truthDatabase.getAtomStore()) {
            if (!(truthAtom instanceof ObservedAtom) || seenTruthAtoms.contains(truthAtom)) continue;
            boolean hasAtom = targetDatabase.getAtomStore().hasAtom(truthAtom.getPredicate(), truthAtom.getArguments());
            if (hasAtom) {
                throw new IllegalStateException("Un-persisted target atom: " + truthAtom);
            }
            this.missingTargets.add((ObservedAtom)truthAtom);
        }
        if (this.missingTargets.size() > 0) {
            log.warn("Found {} missing targets (truth atoms without a matching target). Example: {}.", this.missingTargets.size(), this.missingTargets.get(0));
        }
    }

    public Map<RandomVariableAtom, ObservedAtom> getLabelMap() {
        return Collections.unmodifiableMap(this.labelMap);
    }

    public Map<ObservedAtom, ObservedAtom> getObservedMap() {
        return Collections.unmodifiableMap(this.observedMap);
    }

    public List<RandomVariableAtom> getLatentVariables() {
        return Collections.unmodifiableList(this.latentVariables);
    }

    public List<ObservedAtom> getMissingLabels() {
        return Collections.unmodifiableList(this.missingLabels);
    }

    public List<ObservedAtom> getMissingTargets() {
        return Collections.unmodifiableList(this.missingTargets);
    }

    public Iterable<RandomVariableAtom> getAllPredictions() {
        return IteratorUtils.join(this.labelMap.keySet(), this.latentVariables);
    }

    public Iterable<GroundAtom> getAllTargets() {
        return IteratorUtils.join(this.labelMap.keySet(), this.observedMap.keySet(), this.latentVariables, this.missingLabels);
    }

    public Iterable<GroundAtom> getAllTruths() {
        return IteratorUtils.join(this.labelMap.values(), this.observedMap.values(), this.missingTargets);
    }

    public void addRandomVariableTargetAtom(RandomVariableAtom atom) {
        int missingTargetIndex = this.missingTargets.indexOf(atom);
        if (missingTargetIndex != -1) {
            ObservedAtom observedAtom = this.missingTargets.remove(missingTargetIndex);
            this.labelMap.put(atom, observedAtom);
        } else {
            int latentVariableIndex = this.latentVariables.indexOf(atom);
            if (latentVariableIndex == -1) {
                this.latentVariables.add(atom);
            } else {
                this.latentVariables.set(latentVariableIndex, atom);
            }
        }
    }

    public void deleteAtom(GroundAtom atom) {
        if (atom instanceof RandomVariableAtom) {
            this.labelMap.remove((RandomVariableAtom)atom);
            this.latentVariables.remove((RandomVariableAtom)atom);
        } else {
            this.observedMap.remove((ObservedAtom)atom);
            this.missingLabels.remove((ObservedAtom)atom);
            this.missingTargets.remove((ObservedAtom)atom);
        }
    }

    public Iterable<Map.Entry<GroundAtom, GroundAtom>> getFullMap() {
        Iterable<Map.Entry<GroundAtom, GroundAtom>> temp = IteratorUtils.join(this.labelMap.entrySet(), this.observedMap.entrySet());
        return temp;
    }

    public String toString() {
        return String.format("Training Map -- Label Map: %d, Observed Map: %d, Latent Variables: %d, Missing Labels: %d, Missing Targets: %d", this.labelMap.size(), this.observedMap.size(), this.latentVariables.size(), this.missingLabels.size(), this.missingTargets.size());
    }
}

