# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dockerfile 元数据处理"""

import re
import hashlib
import json
import logging
from pathlib import Path
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from enum import Enum

logger = logging.getLogger(__name__)


def calculate_template_hash(template_path: Path) -> str:
    try:
        return hashlib.sha256(template_path.read_bytes()).hexdigest()[:16]
    except Exception as e:
        version = "unknown"
        try:
            from agentkit.version import VERSION

            version = VERSION
        except Exception:
            version = "unknown"

        logger.warning(
            "Failed to read Dockerfile template for hashing: %s (%s)", template_path, e
        )
        seed = f"unreadable-template:{template_path}:{version}"
        return hashlib.sha256(seed.encode("utf-8")).hexdigest()[:16]


class DockerfileDecision(Enum):
    """Dockerfile 决策类型"""

    GENERATE_NEW = "generate_new"  # 生成新文件
    GENERATE_CONFIG_CHANGED = "generate_config_changed"  # 配置变化，重新生成
    KEEP_UP_TO_DATE = "keep_up_to_date"  # 保留（已是最新）
    KEEP_USER_MODIFIED = "keep_user_modified"  # 保留（用户修改）
    KEEP_USER_CUSTOM = "keep_user_custom"  # 保留（自定义文件）
    KEEP_CONFIG_CONFLICT = "keep_config_conflict"  # 保留（配置冲突+用户修改）
    KEEP_ERROR = "keep_error"  # 保留（检查出错）


@dataclass
class DockerfileMetadata:
    """Dockerfile 元数据"""

    is_managed: bool  # 是否由工具管理
    config_hash: Optional[str]  # 配置哈希
    content_hash: Optional[str]  # 内容哈希（用于检测用户修改）
    agentkit_version: Optional[str]  # AgentKit 版本
    generated_at: Optional[datetime]  # 生成时间


