import pymysql
import time
from pymysql import Error
import pandas as pd
# 条件导入：支持直接运行和包模式
try:
    from .config import get_db_config
except (ImportError, ValueError):
    # 直接运行时失败，尝试绝对路径
    import sys
    import os
    sys.path.insert(0, os.path.dirname(__file__))
    from config import get_db_config


# ===================== 数据库连接 =====================

def _get_connection(database=None, port=None):
    cfg = get_db_config()

    db = database if database else cfg.get('database')
    prt = port if port else cfg.get('port', 3306)

    return pymysql.connect(
        host=cfg.get('host'),
        user=cfg.get('user'),
        password=cfg.get('password'),
        database=db,
        port=prt,
        charset="utf8mb4",
        autocommit=False
    )


# ===================== 查询操作 =====================

def sel_data(sql, params=None, port=None, database=None):

    try:
        with _get_connection(database, port) as conn:
            with conn.cursor() as cursor:
                cursor.execute(sql, params)
                return cursor.fetchall()

    except pymysql.MySQLError as e:
        print("查询失败:", e)
        return None
    
# ===================== 增量更新 =====================
def incremental_update(sql, values, port=None, database=None):
    with _get_connection(database=database, port=port) as conn:
        with conn.cursor() as cursor:
            cursor.executemany(sql, values)
            conn.commit()


# ===================== 删除操作 =====================
def del_data(sql, params=None, port=None, database=None):

    try:
        with _get_connection(database, port) as conn:
            with conn.cursor() as cursor:
                cursor.execute(sql, params)
                conn.commit()

    except pymysql.MySQLError as e:
        print("删除失败:", e)


# ===================== 表结构管理 =====================

def get_table_columns(table_name, database):

    sql = """
    SELECT COLUMN_NAME
    FROM INFORMATION_SCHEMA.COLUMNS
    WHERE TABLE_SCHEMA=%s
      AND TABLE_NAME=%s
    ORDER BY ORDINAL_POSITION
    """

    with _get_connection(database) as conn:
        with conn.cursor() as cursor:
            cursor.execute(sql, (database, table_name))
            return [row[0] for row in cursor.fetchall()]


def add_columns(table_name, database, columns):

    with _get_connection(database) as conn:
        with conn.cursor() as cursor:

            for col in columns:
                sql = f"""
                ALTER TABLE `{table_name}`
                ADD COLUMN `{col}` TEXT 
                """
                cursor.execute(sql)
                print(f"新增字段: {col}")

            conn.commit()

 # 表结构校验
def check_and_sync_columns(df ,table_name, database, max_new=5):

     # df字段
    df_cols = set(df.columns)
    # 数据库字段
    db_cols = set(get_table_columns(table_name, database))
    # 新增字段
    new_cols = df_cols - db_cols
    # 缺失字段（减少的字段，不处理）
    missing_cols = db_cols - df_cols
    print("新增字段:", new_cols if new_cols else "无")
    print("缺失字段(忽略):", missing_cols if missing_cols else "无")

    # 最大新增字段限制
    if len(db_cols) * 0.2 > max_new:
        max_new = int(len(db_cols) *0.2)
    # 自动补
    if 0 < len(new_cols) <= max_new:
        print("开始同步字段...")
        add_columns(table_name, database, new_cols)
        print("字段同步成功")
    else:
        raise Exception(f"新增字段超过{max_new}个: {new_cols}")
    return True
# ===================== 通用查询重复记录函数（适配自定义索引） =====================
def find_duplicate_records(table_name, database, unique_index_fields, port=None):
    """
    查询指定表中基于自定义唯一索引字段组合重复的记录
    参数:
        table_name: 表名（必选）
        database: 数据库名（必选）
        unique_index_fields: 唯一索引字段列表（必选，如 ['begindate', 'enddate']）
        port: 数据库端口（可选）
    返回:
        DataFrame: 重复记录详情；None: 出错；空DataFrame: 无重复
    """
    # 核心校验：确保unique_index_fields是可迭代的非空列表
    if not isinstance(unique_index_fields, list) or len(unique_index_fields) == 0:
        print("错误：unique_index_fields 必须是非空列表！")
        return None
    
    fields_str = ",".join([f"`{field}`" for field in unique_index_fields])
    group_by_str = ",".join([f"`{field}`" for field in unique_index_fields])
    
    # 核心SQL：查询所有重复记录
    sql = f"""
    SELECT t.* 
    FROM `{table_name}` t
    INNER JOIN (
        SELECT {group_by_str}, COUNT(*) AS duplicate_count
        FROM `{table_name}`
        GROUP BY {group_by_str}
        HAVING COUNT(*) > 1
    ) dup ON {
        " AND ".join([f"t.`{field}` = dup.`{field}`" for field in unique_index_fields])
    }
    ORDER BY {group_by_str}
    """
    
    try:
        with _get_connection(database, port) as conn:
            df_duplicates = pd.read_sql(sql, conn)
            if df_duplicates.empty:
                print(f"表 {table_name} 中无重复记录")
            else:
                print(f"发现 {len(df_duplicates)} 条重复记录（{len(df_duplicates.drop_duplicates(subset=unique_index_fields))} 组重复组合）")
                # 可选：保存重复记录到CSV
                df_duplicates.to_csv(f"duplicate_records_{table_name}.csv", index=False, encoding="utf-8-sig")
                print("重复记录已保存到 duplicate_records_{table_name}.csv")
            return df_duplicates
    except pymysql.MySQLError as e:
        print("数据库查询失败:", e)
        return None
    except Exception as e:
        print("未知错误:", e)
        return None
