"""
SoIdea-update-python 升级包多源下载管理模块
支持多源优先级、分片/断点续传、进度条、单/多线程扩展。
依赖：requests, tqdm
"""
import hashlib
import logging
import os
import time
import zipfile
from typing import Any, Dict, List, Optional

import requests
from tqdm import tqdm

logger = logging.getLogger("SoIdea-update-python.sources")

class PackageSource:
    """
    多源/分片下载管理，兼容 Windows/Linux/Mac。

    Attributes:
        sources (List[Dict]): 源列表。
        chunk_size (int): 分片大小（字节）。
    """
    def __init__(self, sources: Optional[List[Dict[str, Any]]] = None, chunk_size: int = 1024*1024):
        """
        初始化 PackageSource。

        Args:
            sources (Optional[List[Dict[str, Any]]]): 源列表。
            chunk_size (int, optional): 分片大小，默认 1MB。
        """
        self.sources = sources or [
            {"type": "github", "url": None, "priority": 1},
            {"type": "gitee", "url": None, "priority": 2},
            {"type": "custom", "url": None, "priority": 3}
        ]
        self.chunk_size = chunk_size

    def set_source(self, type_: str, url: str, priority: int = 1) -> None:
        """
        设置/更新某个源的 url 和优先级。

        Args:
            type_ (str): 源类型。
            url (str): 源地址。
            priority (int, optional): 优先级。
        """
        for s in self.sources:
            if s["type"] == type_:
                s["url"] = url
                s["priority"] = priority
                return
        self.sources.append({"type": type_, "url": url, "priority": priority})

    def get_best_source(self) -> Optional[Dict[str, Any]]:
        """
        按优先级排序，返回第一个有 url 的源。

        Returns:
            Optional[Dict[str, Any]]: 最优源。
        """
        sorted_sources = sorted(self.sources, key=lambda x: x["priority"])
        for s in sorted_sources:
            if s["url"]:
                return s
        return None

    def download(self, filename: str, timeout: int = 30, target_path: Optional[str] = None, logger_: Optional[logging.Logger] = None) -> Any:
        """
        多源顺序尝试下载，支持分片断点续传。

        Args:
            filename (str): 文件名。
            timeout (int, optional): 超时时间。
            target_path (Optional[str], optional): 保存路径。
            logger_ (Optional[logging.Logger], optional): 日志记录器。
        Returns:
            Any: 下载内容或 True/False。
        """
        logger_ = logger_ or logger
        sorted_sources = sorted(self.sources, key=lambda x: x["priority"])
        for src in sorted_sources:
            if not src["url"]:
                continue
            url = src["url"].rstrip("/") + "/" + filename
            try:
                logger_.info(f"下载升级包: {url}")
                if target_path:
                    self._download_from_url(url, target_path, logger_)
                    logger_.info(f'下载成功: {url}')
                    return True
                else:
                    resp = requests.get(url, timeout=timeout)
                    resp.raise_for_status()
                    logger_.info(f'下载成功: {url}')
                    return resp.content
            except Exception as e:
                logger_.warning(f'下载失败: {url}, {e}')
        logger_.error("无可用升级源")
        return None

    def _download_from_url(self, url: str, target_path: str, logger_: Optional[logging.Logger] = None) -> None:
        """
        单个 url 下载到本地，带进度条。

        Args:
            url (str): 下载地址。
            target_path (str): 保存路径。
            logger_ (Optional[logging.Logger], optional): 日志记录器。
        """
        logger_ = logger_ or logger
        resp = requests.get(url, stream=True, timeout=30)
        resp.raise_for_status()
        total = int(resp.headers.get('content-length', 0))
        parent = os.path.dirname(target_path)
        if parent and not os.path.exists(parent):
            os.makedirs(parent, exist_ok=True)
        with open(target_path, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True, desc='下载中') as pbar:
            for chunk in resp.iter_content(chunk_size=self.chunk_size):
                if chunk:
                    f.write(chunk)
                    pbar.update(len(chunk))
        logger_.info(f'文件已保存: {target_path}')

    def download_multithread(self, filename: str, target_path: str, threads: int = 4, logger_: Optional[logging.Logger] = None) -> Any:
        """
        多线程分片下载（占位实现，后续可扩展为真正分片并发）。

        Args:
            filename (str): 文件名。
            target_path (str): 保存路径。
            threads (int, optional): 线程数。
            logger_ (Optional[logging.Logger], optional): 日志记录器。
        Returns:
            Any: 下载内容或 True/False。
        """
        # 目前为单线程下载，后续可实现分片并发
        return self.download(filename, target_path=target_path, logger_=logger_)

