import os, sys, json, unittest, logging, datetime, getpass, enum
from uuid import UUID as uuid_type, uuid4

from sqlalchemy import (create_engine, Column, Integer, String, Boolean, Float, LargeBinary, Numeric, Date, Time,
                        DateTime, Text, Enum)
from sqlalchemy.dialects.postgresql import UUID, JSONB, JSON, DATE, TIME, TIMESTAMP, ARRAY
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from sqlalchemy_aurora_data_api import register_dialects, _ADA_TIMESTAMP  # noqa

logging.basicConfig(level=logging.INFO)
logging.getLogger("aurora_data_api").setLevel(logging.DEBUG)
logging.getLogger("urllib3.connectionpool").setLevel(logging.DEBUG)

dialect_interface_attributes = {
    "name",
    "driver",
    "positional",
    "paramstyle",
    "convert_unicode",
    "encoding",
    "statement_compiler",
    "ddl_compiler",
    "server_version_info",
    "default_schema_name",
    "execution_ctx_cls",
    "execute_sequence_format",
    "preparer",
    "supports_alter",
    "max_identifier_length",
    "supports_unicode_statements",
    "supports_unicode_binds",
    "supports_sane_rowcount",
    "supports_sane_multi_rowcount",
    "preexecute_autoincrement_sequences",
    "implicit_returning",
    "colspecs",
    "supports_default_values",
    "supports_sequences",
    "sequences_optional",
    "supports_native_enum",
    "supports_native_boolean",
    "dbapi_exception_translation_map"
}

dialect_interface_methods = {
    "connect",
    "create_connect_args",
    "create_xid",
    "denormalize_name",
    "do_begin",
    "do_begin_twophase",
    "do_close",
    "do_commit",
    "do_commit_twophase",
    "do_execute",
    "do_execute_no_params",
    "do_executemany",
    "do_prepare_twophase",
    "do_recover_twophase",
    "do_release_savepoint",
    "do_rollback",
    "do_rollback_to_savepoint",
    "do_rollback_twophase",
    "do_savepoint",
    "engine_created",
    "get_check_constraints",
    "get_columns",
    "get_dialect_cls",
    "get_foreign_keys",
    "get_indexes",
    "get_isolation_level",
    "get_pk_constraint",
    "get_table_comment",
    "get_table_names",
    "get_temp_table_names",
    "get_temp_view_names",
    "get_unique_constraints",
    "get_view_definition",
    "get_view_names",
    "has_sequence",
    "has_table",
    "initialize",
    "is_disconnect",
    "normalize_name",
    # "reflect_table",
    "reset_isolation_level",
    "set_isolation_level",
    "type_descriptor"
}

BasicBase = declarative_base()
Base = declarative_base()


class Socks(enum.Enum):
    red = 1
    green = 2
    black = 3


class BasicUser(BasicBase):
    __tablename__ = "sqlalchemy_aurora_data_api_testI"

    id = Column(Integer, primary_key=True)
    name = Column(String(64))
    fullname = Column(String(64))
    nickname = Column(String(64))
    birthday = Column(Date)
    eats_breakfast_at = Column(Time)
    married_at = Column(DateTime)


class User(Base):
    __tablename__ = "sqlalchemy_aurora_data_api_testJ"
    id = Column(Integer, primary_key=True)
    name = Column(String)
    fullname = Column(String)
    nickname = Column(String)
    doc = Column(JSONB)
    doc2 = Column(JSON)
    uuid = Column(UUID)
    uuid2 = Column(UUID(as_uuid=True), default=uuid4)
    flag = Column(Boolean, nullable=True)
    nonesuch = Column(Boolean, nullable=True)
    birthday = Column(DATE)
    wakes_up_at = Column(TIME)
    added = Column(TIMESTAMP)
    floated = Column(Float)
    nybbled = Column(LargeBinary)
    friends = Column(ARRAY(String))
    num_friends = Numeric(asdecimal=True)
    num_laptops = Numeric(asdecimal=False)
    first_date = Column(Date)
    note = Column(Text)
    socks = Column(Enum(Socks))


class TestAuroraDataAPI(unittest.TestCase):
    @classmethod
    def tearDownClass(cls):
        pass

    def test_interface_conformance(self):
        for attr in dialect_interface_attributes:
            self.assertIn(attr, dir(self.engine.dialect))

        for attr in dialect_interface_methods:
            self.assertIn(attr, dir(self.engine.dialect))
            assert callable(getattr(self.engine.dialect, attr))


