#!/usr/bin/env python3
"""
依赖树管理模块 - 分析、可视化和管理 Python 包的依赖关系
"""

import subprocess
import sys
import os
from typing import Dict, List, Set, Tuple, Optional, Any
from collections import defaultdict, deque
import json
import re


class DependencyNode:
    """依赖树节点"""
    
    def __init__(self, name: str, version: str = "", installed: bool = True):
        self.name = name.lower()
        self.version = version
        self.installed = installed
        self.dependencies: List[DependencyNode] = []
        self.dependents: List[DependencyNode] = []
        self.depth = 0
        self.is_root = False
    
    def add_dependency(self, node: 'DependencyNode'):
        """添加依赖"""
        if node not in self.dependencies:
            self.dependencies.append(node)
            node.dependents.append(self)
    
    def remove_dependency(self, node: 'DependencyNode'):
        """移除依赖"""
        if node in self.dependencies:
            self.dependencies.remove(node)
            if self in node.dependents:
                node.dependents.remove(self)
    
    def __str__(self) -> str:
        version_str = f"=={self.version}" if self.version else ""
        status = "✓" if self.installed else "✗"
        return f"{status} {self.name}{version_str}"
    
    def __repr__(self) -> str:
        return f"DependencyNode('{self.name}', '{self.version}', {self.installed})"
    
    def __eq__(self, other) -> bool:
        if not isinstance(other, DependencyNode):
            return False
        return self.name == other.name
    
    def __hash__(self) -> int:
        return hash(self.name)


