/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.model.predicate;

import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamException;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.deep.DeepModelPredicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.term.ConstantType;
import org.linqs.psl.util.Logger;

public class DeepPredicate
extends StandardPredicate {
    private static final Logger log = Logger.getLogger(DeepPredicate.class);
    private DeepModelPredicate deepModel = new DeepModelPredicate(this);

    protected DeepPredicate(String name, ConstantType[] types) {
        super(name, types);
    }

    public void initDeepPredicate(AtomStore atomStore, String application) {
        this.deepModel.setAtomStore(atomStore);
        this.deepModel.initDeepModel(application);
    }

    public void fitDeepPredicate(float[] symbolicGradients) {
        this.deepModel.setSymbolicGradients(symbolicGradients);
        this.deepModel.fitDeepModel();
    }

    public DeepModelPredicate getDeepModel() {
        return this.deepModel;
    }

    public void setDeepModel(DeepModelPredicate deepModel) {
        this.deepModel = deepModel;
    }

    public float predictDeepModel() {
        return this.deepModel.predictDeepModel(false);
    }

    public float predictDeepModel(Boolean learning) {
        return this.deepModel.predictDeepModel(learning);
    }

    public void evalDeepModel() {
        this.deepModel.evalDeepModel();
    }

    public void saveDeepModel() {
        this.deepModel.saveDeepModel();
    }

    @Override
    public synchronized void close() {
        super.close();
        this.deepModel.close();
    }

    public static DeepPredicate get(String name) {
        StandardPredicate predicate = StandardPredicate.get(name);
        if (predicate == null) {
            return null;
        }
        if (!(predicate instanceof DeepPredicate)) {
            throw new ClassCastException("Predicate (" + name + ") is not a DeepPredicate.");
        }
        return (DeepPredicate)predicate;
    }

    public static DeepPredicate get(String name, ConstantType ... types) {
        DeepPredicate predicate = DeepPredicate.get(name);
        if (predicate == null) {
            return new DeepPredicate(name, types);
        }
        StandardPredicate.validateTypes(predicate, types);
        return predicate;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        throw new NotSerializableException();
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        throw new NotSerializableException();
    }

    private void readObjectNoData() throws ObjectStreamException {
        throw new NotSerializableException();
    }
}

