"""data_migrate

修订 ID: 71a72119935f
父修订: fff55366306e
创建时间: 2023-10-12 13:31:08.788371

"""

from __future__ import annotations

from collections.abc import Sequence

import sqlalchemy as sa
from alembic.op import run_async
from nonebot import logger, require
from sqlalchemy import Connection, inspect
from sqlalchemy.ext.asyncio import AsyncConnection
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import Session

revision: str = "71a72119935f"
down_revision: str | Sequence[str] | None = "fff55366306e"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def _read_old_data(conn: Connection) -> dict[int, dict]:
    insp = inspect(conn)
    if "nonebot_plugin_session_sessionmodel" not in insp.get_table_names():
        return {}

    Base = automap_base()
    Base.prepare(autoload_with=conn)
    SessionModel = Base.classes.nonebot_plugin_session_sessionmodel

    session_model_dict: dict[int, dict] = {}

    with Session(conn) as db_session:
        session_models = db_session.scalars(
            sa.select(SessionModel).order_by(SessionModel.id)
        ).all()
        for session_model in session_models:
            level = 0
            old_level = str(session_model.level)
            if old_level.startswith("LEVEL") and (old_level[5:]).isdigit():
                level = int(old_level[5:])

            session_data = {
                "bot_id": session_model.bot_id,
                "bot_type": session_model.bot_type,
                "platform": session_model.platform,
                "level": level,
                "id1": session_model.id1 or "",
                "id2": session_model.id2 or "",
                "id3": session_model.id3 or "",
            }
            if session_data not in session_model_dict.values():
                session_model_dict[session_model.id] = session_data
    return session_model_dict


def _insert_data(conn: Connection, session_model_dict: dict[int, dict]):
    Base = automap_base()
    Base.prepare(autoload_with=conn)
    SessionModel = Base.classes.nonebot_plugin_session_orm_sessionmodel

    with Session(conn) as db_session:
        for session_id, session_data in session_model_dict.items():
            db_session.add(SessionModel(id=session_id, **session_data))
        db_session.commit()


async def data_migrate(conn: AsyncConnection):
    from nonebot_plugin_datastore.db import get_engine

    async with get_engine().connect() as ds_conn:
        session_model_dict = await ds_conn.run_sync(_read_old_data)
        if not session_model_dict:
            return

    logger.info("session-orm: 发现来自 datastore 的数据，正在迁移...")

    await conn.run_sync(_insert_data, session_model_dict)

    logger.info("session-orm: 迁移完成")


def upgrade(name: str = "") -> None:
    if name:
        return
    # ### commands auto generated by Alembic - please adjust! ###

    try:
        require("nonebot_plugin_datastore")
    except RuntimeError:
        return

    run_async(data_migrate)


def downgrade(name: str = "") -> None:
    if name:
        return
    # ### commands auto generated by Alembic - please adjust! ###
    pass
    # ### end Alembic commands ###