class MetadataExtractor:
    """元数据提取器"""

    HEADER_LINES = 20  # 只在前20行查找元数据
    HEADER_MARKER = "AUTO-GENERATED by AgentKit"
    CHECKSUM_PATTERN = r"# Checksum:\s*sha256:(\w+)"
    CONTENT_HASH_PATTERN = r"# ContentHash:\s*sha256:(\w+)"
    VERSION_PATTERN = r"AUTO-GENERATED by AgentKit v([\d.]+)"
    TIMESTAMP_PATTERN = r"# Generated:\s*(.+)"

    @staticmethod
    def extract(content: str) -> DockerfileMetadata:
        """
        从 Dockerfile 内容提取元数据

        Args:
            content: Dockerfile 内容

        Returns:
            DockerfileMetadata 对象
        """
        lines = content.split("\n")[: MetadataExtractor.HEADER_LINES]

        is_managed = False
        version = None
        config_hash = None
        content_hash = None
        generated_at = None

        for line in lines:
            # 检查是否是工具生成的
            if MetadataExtractor.HEADER_MARKER in line:
                is_managed = True
                # 提取版本
                match = re.search(MetadataExtractor.VERSION_PATTERN, line)
                if match:
                    version = match.group(1)

            # 提取配置哈希
            if "# Checksum:" in line:
                match = re.search(MetadataExtractor.CHECKSUM_PATTERN, line)
                if match:
                    config_hash = match.group(1)

            # 提取内容哈希
            if "# ContentHash:" in line:
                match = re.search(MetadataExtractor.CONTENT_HASH_PATTERN, line)
                if match:
                    content_hash = match.group(1)

            # 提取生成时间
            if "# Generated:" in line:
                match = re.search(MetadataExtractor.TIMESTAMP_PATTERN, line)
                if match:
                    try:
                        generated_at = datetime.fromisoformat(match.group(1).strip())
                    except Exception:
                        pass

            # 如果找到实际的 Dockerfile 指令（非注释行），停止解析
            stripped = line.strip()
            if stripped and not stripped.startswith("#"):
                break

        return DockerfileMetadata(
            is_managed=is_managed,
            config_hash=config_hash,
            content_hash=content_hash,
            agentkit_version=version,
            generated_at=generated_at,
        )

    @staticmethod
    def remove_metadata_header(content: str) -> str:
        """
        移除元数据头部，返回纯 Dockerfile 内容

        Args:
            content: 完整的 Dockerfile 内容

        Returns:
            移除头部后的内容
        """
        lines = content.split("\n")

        # 找到第一个非注释、非空行
        start_idx = 0
        for i, line in enumerate(lines):
            stripped = line.strip()
            if stripped and not stripped.startswith("#"):
                start_idx = i
                break

        return "\n".join(lines[start_idx:])

    @staticmethod
    def calculate_config_hash(config_dict: dict) -> str:
        """
        计算配置哈希

        Args:
            config_dict: 配置字典（包含影响 Dockerfile 的关键字段）

        Returns:
            SHA256 哈希值（前16位）
        """
        try:
            # 序列化为 JSON，确保顺序一致
            content = json.dumps(config_dict, sort_keys=True, ensure_ascii=False)

            # 计算 SHA256
            hash_value = hashlib.sha256(content.encode("utf-8")).hexdigest()

            # 返回前16位（足够唯一）
            return hash_value[:16]
        except Exception as e:
            logger.error(f"计算配置哈希失败: {e}", exc_info=True)
            return "unknown"

    @staticmethod
    def calculate_content_hash(content: str) -> str:
        """
        计算 Dockerfile 内容哈希（移除元数据头后）

        Args:
            content: Dockerfile 内容（可能包含元数据头）

        Returns:
            SHA256 哈希值（前16位）
        """
        try:
            # 移除元数据头
            body = MetadataExtractor.remove_metadata_header(content)

            # 标准化内容
            normalized = ContentComparator.normalize_content(body)

            # 计算 SHA256
            hash_value = hashlib.sha256(normalized.encode("utf-8")).hexdigest()

            # 返回前16位
            return hash_value[:16]
        except Exception as e:
            logger.error(f"计算内容哈希失败: {e}", exc_info=True)
            return "unknown"


class ContentComparator:
    """内容比较器"""

    @staticmethod
    def is_modified(current_content: str, expected_content: str) -> bool:
        """
        判断内容是否被用户修改

        策略：
        1. 移除两者的元数据头
        2. 标准化空白字符
        3. 逐行比较

        Args:
            current_content: 当前文件内容
            expected_content: 期望的内容（基于旧配置重新生成）

        Returns:
            是否被修改
        """
        # 移除元数据头
        current_body = MetadataExtractor.remove_metadata_header(current_content)
        expected_body = MetadataExtractor.remove_metadata_header(expected_content)

        # 标准化内容
        current_normalized = ContentComparator.normalize_content(current_body)
        expected_normalized = ContentComparator.normalize_content(expected_body)

        # 直接比较
        return current_normalized != expected_normalized

    @staticmethod
    def normalize_content(content: str) -> str:
        """
        标准化 Dockerfile 内容

        处理：
        1. 移除行尾空白
        2. 统一多个连续空行为一个
        3. 标准化注释格式

        Args:
            content: 原始内容

        Returns:
            标准化后的内容
        """
        lines = content.splitlines()
        normalized_lines = []

        prev_empty = False
        for line in lines:
            # 移除行尾空白
            line = line.rstrip()

            # 空行处理（多个连续空行变一个）
            if not line:
                if not prev_empty:
                    normalized_lines.append("")
                    prev_empty = True
                continue

            prev_empty = False

            # 标准化注释（保留注释内容，但标准化空格）
            if line.startswith("#"):
                comment_content = line[1:].strip()
                if comment_content:
                    normalized_lines.append(f"# {comment_content}")
                else:
                    normalized_lines.append("#")
            else:
                normalized_lines.append(line)

        return "\n".join(normalized_lines)