class DependencyTree:
    """依赖树管理器"""
    
    def __init__(self, python_exe: str = None):
        self.python_exe = python_exe or sys.executable
        self.nodes: Dict[str, DependencyNode] = {}
        self.root_nodes: List[DependencyNode] = []
    
    def get_installed_packages(self) -> List[Tuple[str, str]]:
        """获取已安装的包列表"""
        try:
            result = subprocess.run(
                [self.python_exe, "-m", "pip", "list", "--format=freeze"],
                capture_output=True, text=True, check=True
            )
            
            packages = []
            for line in result.stdout.strip().split('\n'):
                if line and not line.startswith('#'):
                    # 解析包名和版本
                    if '==' in line:
                        name, version = line.split('==', 1)
                        packages.append((name.lower(), version))
                    else:
                        packages.append((line.lower(), ""))
            
            return packages
        except subprocess.CalledProcessError:
            return []
    
    def get_package_info(self, package_name: str) -> Optional[Dict[str, Any]]:
        """获取包的详细信息"""
        try:
            result = subprocess.run(
                [self.python_exe, "-m", "pip", "show", package_name],
                capture_output=True, text=True, check=True
            )
            
            info = {}
            for line in result.stdout.split('\n'):
                if ':' in line:
                    key, value = line.split(':', 1)
                    info[key.strip()] = value.strip()
            
            return info
        except subprocess.CalledProcessError:
            return None
    
    def get_package_dependencies(self, package_name: str) -> List[str]:
        """获取包的依赖列表"""
        info = self.get_package_info(package_name)
        if not info or 'Requires' not in info:
            return []
        
        requires = info['Requires']
        if not requires or requires == 'None':
            return []
        
        # 解析依赖列表
        dependencies = []
        for dep in requires.split(','):
            dep = dep.strip()
            if dep:
                # 提取包名（去掉版本号）
                dep_name = re.split(r'[<>=!~]', dep)[0].strip()
                dependencies.append(dep_name.lower())
        
        return dependencies
    
    def build_tree(self, packages: List[str] = None, root_packages: List[str] = None) -> None:
        """构建依赖树"""
        self.nodes.clear()
        self.root_nodes.clear()
        
        if packages is None:
            # 获取所有已安装的包
            installed_packages = self.get_installed_packages()
            packages = [name for name, _ in installed_packages]
        
        # 创建所有节点
        for package in packages:
            node = DependencyNode(package)
            self.nodes[package] = node
        
        # 建立依赖关系
        for package in packages:
            if package in self.nodes:
                node = self.nodes[package]
                dependencies = self.get_package_dependencies(package)
                
                for dep in dependencies:
                    if dep in self.nodes:
                        dep_node = self.nodes[dep]
                        node.add_dependency(dep_node)
        
        # 如果指定了根包，使用它们；否则自动找出根节点
        if root_packages:
            for package in root_packages:
                if package in self.nodes:
                    node = self.nodes[package]
                    node.is_root = True
                    self.root_nodes.append(node)
        else:
            # 找出根节点（没有被其他包依赖的包）
            self._find_root_nodes()
        
        # 计算深度
        self._calculate_depths()
    
    def _find_root_nodes(self) -> None:
        """找出根节点"""
        self.root_nodes = []
        for node in self.nodes.values():
            if not node.dependents:
                node.is_root = True
                self.root_nodes.append(node)
    
    def _calculate_depths(self) -> None:
        """计算每个节点的深度"""
        # 使用拓扑排序计算深度
        in_degree = defaultdict(int)
        for node in self.nodes.values():
            in_degree[node.name] = len(node.dependents)
        
        queue = deque()
        for node in self.nodes.values():
            if in_degree[node.name] == 0:
                queue.append(node)
                node.depth = 0
        
        while queue:
            node = queue.popleft()
            for dep in node.dependencies:
                in_degree[dep.name] -= 1
                dep.depth = max(dep.depth, node.depth + 1)
                if in_degree[dep.name] == 0:
                    queue.append(dep)
    
    def add_package(self, package_name: str, dependencies: List[str] = None) -> DependencyNode:
        """添加包到依赖树"""
        node = DependencyNode(package_name)
        self.nodes[package_name] = node
        
        if dependencies:
            for dep_name in dependencies:
                if dep_name in self.nodes:
                    node.add_dependency(self.nodes[dep_name])
                else:
                    dep_node = DependencyNode(dep_name)
                    self.nodes[dep_name] = dep_node
                    node.add_dependency(dep_node)
        
        self._find_root_nodes()
        self._calculate_depths()
        return node
    
    def remove_package(self, package_name: str) -> bool:
        """从依赖树中移除包"""
        if package_name not in self.nodes:
            return False
        
        node = self.nodes[package_name]
        
        # 移除所有依赖关系
        for dep in node.dependencies[:]:
            node.remove_dependency(dep)
        
        for dependent in node.dependents[:]:
            dependent.remove_dependency(node)
        
        # 删除节点
        del self.nodes[package_name]
        
        self._find_root_nodes()
        self._calculate_depths()
        return True
    
    def get_dependency_chain(self, package_name: str) -> List[DependencyNode]:
        """获取包的依赖链"""
        if package_name not in self.nodes:
            return []
        
        node = self.nodes[package_name]
        chain = []
        visited = set()
        
        def dfs(n: DependencyNode, depth: int = 0):
            if n.name in visited:
                return
            visited.add(n.name)
            chain.append((n, depth))
            
            for dep in n.dependencies:
                dfs(dep, depth + 1)
        
        dfs(node)
        return chain
    
    def get_dependent_chain(self, package_name: str) -> List[DependencyNode]:
        """获取包的被依赖链"""
        if package_name not in self.nodes:
            return []
        
        node = self.nodes[package_name]
        chain = []
        visited = set()
        
        def dfs(n: DependencyNode, depth: int = 0):
            if n.name in visited:
                return
            visited.add(n.name)
            chain.append((n, depth))
            
            for dependent in n.dependents:
                dfs(dependent, depth + 1)
        
        dfs(node)
        return chain
    
    def find_orphaned_packages(self) -> List[DependencyNode]:
        """查找孤立的包（没有被任何根包依赖的包）"""
        if not self.root_nodes:
            return []
        
        # 从根节点开始遍历，标记可达的包
        reachable = set()
        for root in self.root_nodes:
            self._mark_reachable(root, reachable)
        
        # 返回不可达的包
        orphaned = []
        for node in self.nodes.values():
            if node.name not in reachable:
                orphaned.append(node)
        
        return orphaned
    
    def _mark_reachable(self, node: DependencyNode, reachable: Set[str]) -> None:
        """标记可达的包"""
        if node.name in reachable:
            return
        
        reachable.add(node.name)
        for dep in node.dependencies:
            self._mark_reachable(dep, reachable)
    
    def get_circular_dependencies(self) -> List[List[DependencyNode]]:
        """检测循环依赖"""
        cycles = []
        visited = set()
        rec_stack = set()
        
        def dfs(node: DependencyNode, path: List[DependencyNode]):
            if node.name in rec_stack:
                # 找到循环依赖
                cycle_start = path.index(node)
                cycles.append(path[cycle_start:] + [node])
                return
            
            if node.name in visited:
                return
            
            visited.add(node.name)
            rec_stack.add(node.name)
            path.append(node)
            
            for dep in node.dependencies:
                dfs(dep, path)
            
            path.pop()
            rec_stack.remove(node.name)
        
        for node in self.nodes.values():
            if node.name not in visited:
                dfs(node, [])
        
        return cycles
    
    def visualize_tree(self, max_depth: int = None) -> str:
        """可视化依赖树"""
        if not self.root_nodes:
            return "依赖树为空"
        
        lines = []
        visited = set()
        
        def print_node(node: DependencyNode, prefix: str = "", is_last: bool = True):
            if node.name in visited:
                return
            
            visited.add(node.name)
            
            # 构建节点显示字符串
            node_str = str(node)
            if node.is_root:
                node_str += " (root)"
            
            lines.append(f"{prefix}{'└── ' if is_last else '├── '}{node_str}")
            
            # 递归显示依赖
            deps = sorted(node.dependencies, key=lambda x: x.name)
            for i, dep in enumerate(deps):
                new_prefix = prefix + ('    ' if is_last else '│   ')
                is_last_dep = i == len(deps) - 1
                
                if max_depth is None or dep.depth <= max_depth:
                    print_node(dep, new_prefix, is_last_dep)
                else:
                    lines.append(f"{new_prefix}{'└── ' if is_last_dep else '├── '}... (depth limit)")
        
        # 从根节点开始显示
        root_nodes = sorted(self.root_nodes, key=lambda x: x.name)
        for i, root in enumerate(root_nodes):
            is_last_root = i == len(root_nodes) - 1
            print_node(root, "", is_last_root)
        
        return '\n'.join(lines)
    
    def export_json(self) -> str:
        """导出依赖树为 JSON 格式"""
        data = {
            "nodes": {},
            "root_nodes": [],
            "metadata": {
                "total_packages": len(self.nodes),
                "root_count": len(self.root_nodes)
            }
        }
        
        for name, node in self.nodes.items():
            data["nodes"][name] = {
                "name": node.name,
                "version": node.version,
                "installed": node.installed,
                "depth": node.depth,
                "is_root": node.is_root,
                "dependencies": [dep.name for dep in node.dependencies],
                "dependents": [dep.name for dep in node.dependents]
            }
        
        data["root_nodes"] = [node.name for node in self.root_nodes]
        
        return json.dumps(data, indent=2, ensure_ascii=False)
    
    def import_json(self, json_data: str) -> bool:
        """从 JSON 导入依赖树"""
        try:
            data = json.loads(json_data)
            
            self.nodes.clear()
            self.root_nodes.clear()
            
            # 创建节点
            for name, node_data in data["nodes"].items():
                node = DependencyNode(
                    node_data["name"],
                    node_data["version"],
                    node_data["installed"]
                )
                node.depth = node_data["depth"]
                node.is_root = node_data["is_root"]
                self.nodes[name] = node
            
            # 建立依赖关系
            for name, node_data in data["nodes"].items():
                node = self.nodes[name]
                for dep_name in node_data["dependencies"]:
                    if dep_name in self.nodes:
                        node.add_dependency(self.nodes[dep_name])
            
            # 设置根节点
            for root_name in data["root_nodes"]:
                if root_name in self.nodes:
                    self.root_nodes.append(self.nodes[root_name])
            
            return True
        except Exception as e:
            print(f"导入 JSON 失败: {e}", file=sys.stderr)
            return False
    
    def get_statistics(self) -> Dict[str, Any]:
        """获取依赖树统计信息"""
        if not self.nodes:
            return {
                "total_packages": 0,
                "root_packages": 0,
                "max_depth": 0,
                "avg_depth": 0,
                "orphaned_packages": 0,
                "circular_dependencies": 0
            }
        
        depths = [node.depth for node in self.nodes.values()]
        orphaned = len(self.find_orphaned_packages())
        cycles = len(self.get_circular_dependencies())
        
        return {
            "total_packages": len(self.nodes),
            "root_packages": len(self.root_nodes),
            "max_depth": max(depths),
            "avg_depth": sum(depths) / len(depths),
            "orphaned_packages": orphaned,
            "circular_dependencies": cycles
        }