def download_and_extract(
    zip_url: str,
    extract_to: str,
    logger: logging.Logger,
    expected_hash: Optional[str] = None,
    max_retries: int = 3,
    timeout: int = 30
) -> bool:
    """
    下载升级包并解压，支持超时、重试、hash校验、异常日志。

    Args:
        zip_url (str): 升级包下载地址。
        extract_to (str): 解压目标目录。
        logger (logging.Logger): 日志记录器。
        expected_hash (Optional[str], optional): 期望 hash。
        max_retries (int, optional): 最大重试次数。
        timeout (int, optional): 超时时间。
    Returns:
        bool: 是否下载并解压成功。
    """

    if not os.path.exists(extract_to):
        os.makedirs(extract_to)
    local_zip = os.path.join(extract_to, 'update.zip')
    parent = os.path.dirname(local_zip)
    if parent and not os.path.exists(parent):
        os.makedirs(parent, exist_ok=True)
    for attempt in range(1, max_retries + 1):
        try:
            logger.info(f"下载升级包: {zip_url} -> {local_zip} (尝试{attempt}/{max_retries})")
            with requests.get(zip_url, stream=True, timeout=timeout) as r:
                r.raise_for_status()
                with open(local_zip, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            # hash 校验
            if expected_hash:
                sha256 = hashlib.sha256()
                with open(local_zip, 'rb') as f:
                    for chunk in iter(lambda: f.read(8192), b''):
                        sha256.update(chunk)
                file_hash = sha256.hexdigest()
                if file_hash.lower() != expected_hash.lower():
                    logger.error(f"升级包 hash 校验失败: 期望 {expected_hash}, 实际 {file_hash}")
                    os.remove(local_zip)
                    continue
            # 解压
            logger.info(f"解压升级包到: {extract_to}")
            with zipfile.ZipFile(local_zip, 'r') as zip_ref:
                zip_ref.extractall(extract_to)
            os.remove(local_zip)
            return True
        except Exception as e:
            logger.error(f"下载或解压升级包失败: {e}", exc_info=True)
            if os.path.exists(local_zip):
                try:
                    os.remove(local_zip)
                except Exception:
                    pass
            time.sleep(2)
    logger.error(f"升级包下载/解压连续失败 {max_retries} 次，终止升级流程。")
    return False


# 工具函数：获取最新 GitHub Release 信息和 hash
def get_latest_release(repo_url: str, logger: logging.Logger, proxy: str = None) -> Any:
    """
    获取最新 GitHub Release 信息和 hash（如有）。

    Args:
        repo_url (str): 仓库地址。
        logger (logging.Logger): 日志记录器。
        proxy (str, optional): 代理。
    Returns:
        Any: Release 信息。
    """
    import requests
    api = f"https://api.github.com/repos/{repo_url}/releases"
    proxies = {"http": proxy, "https": proxy} if proxy else None
    logger.info(f"请求 GitHub Release: {api} 代理: {proxy}")
    resp = requests.get(api, timeout=10, proxies=proxies)
    resp.raise_for_status()
    releases = resp.json()
    stable_releases = [r for r in releases if not r['prerelease']]
    latest = max(stable_releases, key=lambda r: r['published_at'])
    logger.info(f"检测到最新 release: {latest['tag_name']}")
    # 自动尝试从 assets 或 body 获取 hash
    hash_value = None
    asset_hash = next((a for a in latest['assets'] if a['name'].endswith('.sha256')), None)
    if asset_hash:
        try:
            hash_resp = requests.get(asset_hash['browser_download_url'], timeout=10, proxies=proxies)
            hash_resp.raise_for_status()
            hash_value = hash_resp.text.strip().split()[0]
            logger.info(f"Release API 获取到 hash: {hash_value} (asset)")
        except Exception as e:
            logger.warning(f"获取 hash asset 失败: {e}")
    if not hash_value and 'body' in latest:
        import re
        match = re.search(r'hash[:：]?\s*([a-fA-F0-9]{32,64})', latest['body'])
        if match:
            hash_value = match.group(1)
            logger.info(f"Release API 获取到 hash: {hash_value} (body)")
    return latest, hash_value
