/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.model.deep;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.RandomAccessFile;
import java.io.Writer;
import java.net.Socket;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.HashMap;
import java.util.Map;
import org.json.JSONObject;
import org.linqs.psl.config.Config;
import org.linqs.psl.config.Options;
import org.linqs.psl.util.FileUtils;
import org.linqs.psl.util.Logger;

public abstract class DeepModel {
    private static final Logger log = Logger.getLogger(DeepModel.class);
    protected static final String CONFIG_MODEL_PATH = "model-path";
    protected static final String CONFIG_RELATIVE_DIR = "relative-dir";
    private static final long SERVER_SLEEP_TIME_MS = 500L;
    private static int startingPort = Options.PREDICATE_DEEP_PYTHON_PORT.getInt();
    private static Map<Integer, DeepModel> usedPorts = new HashMap<Integer, DeepModel>();
    protected String deepModel;
    protected Map<String, String> pythonOptions;
    protected String application;
    protected int port;
    protected String pythonModule;
    protected String sharedMemoryPath;
    protected Process pythonServerProcess;
    protected RandomAccessFile sharedFile;
    protected MappedByteBuffer sharedBuffer;
    protected Socket socket;
    protected BufferedReader socketInput;
    protected PrintWriter socketOutput;
    protected boolean serverOpen;

    protected DeepModel(String deepModel) {
        this.deepModel = deepModel;
        this.pythonOptions = new HashMap<String, String>();
        this.application = null;
        this.port = DeepModel.getOpenPort(this);
        this.pythonModule = Options.PREDICATE_DEEP_PYTHON_WRAPPER_MODULE.getString();
        this.sharedMemoryPath = Options.PREDICATE_DEEP_SHARED_MEMORY_PATH.getString();
        this.pythonServerProcess = null;
        this.sharedFile = null;
        this.sharedBuffer = null;
        this.socket = null;
        this.socketInput = null;
        this.socketOutput = null;
        this.serverOpen = false;
    }

    public abstract int init();

    public abstract void writeFitData();

    public abstract void writePredictData();

    public abstract float readPredictData();

    public abstract void writeEvalData();

    public void initDeepModel(String application) {
        log.debug("Init deep model {}.", this);
        this.application = application;
        this.pythonOptions.put(CONFIG_RELATIVE_DIR, Config.getString("runtime.relativebasepath", null));
        int bufferLength = this.init();
        if (this.pythonOptions.get(CONFIG_MODEL_PATH) == null) {
            throw new IllegalArgumentException(String.format("A DeepModel must have a model path (\"%s\") specified in predicate config.", CONFIG_MODEL_PATH));
        }
        if (this.pythonServerProcess == null) {
            log.debug("DeepModel server not found for {}. Starting server.", this);
            this.initServer(bufferLength);
        }
        JSONObject message = new JSONObject();
        message.put("task", "init");
        message.put("deep_model", this.deepModel);
        message.put("shared_memory_path", this.sharedMemoryPath);
        message.put("application", application);
        message.put("options", this.pythonOptions);
        log.debug("Sending init message to deep model server for {}.", this);
        log.debug("Message: {}", message);
        JSONObject response = this.sendSocketMessage(message);
        String resultString = this.getResultString(response);
        log.debug("Init deep model results for {} : {}", this, resultString);
    }

    public void fitDeepModel() {
        log.debug("Fit deep model {}.", this);
        this.sharedBuffer.clear();
        this.writeFitData();
        this.sharedBuffer.force();
        JSONObject message = new JSONObject();
        message.put("task", "fit");
        message.put("deep_model", this.deepModel);
        message.put("options", this.pythonOptions);
        JSONObject response = this.sendSocketMessage(message);
        String resultString = this.getResultString(response);
        log.debug("Fit deep model results for {} : {}", this, resultString);
    }

    public float predictDeepModel(Boolean learning) {
        log.debug("Predict deep model {}.", this);
        this.sharedBuffer.clear();
        this.writePredictData();
        this.sharedBuffer.force();
        JSONObject message = new JSONObject();
        if (learning.booleanValue()) {
            message.put("task", "predict_learn");
        } else {
            message.put("task", "predict");
        }
        message.put("deep_model", this.deepModel);
        message.put("options", this.pythonOptions);
        JSONObject response = this.sendSocketMessage(message);
        this.sharedBuffer.clear();
        float movement = this.readPredictData();
        String resultString = this.getResultString(response);
        log.debug("Predict deep model result for {} : {}", this, resultString);
        return movement;
    }

    public void evalDeepModel() {
        log.debug("Eval deep model {}.", this);
        this.sharedBuffer.clear();
        this.writeEvalData();
        this.sharedBuffer.force();
        JSONObject message = new JSONObject();
        message.put("task", "eval");
        message.put("deep_model", this.deepModel);
        message.put("options", this.pythonOptions);
        JSONObject response = this.sendSocketMessage(message);
        String resultString = this.getResultString(response);
        log.debug("Eval deep model result for {} : {}", this, resultString);
    }

