/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.model.predicate.model;

import java.util.Map;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.predicate.model.SupportingModel;
import org.linqs.psl.model.term.ConstantType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelPredicate
extends StandardPredicate {
    private static final Logger log = LoggerFactory.getLogger(ModelPredicate.class);
    protected SupportingModel model;
    private boolean modelLoaded;
    private boolean modelRan;

    protected ModelPredicate(String name, ConstantType[] types, SupportingModel model) {
        super(name, types);
        this.model = model;
        this.modelLoaded = false;
        this.modelRan = false;
    }

    public void loadModel(Map<String, String> config, String relativeDir) {
        this.model.load(config, relativeDir);
        this.modelLoaded = true;
    }

    public float getValue(RandomVariableAtom atom) {
        this.checkModel();
        if (!this.modelRan) {
            throw new IllegalStateException("Cannot invoke getValue() before runModel() has been called.");
        }
        float value = this.model.getValue(atom);
        return Math.max(0.0f, Math.min(1.0f, value));
    }

    public void runModel() {
        this.checkModel();
        this.model.run();
        this.modelRan = true;
    }

    public void resetLabels() {
        this.checkModel();
        this.model.resetLabels();
    }

    public void setLabel(RandomVariableAtom atom, float label) {
        this.checkModel();
        this.model.setLabel(atom, label);
    }

    public void fit() {
        this.checkModel();
        log.trace("Fitting {} ({}).", (Object)this, (Object)this.model);
        this.model.fit();
        log.trace("Done fitting {} ({}).", (Object)this, (Object)this.model);
    }

    private void checkModel() {
        if (!this.modelLoaded) {
            throw new IllegalStateException("ModelPredicate (" + this + ") has not been initialized via loadModel().");
        }
    }

    public static ModelPredicate get(String name) {
        StandardPredicate predicate = StandardPredicate.get(name);
        if (predicate == null) {
            return null;
        }
        if (!(predicate instanceof ModelPredicate)) {
            throw new ClassCastException("Predicate (" + name + ") is not a ModelPredicate.");
        }
        return (ModelPredicate)predicate;
    }

    public static ModelPredicate get(String name, SupportingModel model, ConstantType ... types) {
        ModelPredicate predicate = ModelPredicate.get(name);
        if (predicate == null) {
            return new ModelPredicate(name, types, model);
        }
        StandardPredicate.validateTypes(predicate, types);
        return predicate;
    }
}

