import csv
import os
import time
from unittest import TestCase
from e6data_python_connector import Connection
import json
import logging

logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


class TestE6X(TestCase):
    def setUp(self) -> None:
        self._host = os.environ.get('ENGINE_IP')
        self._database = os.environ.get('DB_NAME')
        self._email = os.environ.get('EMAIL')
        self._password = os.environ.get('PASSWORD')
        self._catalog = os.environ.get('CATALOG')
        self._port = int(os.environ.get('PORT', 80))
        self.e6x_connection = None
        logging.debug('Trying to connect to engine host {}, database {}.'.format(self._host, self._database))
        self.e6x_connection = Connection(
            host=self._host,
            port=self._port,
            username=self._email,
            database=self._database,
            password=self._password,
            catalog=self._catalog
        )
        logging.debug('Successfully to connect to engine.')

    def test_connection(self):
        self.assertIsNotNone(self.e6x_connection, 'Unable to connect.')

    def disconnect(self):
        self.e6x_connection.close()
        self.assertFalse(self.e6x_connection.check_connection())

    def test_query_1(self):
        sql = 'select 1'
        logging.debug('Executing query: {}'.format(sql))
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        logging.debug('Query Id {}'.format(query_id))
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        self.assertIn(1, records[0])
        cursor.clear()
        self.e6x_connection.close()

    def test_query_2(self):
        sql = 'select 1,2,3'
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        logging.debug('Query Id {}'.format(query_id))
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        self.assertEqual(3, cursor.description.__len__())
        self.assertIn(1, records[0])
        self.assertIn(2, records[0])
        self.assertIn(3, records[0])
        cursor.clear()

    def test_query_3(self):
        sql = 'select * from {}.{}.lineitem limit 10'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        logging.debug('Query Id {}'.format(query_id))
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        self.assertEqual(10, records.__len__())
        cursor.clear()

    def test_query_fetchmany(self):
        sql = 'select * from {}.{}.lineitem limit 100'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        logging.debug('Query Id {}'.format(query_id))
        self.assertIsNotNone(query_id)
        records = cursor.fetchmany(5)
        self.assertEqual(5, records.__len__())
        records = cursor.fetchmany(10)
        self.assertEqual(10, records.__len__())
        records = cursor.fetchmany(50)
        self.assertEqual(50, records.__len__())
        records = cursor.fetchmany(50)
        self.assertEqual(35, records.__len__())
        cursor.clear()

    def test_query_query_planner(self):
        sql = 'select * from {}.{}.lineitem limit 100'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor()
        cursor.explain_analyse(sql)
        records = cursor.fetchall()
        for record in records:
            logging.debug('Query planner results {}'.format(record))
        self.assertGreater(records.__len__(), 0)
        cursor.clear()

    def test_query_parameterized_query(self):
        sql = 'select * from {}.{}.lineitem where l_partkey=%(partkey)s limit 10'.format(self._catalog, self._database)
        params = {
            'partkey': 10
        }
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql, params)
        logging.debug('Query Id {}'.format(query_id))
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        logging.debug('First records {}'.format(records[0]))
        for record in records:
            self.assertEqual(10, record[1])
        cursor.clear()

    def test_query_parameterized_string_query(self):
        sql = "select * from {}.{}.lineitem where l_linestatus=%(lineitem_status)s and " \
              "l_returnflag=%(return_flag)s limit 10".format(self._catalog, self._database)
        params = {
            'lineitem_status': 'F',
            'return_flag': 'R'
        }
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql, params)
        logging.debug('Query Id {}'.format(query_id))
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        for record in records:
            self.assertEqual('F', record[9])
            self.assertEqual('R', record[8])
        cursor.clear()

    def test_query_get_schema(self):
        sql = 'select * from {}.{}.customer limit 1'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        logging.debug('Query Id {}'.format(query_id))
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        self.assertEqual(1, records.__len__())
        self.assertEqual(8, cursor.description.__len__())
        self.assertEqual('c_custkey', cursor.description[0].name)
        self.assertEqual('LONG', cursor.description[0].type_code.name)
        self.assertEqual('c_name', cursor.description[1].name)
        self.assertEqual('STRING', cursor.description[1].type_code.name)
        self.assertEqual('c_address', cursor.description[2].name)
        self.assertEqual('STRING', cursor.description[2].type_code.name)
        self.assertEqual('c_nationkey', cursor.description[3].name)
        self.assertEqual('LONG', cursor.description[3].type_code.name)
        self.assertEqual('c_phone', cursor.description[4].name)
        self.assertEqual('STRING', cursor.description[4].type_code.name)
        self.assertEqual('c_acctbal', cursor.description[5].name)
        self.assertEqual('DOUBLE', cursor.description[5].type_code.name)
        self.assertEqual('c_mktsegment', cursor.description[6].name)
        self.assertEqual('STRING', cursor.description[6].type_code.name)
        self.assertEqual('c_comment', cursor.description[7].name)
        self.assertEqual('STRING', cursor.description[7].type_code.name)
        cursor.clear()

    def test_catalog_query(self):
        sql = 'show catalogs'
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        cursor.clear()
        self.assertGreater(records.__len__(), 0)

    def test_database_query(self):
        sql = 'show databases in {}'.format(self._catalog)
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        cursor.clear()
        self.assertGreater(records.__len__(), 0)

    def test_tables_query(self):
        sql = 'show tables in {}.{}'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        cursor.clear()
        self.assertGreater(records.__len__(), 0)

    def test_write_csv(self):
        sql = 'select * from {}.{}.lineitem limit 10'.format(self._catalog, self._database)
        file_path = 'test_lineitem.csv'
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        logging.debug('Query Id {}'.format(query_id))
        self.assertIsNotNone(query_id)
        columns = []
        for column in cursor.description:
            columns.append(column.name)
        with open(file_path, 'w') as write_file:
            writer = csv.DictWriter(write_file, fieldnames=columns)
            writer.writeheader()
            for record in cursor:
                row_json = dict()
                for index in range(columns.__len__()):
                    row_json[columns[index]] = record[index]
                writer.writerow(row_json)
        cursor.clear()
        self.assertTrue(os.path.exists(file_path))
        self.assertGreater(os.path.getsize(file_path), 0)
        os.remove(file_path)

    def test_multiple_cursors(self):
        cursor_1 = self.e6x_connection.cursor()
        cursor_2 = self.e6x_connection.cursor()
        cursor_3 = self.e6x_connection.cursor()

        self.assertIsNotNone(cursor_1)
        self.assertIsNotNone(cursor_2)
        self.assertIsNotNone(cursor_3)
        cursor_1.clear()
        cursor_2.clear()
        cursor_3.clear()

    def test_switch_database(self):
        sql = 'show tables in {}.{}'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor(db_name=self._database, catalog_name=self._catalog)
        query_id = cursor.execute(sql)
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        self.assertGreater(records.__len__(), 0)

        # Test null database
        cursor = self.e6x_connection.cursor(db_name=None, catalog_name=self._catalog)
        try:
            query_id = cursor.execute(sql)
            self.assertIsNone(query_id)
        except Exception as e:
            logging.debug(e)
            logging.debug('Excepted, catalog invalid.')

        cursor.clear()

    def test_get_query_time(self):
        sql = 'select * from {}.{}.lineitem limit 100'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        logging.debug('Query Id {}'.format(query_id))
        time.sleep(5)
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        cursor.clear()

    def test_cancel_query(self):
        sql = 'select count(*) from {}.{}.lineitem'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        logging.debug('Query Id {}'.format(query_id))
        try:
            cursor.cancel()
            records = cursor.fetchall()
            self.assertEqual(records.__len__(), 0)
        except Exception as ex:
            logging.debug(ex)
        cursor.clear()

    def test_json_parsing(self):
        sql = 'select nation_json from {}.{}.nation_json'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        for record in records:
            self.assertEqual(json.loads(record[0])['n_name'], 'ALGERIA')
            break
        cursor.clear()

    def test_row_count(self):
        sql = 'select * from {}.{}.lineitem limit 10'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        self.assertIsNotNone(query_id)
        records = cursor.fetchall()
        self.assertEqual(cursor.rowcount, 10)
        cursor.clear()

    def test_executemany(self):
        cursor = self.e6x_connection.cursor()
        try:
            cursor.executemany()
        except Exception as e:
            self.assertEqual(str(e), 'Not Supported')

    def test_fetchone(self):
        sql = 'select * from {}.{}.lineitem limit 100'.format(self._catalog, self._database)
        cursor = self.e6x_connection.cursor()
        query_id = cursor.execute(sql)
        self.assertIsNotNone(query_id)
        count = 0
        while True:
            record = cursor.fetchone()
            if record is None:
                break
            count += 1
        self.assertEqual(count, 100)
        cursor.clear()

    def tearDown(self) -> None:
        logging.debug('Closing connection.')
        self.e6x_connection.close()