/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.cli;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Set;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.configuration2.ex.ConfigurationException;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.cli.DataLoader;
import org.linqs.psl.database.DataStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.Partition;
import org.linqs.psl.database.rdbms.RDBMSDataStore;
import org.linqs.psl.database.rdbms.driver.DatabaseDriver;
import org.linqs.psl.database.rdbms.driver.H2DatabaseDriver;
import org.linqs.psl.database.rdbms.driver.PostgreSQLDriver;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.UnweightedGroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.parser.CommandLineLoader;
import org.linqs.psl.parser.ModelLoader;
import org.linqs.psl.util.Reflection;
import org.linqs.psl.util.StringUtils;
import org.linqs.psl.util.Version;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Launcher {
    public static final String MODEL_FILE_EXTENSION = ".psl";
    public static final String PARTITION_NAME_OBSERVATIONS = "observations";
    public static final String PARTITION_NAME_TARGET = "targets";
    public static final String PARTITION_NAME_LABELS = "truth";
    private static final Logger log = LoggerFactory.getLogger(Launcher.class);
    private CommandLine parsedOptions;

    private Launcher(CommandLine givenOptions) {
        this.parsedOptions = givenOptions;
    }

    private DataStore initDataStore() {
        String dbPath = CommandLineLoader.DEFAULT_H2_DB_PATH;
        boolean useH2 = true;
        if (this.parsedOptions.hasOption("h2path")) {
            dbPath = this.parsedOptions.getOptionValue("h2path");
        } else if (this.parsedOptions.hasOption("postgres")) {
            dbPath = this.parsedOptions.getOptionValue("postgres", "psl_cli");
            useH2 = false;
        }
        DatabaseDriver driver = null;
        driver = useH2 ? new H2DatabaseDriver(H2DatabaseDriver.Type.Disk, dbPath, true) : new PostgreSQLDriver(dbPath, true);
        return new RDBMSDataStore(driver);
    }

    private Set<StandardPredicate> loadData(DataStore dataStore) {
        Set<StandardPredicate> closedPredicates;
        log.info("Loading data");
        try {
            String path = this.parsedOptions.getOptionValue("d");
            closedPredicates = DataLoader.load(dataStore, path, this.parsedOptions.hasOption("int"));
        }
        catch (FileNotFoundException | ConfigurationException ex) {
            throw new RuntimeException("Failed to load data.", ex);
        }
        log.info("Data loading complete");
        return closedPredicates;
    }

    private void outputGroundRules(GroundRuleStore groundRuleStore, String path, boolean includeSatisfaction) {
        PrintStream stream = System.out;
        boolean closeStream = false;
        if (path != null) {
            try {
                stream = new PrintStream(path);
                closeStream = true;
            }
            catch (IOException ex) {
                log.error(String.format("Unable to open file (%s) for ground rules, using stdout instead.", path), ex);
            }
        }
        String header = StringUtils.join("\t", "Weight", "Squared?", "Rule");
        if (includeSatisfaction) {
            header = StringUtils.join("\t", header, "Satisfaction");
        }
        stream.println(header);
        for (GroundRule groundRule : groundRuleStore.getGroundRules()) {
            String row = "";
            double satisfaction = 0.0;
            if (groundRule instanceof WeightedGroundRule) {
                WeightedGroundRule weightedGroundRule = (WeightedGroundRule)groundRule;
                row = StringUtils.join("\t", "" + weightedGroundRule.getWeight(), "" + weightedGroundRule.isSquared(), groundRule.baseToString());
                satisfaction = 1.0 - weightedGroundRule.getIncompatibility();
            } else {
                UnweightedGroundRule unweightedGroundRule = (UnweightedGroundRule)groundRule;
                row = StringUtils.join("\t", ".", "false", groundRule.baseToString());
                satisfaction = 1.0 - unweightedGroundRule.getInfeasibility();
            }
            if (includeSatisfaction) {
                row = StringUtils.join("\t", row, "" + satisfaction);
            }
            stream.println(row);
        }
        if (closeStream) {
            stream.close();
        }
    }

    private Database runInference(Model model, DataStore dataStore, Set<StandardPredicate> closedPredicates, String inferenceName) {
        log.info("Starting inference with class: {}", (Object)inferenceName);
        Partition targetPartition = dataStore.getPartition(PARTITION_NAME_TARGET);
        Partition observationsPartition = dataStore.getPartition(PARTITION_NAME_OBSERVATIONS);
        Database database = dataStore.getDatabase(targetPartition, closedPredicates, observationsPartition);
        InferenceApplication inferenceApplication = InferenceApplication.getInferenceApplication(inferenceName, model, database);
        if (this.parsedOptions.hasOption("groundrules")) {
            String path = this.parsedOptions.getOptionValue("groundrules");
            this.outputGroundRules(inferenceApplication.getGroundRuleStore(), path, false);
        }
        boolean commitAtoms = !this.parsedOptions.hasOption("skipAtomCommit");
        inferenceApplication.inference(commitAtoms);
        if (this.parsedOptions.hasOption("satisfaction")) {
            String path = this.parsedOptions.getOptionValue("satisfaction");
            this.outputGroundRules(inferenceApplication.getGroundRuleStore(), path, true);
        }
        log.info("Inference Complete");
        this.outputResults(database, dataStore, closedPredicates);
        return database;
    }

    private void outputResults(Database database, DataStore dataStore, Set<StandardPredicate> closedPredicates) {
        Set<StandardPredicate> openPredicates = dataStore.getRegisteredPredicates();
        openPredicates.removeAll(closedPredicates);
        if (!this.parsedOptions.hasOption("o")) {
            for (StandardPredicate openPredicate : openPredicates) {
                for (GroundAtom groundAtom : database.getAllGroundRandomVariableAtoms(openPredicate)) {
                    System.out.println(groundAtom.toString() + " = " + groundAtom.getValue());
                }
            }
            return;
        }
        String outputDirectoryPath = this.parsedOptions.getOptionValue("o");
        File outputDirectory = new File(outputDirectoryPath);
        outputDirectory.mkdirs();
        for (StandardPredicate standardPredicate : openPredicates) {
            try {
                FileWriter predFileWriter = new FileWriter(new File(outputDirectory, standardPredicate.getName() + ".txt"));
                StringBuilder row = new StringBuilder();
                for (GroundAtom groundAtom : database.getAllGroundRandomVariableAtoms(standardPredicate)) {
                    row.setLength(0);
                    for (Constant term : groundAtom.getArguments()) {
                        row.append(term.rawToString());
                        row.append("\t");
                    }
                    row.append(Double.toString(groundAtom.getValue()));
                    row.append("\n");
                    predFileWriter.write(row.toString());
                }
                predFileWriter.close();
            }
            catch (IOException ex) {
                log.error("Exception writing predicate {}", (Object)standardPredicate);
            }
        }
    }

    private void learnWeights(Model model, DataStore dataStore, Set<StandardPredicate> closedPredicates, String wlaName) {
        String path;
        log.info("Starting weight learning with learner: " + wlaName);
        Partition targetPartition = dataStore.getPartition(PARTITION_NAME_TARGET);
        Partition observationsPartition = dataStore.getPartition(PARTITION_NAME_OBSERVATIONS);
        Partition truthPartition = dataStore.getPartition(PARTITION_NAME_LABELS);
        Database randomVariableDatabase = dataStore.getDatabase(targetPartition, closedPredicates, observationsPartition);
        Database observedTruthDatabase = dataStore.getDatabase(truthPartition, dataStore.getRegisteredPredicates(), new Partition[0]);
        WeightLearningApplication learner = WeightLearningApplication.getWLA(wlaName, model.getRules(), randomVariableDatabase, observedTruthDatabase);
        learner.learn();
        if (this.parsedOptions.hasOption("groundrules")) {
            path = this.parsedOptions.getOptionValue("groundrules");
            this.outputGroundRules(learner.getGroundRuleStore(), path, false);
        }
        learner.close();
        if (this.parsedOptions.hasOption("satisfaction")) {
            path = this.parsedOptions.getOptionValue("satisfaction");
            this.outputGroundRules(learner.getGroundRuleStore(), path, true);
        }
        randomVariableDatabase.close();
        observedTruthDatabase.close();
        log.info("Weight learning complete");
        String modelFilename = this.parsedOptions.getOptionValue("m");
        int prefixPos = modelFilename.lastIndexOf(MODEL_FILE_EXTENSION);
        String learnedFilename = prefixPos == -1 ? modelFilename + MODEL_FILE_EXTENSION : modelFilename.substring(0, prefixPos) + "-learned" + MODEL_FILE_EXTENSION;
        log.info("Writing learned model to {}", (Object)learnedFilename);
        String outModel = model.asString();
        outModel = outModel.replaceAll("\\( | \\)", "");
        try (FileWriter learnedFileWriter = new FileWriter(new File(learnedFilename));){
            learnedFileWriter.write(outModel);
        }
        catch (IOException ex) {
            log.error("Failed to write learned model:\n" + outModel);
            throw new RuntimeException("Failed to write learned model to: " + learnedFilename, ex);
        }
    }

    private void evaluation(DataStore dataStore, Database predictionDatabase, Set<StandardPredicate> closedPredicates, String evalClassName) {
        log.info("Starting evaluation with class: {}.", (Object)evalClassName);
        Set<StandardPredicate> openPredicates = dataStore.getRegisteredPredicates();
        openPredicates.removeAll(closedPredicates);
        Partition targetPartition = dataStore.getPartition(PARTITION_NAME_TARGET);
        Partition observationsPartition = dataStore.getPartition(PARTITION_NAME_OBSERVATIONS);
        Partition truthPartition = dataStore.getPartition(PARTITION_NAME_LABELS);
        boolean closePredictionDB = false;
        if (predictionDatabase == null) {
            closePredictionDB = true;
            predictionDatabase = dataStore.getDatabase(targetPartition, closedPredicates, observationsPartition);
        }
        Database truthDatabase = dataStore.getDatabase(truthPartition, dataStore.getRegisteredPredicates(), new Partition[0]);
        Evaluator evaluator = (Evaluator)Reflection.newObject(evalClassName);
        for (StandardPredicate targetPredicate : openPredicates) {
            if (truthDatabase.countAllGroundAtoms(targetPredicate) == 0) {
                log.info("Skipping evaluation for {} since there are no ground truth atoms", (Object)targetPredicate);
                continue;
            }
            evaluator.compute(predictionDatabase, truthDatabase, targetPredicate, !closePredictionDB);
            log.info("Evaluation results for {} -- {}", (Object)targetPredicate.getName(), (Object)evaluator.getAllStats());
        }
        if (closePredictionDB) {
            predictionDatabase.close();
        }
        truthDatabase.close();
    }

    private Model loadModel(DataStore dataStore) {
        log.info("Loading model from {}", (Object)this.parsedOptions.getOptionValue("m"));
        Model model = null;
        try (FileReader reader = new FileReader(new File(this.parsedOptions.getOptionValue("m")));){
            model = ModelLoader.load(dataStore, reader);
        }
        catch (IOException ex) {
            throw new RuntimeException("Failed to load model from file: " + this.parsedOptions.getOptionValue("m"), ex);
        }
        log.debug("Model:");
        for (Rule rule : model.getRules()) {
            log.debug("   " + rule);
        }
        log.info("Model loading complete");
        return model;
    }

    private void run() {
        log.info("Running PSL CLI Version {}", (Object)Version.getFull());
        DataStore dataStore = this.initDataStore();
        Set<StandardPredicate> closedPredicates = this.loadData(dataStore);
        Model model = this.loadModel(dataStore);
        Database evalDB = null;
        if (this.parsedOptions.hasOption("i")) {
            evalDB = this.runInference(model, dataStore, closedPredicates, this.parsedOptions.getOptionValue("i", CommandLineLoader.DEFAULT_IA));
        } else if (this.parsedOptions.hasOption("l")) {
            this.learnWeights(model, dataStore, closedPredicates, this.parsedOptions.getOptionValue("l", CommandLineLoader.DEFAULT_WLA));
        } else {
            throw new IllegalArgumentException("No valid operation provided.");
        }
        if (this.parsedOptions.hasOption("e")) {
            for (String evaluator : this.parsedOptions.getOptionValues("e")) {
                this.evaluation(dataStore, evalDB, closedPredicates, evaluator);
            }
            log.info("Evaluation complete.");
        }
        if (evalDB != null) {
            evalDB.close();
        }
        dataStore.close();
    }

    private static boolean isCommandLineValid(CommandLine givenOptions) {
        if (givenOptions.hasOption("h") || givenOptions.hasOption("v")) {
            return false;
        }
        HelpFormatter helpFormatter = new HelpFormatter();
        if (!givenOptions.hasOption("d")) {
            System.out.println(String.format("Missing required option: --%s/-%s.", "data", "d"));
            helpFormatter.printHelp("psl", CommandLineLoader.getOptions(), true);
            return false;
        }
        if (!givenOptions.hasOption("m")) {
            System.out.println(String.format("Missing required option: --%s/-%s.", "model", "m"));
            helpFormatter.printHelp("psl", CommandLineLoader.getOptions(), true);
            return false;
        }
        if (!givenOptions.hasOption("i") && !givenOptions.hasOption("l")) {
            System.out.println(String.format("Missing required option: --%s/-%s.", "infer", "i"));
            helpFormatter.printHelp("psl", CommandLineLoader.getOptions(), true);
            return false;
        }
        return true;
    }

    public static void main(String[] args) {
        Launcher.main(args, false);
    }

    public static void main(String[] args, boolean rethrow) {
        try {
            CommandLineLoader commandLineLoader = new CommandLineLoader(args);
            CommandLine givenOptions = commandLineLoader.getParsedOptions();
            if (givenOptions == null || !Launcher.isCommandLineValid(givenOptions)) {
                return;
            }
            Launcher pslLauncher = new Launcher(givenOptions);
            pslLauncher.run();
        }
        catch (Exception ex) {
            if (rethrow) {
                throw new RuntimeException("Failed to run CLI: " + ex.getMessage(), ex);
            }
            System.err.println("Unexpected exception!");
            ex.printStackTrace(System.err);
            System.exit(1);
        }
    }
}

