import pymysql
import time
from pymysql import Error

# 条件导入：支持直接运行和包模式
try:
    from .config import get_db_config
except (ImportError, ValueError):
    # 直接运行时失败，尝试绝对路径
    import sys
    import os
    sys.path.insert(0, os.path.dirname(__file__))
    from config import get_db_config


# ===================== 数据库连接 =====================

def _get_connection(database=None, port=None):
    cfg = get_db_config()

    db = database if database else cfg.get('database')
    prt = port if port else cfg.get('port', 3306)

    return pymysql.connect(
        host=cfg.get('host'),
        user=cfg.get('user'),
        password=cfg.get('password'),
        database=db,
        port=prt,
        charset="utf8mb4",
        autocommit=False
    )


# ===================== 查询操作 =====================

def sel_data(sql, params=None, port=None, database=None):

    try:
        with _get_connection(database, port) as conn:
            with conn.cursor() as cursor:
                cursor.execute(sql, params)
                return cursor.fetchall()

    except pymysql.MySQLError as e:
        print("查询失败:", e)
        return None
    
# ===================== 增量更新 =====================
def incremental_update(sql, values, port=None, database=None):
    with _get_connection(database=database, port=port) as conn:
        with conn.cursor() as cursor:
            cursor.executemany(sql, values)
            conn.commit()


# ===================== 删除操作 =====================
def del_data(sql, params=None, port=None, database=None):

    try:
        with _get_connection(database, port) as conn:
            with conn.cursor() as cursor:
                cursor.execute(sql, params)
                conn.commit()

    except pymysql.MySQLError as e:
        print("删除失败:", e)


# ===================== 表结构管理 =====================

def get_table_columns(table_name, database):

    sql = """
    SELECT COLUMN_NAME
    FROM INFORMATION_SCHEMA.COLUMNS
    WHERE TABLE_SCHEMA=%s
      AND TABLE_NAME=%s
    ORDER BY ORDINAL_POSITION
    """

    with _get_connection(database) as conn:
        with conn.cursor() as cursor:
            cursor.execute(sql, (database, table_name))
            return [row[0] for row in cursor.fetchall()]


def add_columns(table_name, database, columns):

    with _get_connection(database) as conn:
        with conn.cursor() as cursor:

            for col in columns:
                sql = f"""
                ALTER TABLE `{table_name}`
                ADD COLUMN `{col}` TEXT 
                """
                cursor.execute(sql)
                print(f"新增字段: {col}")

            conn.commit()

 # 表结构校验
def check_and_sync_columns(df ,table_name, database, max_new=5):

     # df字段
    df_cols = set(df.columns)
    # 数据库字段
    db_cols = set(get_table_columns(table_name, database))
    # 新增字段
    new_cols = df_cols - db_cols

    # 缺失字段（减少的字段，不处理）
    missing_cols = db_cols - df_cols
    print("新增字段:", new_cols if new_cols else "无")
    print("缺失字段(忽略):", missing_cols if missing_cols else "无")

    # 超限
    if len(new_cols) > max_new:
        raise Exception(f"新增字段超过{max_new}个: {new_cols}")

    # 自动补
    if 0 < len(new_cols) <= max_new:

        print("开始同步字段...")
        add_columns(table_name, database, new_cols)
        print("字段同步成功")

    return True


# ===================== DataFrame 入库 =====================

def update_datas(df, table_name, database):

    conn = None
    for i in range(2):
        try:
            # 表结构校验,df为抓取数据表，db_cols为数据库现有字段
            conn = _get_connection(database)
            cursor = conn.cursor()
            cols = list(df.columns)
            col_str = ",".join([f"`{c}`" for c in cols])
            value_tpl = "(" + ",".join(["%s"] * len(cols)) + ")"
            values_str = ",".join([value_tpl] * len(df))
            update_clause = ",".join(
                [f"`{c}`=VALUES(`{c}`)" for c in cols if c != "id"]
            )
            sql = f"""
            INSERT INTO `{table_name}` ({col_str})
            VALUES {values_str}
            ON DUPLICATE KEY UPDATE {update_clause}
            """
            data = [tuple(row) for row in df.values]
            flat_data = [v for row in data for v in row]
            cursor.execute(sql, flat_data)
            conn.commit()
            print(f"成功插入/更新 {cursor.rowcount} 条记录")
            break
        except Exception as e:
            if conn:
                conn.rollback()
            print("入库失败:", e)
            if '1054, "Unknown column' in str(e):
                check_and_sync_columns(df,table_name,database)
            else:
                raise Exception(e)
        finally:
            if conn:
                conn.close()


