/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.database.rdbms.driver;

import com.healthmarketscience.sqlbuilder.CreateTableQuery;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.json.JSONArray;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Partition;
import org.linqs.psl.database.rdbms.PredicateInfo;
import org.linqs.psl.database.rdbms.SelectivityHistogram;
import org.linqs.psl.database.rdbms.TableStats;
import org.linqs.psl.database.rdbms.driver.DatabaseDriver;
import org.linqs.psl.model.term.ConstantType;
import org.linqs.psl.util.ListUtils;
import org.linqs.psl.util.Parallel;
import org.linqs.psl.util.StringUtils;
import org.postgresql.PGConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PostgreSQLDriver
implements DatabaseDriver {
    private static final int MAX_STATS = 10000;
    private static final String ENCODING = "UTF-8";
    private static final Logger log = LoggerFactory.getLogger(PostgreSQLDriver.class);
    private final HikariDataSource dataSource;
    private final double statsPercentage;

    public PostgreSQLDriver(String databaseName, boolean clearDatabase) {
        this(Options.POSTGRES_HOST.getString(), Options.POSTGRES_PORT.getString(), databaseName, clearDatabase);
    }

    public PostgreSQLDriver(String host, String port, String databaseName, boolean clearDatabase) {
        this(host, port, Options.POSTGRES_USER.getString(), (String)Options.POSTGRES_PASSWORD.getUnlogged(), databaseName, clearDatabase);
    }

    public PostgreSQLDriver(String host, String port, String user, String password, String databaseName, boolean clearDatabase) {
        this(PostgreSQLDriver.formatConnectionString(host, port, user, password, databaseName), databaseName, clearDatabase);
    }

    public PostgreSQLDriver(String connectionString, String databaseName, boolean clearDatabase) {
        try {
            Class.forName("org.postgresql.Driver");
        }
        catch (ClassNotFoundException ex) {
            throw new RuntimeException("Could not find postgres driver. Please check classpath.", ex);
        }
        log.debug("Connecting to PostgreSQL database: " + databaseName);
        this.statsPercentage = Options.POSTGRES_STATS_PERCENTAGE.getDouble();
        HikariConfig config = new HikariConfig();
        config.setJdbcUrl(connectionString);
        config.setMaximumPoolSize(Math.max(8, Parallel.getNumThreads() * 2));
        config.setMaxLifetime(0L);
        this.dataSource = new HikariDataSource(config);
        if (clearDatabase) {
            this.executeUpdate("DROP SCHEMA public CASCADE");
            this.executeUpdate("CREATE SCHEMA public");
            this.executeUpdate("GRANT ALL ON SCHEMA public TO public");
        }
    }

    @Override
    public void close() {
        this.dataSource.close();
    }

    @Override
    public Connection getConnection() {
        try {
            return this.dataSource.getConnection();
        }
        catch (SQLException ex) {
            throw new RuntimeException("Failed to get connection from pool.", ex);
        }
    }

    @Override
    public boolean supportsBulkCopy() {
        return true;
    }

    @Override
    public void bulkCopy(String path, String delimiter, boolean hasTruth, PredicateInfo predicateInfo, Partition partition) {
        String sql = String.format("COPY %s(%s%s) FROM STDIN WITH DELIMITER '%s'", predicateInfo.tableName(), ListUtils.join(", ", predicateInfo.argumentColumns()), hasTruth ? ", value" : "", delimiter);
        this.setColumnDefault(predicateInfo.tableName(), "partition_id", "'" + partition.getID() + "'");
        try (Connection connection = this.getConnection();
             FileInputStream inFile = new FileInputStream(path);){
            PGConnection pgConnection = connection.unwrap(PGConnection.class);
            pgConnection.getCopyAPI().copyIn(sql, inFile);
        }
        catch (SQLException ex) {
            throw new RuntimeException("Could not perform bulk insert on " + predicateInfo.predicate(), ex);
        }
        catch (IOException ex) {
            throw new RuntimeException("Error bulk copying file: " + path, ex);
        }
        finally {
            this.dropColumnDefault(predicateInfo.tableName(), "partition_id");
        }
    }

    public void setColumnDefault(String tableName, String columnName, String defaultValue) {
        String sql = String.format("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s", tableName, columnName, defaultValue);
        try (Connection connection = this.getConnection();){
            PreparedStatement statement = connection.prepareStatement(sql);
            statement.executeUpdate();
        }
        catch (SQLException ex) {
            throw new RuntimeException(String.format("Could not set the column default of %s for %s.%s.", defaultValue, tableName, columnName), ex);
        }
    }

    public void dropColumnDefault(String tableName, String columnName) {
        String sql = String.format("ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT", tableName, columnName);
        try (Connection connection = this.getConnection();){
            PreparedStatement statement = connection.prepareStatement(sql);
            statement.executeUpdate();
        }
        catch (SQLException ex) {
            throw new RuntimeException(String.format("Could not drop the column default for %s.%s.", tableName, columnName), ex);
        }
    }

    @Override
    public String getTypeName(ConstantType type) {
        switch (type) {
            case Double: {
                return "DOUBLE PRECISION";
            }
            case Integer: {
                return "INT";
            }
            case String: {
                return "TEXT";
            }
            case Long: {
                return "BIGINT";
            }
            case UniqueIntID: {
                return "INT";
            }
            case UniqueStringID: {
                return "TEXT";
            }
        }
        throw new IllegalStateException("Unknown ConstantType: " + (Object)((Object)type));
    }

    @Override
    public String getSurrogateKeyColumnDefinition(String columnName) {
        return columnName + " SERIAL PRIMARY KEY";
    }

    @Override
    public String getDoubleTypeName() {
        return "DOUBLE PRECISION";
    }

    @Override
    public String getUpsert(String tableName, String[] columns, String[] keyColumns) {
        ArrayList<String> updateValues = new ArrayList<String>();
        for (String column : columns) {
            updateValues.add(String.format("%s = EXCLUDED.%s", column, column));
        }
        ArrayList<String> sql = new ArrayList<String>();
        sql.add("INSERT INTO " + tableName + "");
        sql.add("    (" + StringUtils.join(", ", (Object[])columns) + ")");
        sql.add("VALUES");
        sql.add("    (" + StringUtils.repeat("?", ", ", columns.length) + ")");
        sql.add("ON CONFLICT");
        sql.add("    (" + StringUtils.join(", ", (Object[])keyColumns) + ")");
        sql.add("DO UPDATE SET");
        sql.add("    " + ListUtils.join(", ", updateValues));
        return ListUtils.join("\n", sql);
    }

    private void executeUpdate(String sql) {
        try (Connection connection = this.getConnection();
             Statement stmt = connection.createStatement();){
            stmt.executeUpdate(sql);
        }
        catch (SQLException ex) {
            throw new RuntimeException("Failed to execute a general update: [" + sql + "].", ex);
        }
    }

    @Override
    public String finalizeCreateTable(CreateTableQuery createTable) {
        return ((CreateTableQuery)createTable.validate()).toString().replace("CREATE TABLE", "CREATE UNLOGGED TABLE");
    }

    @Override
    public String getStringAggregate(String columnName, String delimiter, boolean distinct) {
        if (delimiter.contains("'")) {
            throw new IllegalArgumentException("Delimiter (" + delimiter + ") may not contain a single quote.");
        }
        return String.format("STRING_AGG(DISTINCT CAST(%s AS TEXT), '%s')", columnName, delimiter);
    }

    @Override
    public TableStats getTableStats(PredicateInfo predicate) {
        ArrayList<String> sql = new ArrayList<String>();
        sql.add("SELECT");
        sql.add("    UPPER(attname) AS col,");
        sql.add("    (SELECT COUNT(*) FROM " + predicate.tableName() + ") AS tableCount,");
        sql.add("    CASE WHEN n_distinct >= 0");
        sql.add("        THEN n_distinct / (SELECT COUNT(*) FROM " + predicate.tableName() + ")");
        sql.add("        ELSE -1.0 * n_distinct");
        sql.add("        END AS selectivity,");
        sql.add("    array_to_json(histogram_bounds) AS histogram,");
        sql.add("    array_to_json(most_common_vals) AS most_common_vals,");
        sql.add("    array_to_json(most_common_freqs) AS most_common_freqs");
        sql.add("FROM pg_stats");
        sql.add("WHERE");
        sql.add("    UPPER(tablename) = '" + predicate.tableName().toUpperCase() + "'");
        sql.add("    AND UPPER(attname) NOT IN ('PARTITION_ID', 'VALUE')");
        TableStats stats = null;
        try (Connection connection = this.getConnection();
             PreparedStatement statement = connection.prepareStatement(ListUtils.join("\n", sql));
             ResultSet result = statement.executeQuery();){
            while (result.next()) {
                if (stats == null) {
                    stats = new TableStats(result.getInt(2));
                }
                String columnName = result.getString(1).toUpperCase();
                stats.addColumnSelectivity(columnName, result.getDouble(3));
                SelectivityHistogram histogram = this.parseHistogram(result.getString(4), result.getString(5), result.getString(6), stats.getCount());
                if (histogram == null) continue;
                stats.addColumnHistogram(columnName, histogram);
            }
        }
        catch (SQLException ex) {
            throw new RuntimeException("Failed to get stats from table: " + predicate.tableName(), ex);
        }
        return stats;
    }

    private SelectivityHistogram parseHistogram(String rawBounds, String rawMostCommonVals, String rawMostCommonCounts, int rowCount) {
        int i;
        SelectivityHistogram<Comparable<Object>> histogram;
        ArrayList<Comparable> bounds = null;
        ArrayList<Integer> counts = null;
        HashMap<Comparable, Integer> mostCommonHistogram = null;
        if (rawBounds != null && ((JSONArray)((Object)(histogram = new JSONArray(rawBounds)))).length() > 0) {
            bounds = new ArrayList<Comparable>();
            counts = new ArrayList<Integer>();
            int bucketCount = rowCount / (((JSONArray)((Object)histogram)).length() - 1);
            bounds.add(this.convertHistogramBound(((JSONArray)((Object)histogram)).get(0)));
            for (i = 1; i < ((JSONArray)((Object)histogram)).length(); ++i) {
                bounds.add(this.convertHistogramBound(((JSONArray)((Object)histogram)).get(i)));
                counts.add(new Integer(bucketCount));
            }
        }
        if (rawMostCommonVals != null) {
            JSONArray mostCommonVals = new JSONArray(rawMostCommonVals);
            JSONArray mostCommonCounts = new JSONArray(rawMostCommonCounts);
            if (mostCommonVals.length() > 0) {
                mostCommonHistogram = new HashMap<Comparable, Integer>();
                for (i = 0; i < mostCommonVals.length(); ++i) {
                    double proportion = ((Number)mostCommonCounts.get(i)).doubleValue();
                    int count = Math.max(1, (int)(proportion * (double)rowCount));
                    mostCommonHistogram.put(this.convertHistogramBound(mostCommonVals.get(i)), new Integer(count));
                }
            }
        }
        histogram = null;
        if (bounds != null) {
            histogram = new SelectivityHistogram();
            if (mostCommonHistogram != null) {
                this.addMostCommonValsToBuckets(bounds, counts, mostCommonHistogram);
            }
            histogram.addHistogramBounds(bounds, counts);
        } else if (mostCommonHistogram != null) {
            histogram = new SelectivityHistogram<Comparable<Object>>();
            histogram.addHistogramExact(mostCommonHistogram);
        }
        return histogram;
    }

    private void addMostCommonValsToBuckets(List<Comparable> bounds, List<Integer> counts, Map<Comparable, Integer> mostCommonHistogram) {
        ArrayList<Comparable> sortedKeys = new ArrayList<Comparable>(mostCommonHistogram.keySet());
        Collections.sort(sortedKeys);
        int currentCommonIndex = 0;
        int bucketIndex = 0;
        while (currentCommonIndex != sortedKeys.size()) {
            Comparable currentCommonValue = (Comparable)sortedKeys.get(currentCommonIndex);
            if (bucketIndex == counts.size()) {
                ++currentCommonIndex;
                int index = counts.size() - 1;
                counts.set(index, new Integer(counts.get(index) + mostCommonHistogram.get(currentCommonValue)));
                continue;
            }
            Comparable bucketStartValue = bounds.get(bucketIndex + 0);
            Comparable bucketEndValue = bounds.get(bucketIndex + 1);
            if (currentCommonValue.compareTo(bucketEndValue) > 0) {
                ++bucketIndex;
                continue;
            }
            ++currentCommonIndex;
            counts.set(bucketIndex, new Integer(counts.get(bucketIndex) + mostCommonHistogram.get(currentCommonValue)));
        }
    }

    private Comparable convertHistogramBound(Object bound) {
        if (bound instanceof Long) {
            return new Integer(((Long)bound).intValue());
        }
        if (bound instanceof Integer) {
            return new Integer((Integer)bound);
        }
        return bound.toString();
    }

    private static String formatConnectionString(String host, String port, String user, String password, String databaseName) {
        String connectionString = String.format("jdbc:postgresql://%s:%s/%s?loggerLevel=OFF", PostgreSQLDriver.urlEncode(host), PostgreSQLDriver.urlEncode(port), PostgreSQLDriver.urlEncode(databaseName));
        if (user != null && user.length() > 0) {
            connectionString = connectionString + "&user=" + PostgreSQLDriver.urlEncode(user);
        }
        if (password != null && password.length() > 0) {
            connectionString = connectionString + "&password=" + PostgreSQLDriver.urlEncode(password);
        }
        return connectionString;
    }

    private static String urlEncode(String text) {
        try {
            return URLEncoder.encode(text, ENCODING);
        }
        catch (UnsupportedEncodingException ex) {
            throw new RuntimeException(String.format("Bad encoding: '%s'.", ENCODING), ex);
        }
    }

    @Override
    public void updateDBStats() {
        this.executeUpdate("VACUUM ANALYZE");
    }

    @Override
    public void updateTableStats(PredicateInfo predicate) {
        int count = 0;
        try {
            Connection connection = this.getConnection();
            Object object = null;
            try {
                count = predicate.getCount(connection);
            }
            catch (Throwable throwable) {
                object = throwable;
                throw throwable;
            }
            finally {
                if (connection != null) {
                    if (object != null) {
                        try {
                            connection.close();
                        }
                        catch (Throwable throwable) {
                            ((Throwable)object).addSuppressed(throwable);
                        }
                    } else {
                        connection.close();
                    }
                }
            }
        }
        catch (SQLException ex) {
            throw new RuntimeException(String.format("Could not get table count for stats update: " + predicate, new Object[0]));
        }
        int statsCount = (int)Math.min(10000.0, (double)count * this.statsPercentage);
        if (statsCount == 0) {
            return;
        }
        for (String col : predicate.argumentColumns()) {
            this.executeUpdate(String.format("ALTER TABLE %s ALTER COLUMN %s SET STATISTICS %d", predicate.tableName(), col, statsCount));
        }
    }
}