def create_dependency_tree(python_exe: str = None) -> DependencyTree:
    """创建依赖树实例"""
    return DependencyTree(python_exe)


def analyze_dependencies(packages: List[str] = None, python_exe: str = None, root_packages: List[str] = None) -> DependencyTree:
    """分析依赖关系"""
    tree = create_dependency_tree(python_exe)
    tree.build_tree(packages, root_packages)
    return tree


def find_unused_dependencies(tree: DependencyTree) -> List[str]:
    """查找未使用的依赖"""
    orphaned = tree.find_orphaned_packages()
    return [node.name for node in orphaned]


def check_circular_dependencies(tree: DependencyTree) -> List[List[str]]:
    """检查循环依赖"""
    cycles = tree.get_circular_dependencies()
    return [[node.name for node in cycle] for cycle in cycles]


def get_dependency_path(tree: DependencyTree, from_package: str, to_package: str) -> List[str]:
    """获取两个包之间的依赖路径"""
    if from_package not in tree.nodes or to_package not in tree.nodes:
        return []
    
    # 使用 BFS 查找路径
    queue = deque([(tree.nodes[from_package], [from_package])])
    visited = set()
    
    while queue:
        node, path = queue.popleft()
        
        if node.name == to_package:
            return path
        
        if node.name in visited:
            continue
        
        visited.add(node.name)
        
        for dep in node.dependencies:
            if dep.name not in visited:
                new_path = path + [dep.name]
                queue.append((dep, new_path))
    
    return []


if __name__ == "__main__":
    # 测试代码
    tree = analyze_dependencies()
    print("依赖树统计:")
    stats = tree.get_statistics()
    for key, value in stats.items():
        print(f"  {key}: {value}")
    
    print("\n依赖树可视化:")
    print(tree.visualize_tree())
    
    print("\n未使用的依赖:")
    unused = find_unused_dependencies(tree)
    for dep in unused:
        print(f"  - {dep}")
    
    print("\n循环依赖:")
    cycles = check_circular_dependencies(tree)
    for cycle in cycles:
        print(f"  {' -> '.join(cycle)} -> {cycle[0]}")