    public void saveDeepModel() {
        log.debug("Save deep model {}.", this);
        JSONObject message = new JSONObject();
        message.put("task", "save");
        message.put("options", this.pythonOptions);
        JSONObject response = this.sendSocketMessage(message);
        String resultString = this.getResultString(response);
        log.debug("Save deep model result for {} : {}", this, resultString);
    }

    public void close() {
        log.debug("Close deep model {}.", this);
        if (this.pythonOptions != null) {
            this.pythonOptions.clear();
        }
        if (this.socketOutput != null) {
            JSONObject message = new JSONObject();
            message.put("task", "close");
            JSONObject response = this.sendSocketMessage(message);
            String resultString = this.getResultString(response);
            log.debug("Close deep model result for {} : {}", this, resultString);
        }
        this.closeServer();
    }

    private String getResultString(JSONObject response) {
        JSONObject result = response.optJSONObject("result");
        if (result == null) {
            return "<No Result Provided>";
        }
        return result.toString();
    }

    private void initServer(int bufferLength) {
        final DeepModel finalThis = this;
        Runtime.getRuntime().addShutdownHook(new Thread(){

            @Override
            public void run() {
                finalThis.close();
            }
        });
        try {
            this.sharedFile = new RandomAccessFile(this.sharedMemoryPath, "rw");
        }
        catch (FileNotFoundException ex) {
            throw new RuntimeException("Could not open random access file: " + this.sharedMemoryPath, ex);
        }
        try {
            this.sharedBuffer = this.sharedFile.getChannel().map(FileChannel.MapMode.READ_WRITE, 0L, bufferLength);
            this.sharedBuffer.clear();
            ProcessBuilder builder = new ProcessBuilder("python3", "-m", this.pythonModule, "" + this.port);
            builder.inheritIO();
            this.pythonServerProcess = builder.start();
            this.sleepForServer();
            this.serverOpen = true;
            this.socket = new Socket("127.0.0.1", this.port);
            this.socketInput = new BufferedReader(new InputStreamReader(this.socket.getInputStream(), FileUtils.DEFAULT_CHARSET));
            this.socketOutput = new PrintWriter((Writer)new OutputStreamWriter(this.socket.getOutputStream(), FileUtils.DEFAULT_CHARSET), true);
        }
        catch (IOException ex) {
            throw new RuntimeException(ex);
        }
    }

    private void closeServer() {
        if (this.socketOutput != null) {
            this.serverOpen = false;
            this.sleepForServer();
            DeepModel.freePort(this.port);
            this.socketOutput.close();
            this.socketOutput = null;
        }
        if (this.socketInput != null) {
            try {
                this.socketInput.close();
            }
            catch (IOException ex) {
                throw new RuntimeException(ex);
            }
            this.socketInput = null;
        }
        if (this.socket != null) {
            if (!this.socket.isClosed()) {
                try {
                    this.socket.close();
                }
                catch (IOException ex) {
                    throw new RuntimeException(ex);
                }
            }
            this.socket = null;
        }
        if (this.sharedBuffer != null) {
            this.sharedBuffer = null;
        }
        if (this.sharedFile != null) {
            try {
                this.sharedFile.close();
                FileUtils.delete(this.sharedMemoryPath);
            }
            catch (IOException ex) {
                throw new RuntimeException("Failed to clean up shared file: " + this.sharedMemoryPath, ex);
            }
            this.sharedFile = null;
        }
        if (this.pythonServerProcess != null) {
            if (this.pythonServerProcess.isAlive()) {
                this.pythonServerProcess.destroyForcibly();
            }
            this.pythonServerProcess = null;
        }
    }

    private void sleepForServer() {
        try {
            Thread.sleep(500L);
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
    }

    private JSONObject sendSocketMessage(JSONObject message) {
        if (!this.serverOpen) {
            return null;
        }
        String rawResponse = null;
        log.trace(String.format("Sending server message: '%s'.", message.toString()));
        try {
            this.socketOutput.println(message.toString());
            rawResponse = this.socketInput.readLine();
        }
        catch (IOException ex) {
            throw new RuntimeException(ex);
        }
        log.trace(String.format("Received server message: '%s'.", rawResponse));
        JSONObject response = new JSONObject(rawResponse);
        String status = response.optString("status", "<UNKNOWN>");
        if (!status.equals("success")) {
            this.serverOpen = false;
            this.sleepForServer();
            String failureMessage = response.optString("message", "<no message provided>");
            throw new RuntimeException(String.format("Server sent a failure status (%s): '%s'.", status, failureMessage));
        }
        return response;
    }

    private static synchronized int getOpenPort(DeepModel model) {
        int port = startingPort;
        while (usedPorts.containsKey(port)) {
            ++port;
        }
        usedPorts.put(port, model);
        return port;
    }

    protected static synchronized void freePort(int port) {
        usedPorts.remove(port);
    }
}