class TestAuroraDataAPIPostgresDialect(TestAuroraDataAPI):
    dialect = "postgresql+auroradataapi://"
    # dialect = "postgresql+psycopg2://" + getpass.getuser()

    @classmethod
    def setUpClass(cls):
        register_dialects()
        cls.db_name = os.environ.get("AURORA_DB_NAME", __name__)
        cls.engine = create_engine(cls.dialect + ':@/' + cls.db_name)

    def test_execute(self):
        with self.engine.connect() as conn:
            for result in conn.execute("select * from pg_catalog.pg_tables"):
                print(result)

    def test_orm(self):
        uuid = uuid4()
        doc = {'foo': [1, 2, 3]}
        blob = b"0123456789ABCDEF" * 1024
        friends = ["Scarlett O'Hara", 'Ada "Hacker" Lovelace']
        Base.metadata.create_all(self.engine)
        added = datetime.datetime.now()
        ed_user = User(name='ed', fullname='Ed Jones', nickname='edsnickname', doc=doc, doc2=doc, uuid=str(uuid),
                       flag=True, birthday=datetime.datetime.fromtimestamp(0), added=added, floated=1.2, nybbled=blob,
                       friends=friends, num_friends=500, num_laptops=9000, first_date=added, note='note',
                       socks=Socks.red)
        Session = sessionmaker(bind=self.engine)
        session = Session()

        session.query(User).delete()
        session.commit()

        session.add(ed_user)
        self.assertEqual(session.query(User).filter_by(name='ed').first().name, "ed")
        session.commit()
        self.assertGreater(session.query(User).filter(User.name.like('%ed')).count(), 0)
        u = session.query(User).filter(User.name.like('%ed')).first()
        self.assertEqual(u.doc, doc)
        self.assertEqual(u.doc2, doc)
        self.assertEqual(u.flag, True)
        self.assertEqual(u.nonesuch, None)
        self.assertEqual(u.birthday, datetime.date.fromtimestamp(0))
        self.assertEqual(u.added, added.replace(microsecond=0))
        self.assertEqual(u.floated, 1.2)
        self.assertEqual(u.nybbled, blob)
        self.assertEqual(u.friends, friends)
        self.assertEqual(u.num_friends, 500)
        self.assertEqual(u.num_laptops, 9000)
        self.assertEqual(u.first_date, added.date())
        self.assertEqual(u.note, 'note')
        self.assertEqual(u.socks, Socks.red)
        self.assertEqual(u.uuid, str(uuid))
        self.assertIsInstance(u.uuid2, uuid_type)

        u.socks = Socks.green
        session.commit()

        session2 = Session()
        u2 = session2.query(User).filter(User.name.like('%ed')).first()
        self.assertEqual(u2.socks, Socks.green)

    @unittest.skipIf(sys.version_info < (3, 7), "Skipping test that requires Python 3.7+")
    def test_timestamp_microsecond_padding(self):
        ts = '2019-10-31 09:37:17.3186'
        processor = _ADA_TIMESTAMP.result_processor(_ADA_TIMESTAMP, None, None)
        self.assertEqual(processor(ts), datetime.datetime.fromisoformat(ts.ljust(26, "0")))


class TestAuroraDataAPIMySQLDialect(TestAuroraDataAPI):
    dialect = "mysql+auroradataapi://"

    @classmethod
    def setUpClass(cls):
        register_dialects()
        cls.db_name = os.environ.get("AURORA_DB_NAME", __name__)
        cls.engine = create_engine(cls.dialect + ':@/' + cls.db_name + "?charset=utf8mb4")

    def test_execute(self):
        with self.engine.connect() as conn:
            for result in conn.execute("select * from information_schema.tables"):
                print(result)

    def test_orm(self):
        BasicBase.metadata.create_all(self.engine)
        birthday = datetime.datetime.fromtimestamp(0).date()
        eats_breakfast_at = datetime.time(9, 0, 0, 123)
        married_at = datetime.datetime(2020, 2, 20, 2, 20, 2, 200200)
        ed_user = BasicUser(name='ed', fullname='Ed Jones', nickname='edsnickname',
                            birthday=birthday, eats_breakfast_at=eats_breakfast_at, married_at=married_at)
        Session = sessionmaker(bind=self.engine)
        session = Session()

        session.query(BasicUser).delete()
        session.commit()

        session.add(ed_user)
        self.assertEqual(session.query(BasicUser).filter_by(name='ed').first().name, "ed")
        session.commit()
        self.assertGreater(session.query(BasicUser).filter(BasicUser.name.like('%ed')).count(), 0)
        u = session.query(BasicUser).filter(BasicUser.name.like('%ed')).first()
        self.assertEqual(u.nickname, "edsnickname")
        self.assertEqual(u.birthday, birthday)
        self.assertEqual(u.eats_breakfast_at, eats_breakfast_at.replace(microsecond=0))
        self.assertEqual(u.married_at, married_at.replace(microsecond=0))


if __name__ == "__main__":
    unittest.main()