# ===================== 删除重复记录函数（修复参数+调用逻辑） =====================
def delete_duplicate_records(table_name, database, unique_index_fields, port=None, keep_strategy="min_id"):
    """
    删除重复记录，仅保留每组唯一值的一条记录
    参数:
        table_name: 表名（必选）
        database: 数据库名（必选）
        unique_index_fields: 唯一索引字段列表（必选）
        port: 数据库端口（可选）
        keep_strategy: 保留策略（min_id/max_id）
    返回:
        int: 删除的记录数；None: 出错；0: 无重复
    """
    # 1. 先调用修复后的find_duplicate_records查询重复数据
    duplicate_df = find_duplicate_records(table_name, database, unique_index_fields, port)
    if duplicate_df is None or duplicate_df.empty:
        return 0
    
    # 2. 构建删除SQL（核心逻辑）
    group_by_str = ",".join([f"`{field}`" for field in unique_index_fields])
    if keep_strategy == "min_id":
        keep_condition = "MIN(`id`)"
    elif keep_strategy == "max_id":
        keep_condition = "MAX(`id`)"
    else:
        print(f"不支持的策略：{keep_strategy}，默认使用min_id")
        keep_condition = "MIN(`id`)"
    
    delete_sql = f"""
    DELETE t FROM `{table_name}` t
    INNER JOIN (
        SELECT {group_by_str}, {keep_condition} AS keep_id
        FROM `{table_name}`
        GROUP BY {group_by_str}
        HAVING COUNT(*) > 1
    ) dup ON {
        " AND ".join([f"t.`{field}` = dup.`{field}`" for field in unique_index_fields])
    }
    WHERE t.`id` != dup.keep_id
    """
    
    try:
        with _get_connection(database, port) as conn:
            with conn.cursor() as cursor:
                # 安全确认
                print(f"\n⚠️  即将删除表 {table_name} 的重复记录，保留策略：{keep_strategy}")
                print("确认删除请按任意键，取消请关闭程序...")
                input()
                
                # 执行删除
                affected_rows = cursor.execute(delete_sql)
                conn.commit()
                print(f"✅ 成功删除 {affected_rows} 条重复记录")
                return affected_rows
    except pymysql.MySQLError as e:
        print("❌ 删除失败（数据库错误）:", e)
        if 'conn' in locals():
            conn.rollback()
        return None
    except Exception as e:
        print("❌ 删除失败（未知错误）:", e)
        if 'conn' in locals():
            conn.rollback()
        return None
# ===================== DataFrame 入库 =====================

def update_datas(df, table_name, database):

    conn = None
    for i in range(2):
        try:
            conn = _get_connection(database)
            cursor = conn.cursor()
            cols = list(df.columns)
            col_str = ",".join([f"`{c}`" for c in cols])
            value_tpl = "(" + ",".join(["%s"] * len(cols)) + ")"
            values_str = ",".join([value_tpl] * len(df))
            update_clause = ",".join(
                [f"`{c}`=VALUES(`{c}`)" for c in cols if c != "id"]
            )
            sql = f"""
            INSERT INTO `{table_name}` ({col_str})
            VALUES {values_str}
            ON DUPLICATE KEY UPDATE {update_clause}
            """
            print(sql)
            data = [tuple(row) for row in df.values]
            flat_data = [v for row in data for v in row]
            cursor.execute(sql, flat_data)
            conn.commit()
            print(f"成功插入/更新 {cursor.rowcount} 条记录")
            break
        except Exception as e:
            if conn:
                conn.rollback()
            print("入库失败:", e)
            if '1054, "Unknown column' in str(e):
                # 表结构校验,df为抓取数据表
                check_and_sync_columns(df,table_name,database)
            else:
                raise Exception(e)
        finally:
            if conn:
                conn.close()


