import asyncio
import aiohttp
from copy import deepcopy
from .image_processor import encode_image_to_base64
from .image_processor import DEFAULT_CACHE_DIR
from loguru import logger


async def process_content_recursive(content, session, use_cache=False, cache_dir=DEFAULT_CACHE_DIR, force_refresh=False, **kwargs):
    """Recursively process a content dictionary, replacing any URL with its Base64 equivalent."""
    if isinstance(content, dict):
        for key, value in content.items():
            if key == "url" and isinstance(value, str):  # Detect URL fields
                base64_data = await encode_image_to_base64(
                    value,
                    session,
                    max_width=kwargs.get("max_width"),
                    max_height=kwargs.get("max_height"),
                    max_pixels=kwargs.get("max_pixels"),
                    use_cache=use_cache,
                    cache_dir=cache_dir,
                    force_refresh=force_refresh,
                )
                if base64_data:
                    content[key] = base64_data
            else:
                await process_content_recursive(value, session, use_cache=use_cache, cache_dir=cache_dir, force_refresh=force_refresh, **kwargs)
    elif isinstance(content, list):
        for item in content:
            await process_content_recursive(item, session, use_cache=use_cache, cache_dir=cache_dir, force_refresh=force_refresh, **kwargs)


async def messages_preprocess(messages, inplace=False, use_cache=False, cache_dir=DEFAULT_CACHE_DIR, force_refresh=False, **kwargs):
    """Process a list of messages, converting URLs in any type of content to Base64.
    
    Args:
        messages: List of messages to process
        inplace: Whether to modify the messages in-place or create a copy
        use_cache: Whether to use cache for URL images
        cache_dir: Cache directory path
        force_refresh: Whether to force refresh the cache even if cached image exists
        **kwargs: Additional arguments to pass to image processing functions
    """
    if not inplace:
        messages = deepcopy(messages)
    async with aiohttp.ClientSession() as session:
        tasks = [
            process_content_recursive(message, session, use_cache=use_cache, cache_dir=cache_dir, force_refresh=force_refresh, **kwargs)
            for message in messages
        ]
        await asyncio.gather(*tasks)
    return messages


async def batch_messages_preprocess(
    messages_list,
    max_concurrent=5,
    inplace=False,
    use_cache=False,
    cache_dir=DEFAULT_CACHE_DIR,
    force_refresh=False,
    as_iterator=False,
    **kwargs,
):
    """Process multiple lists of messages in batches.

    Args:
        messages_list: List, iterator or async iterator of message lists to process
        max_concurrent: Maximum number of concurrent batches to process
        inplace: Whether to modify the messages in-place
        use_cache: Whether to use cache for URL images
        cache_dir: Cache directory path
        force_refresh: Whether to force refresh the cache even if cached image exists
        as_iterator: Whether to return an async iterator instead of a list
        **kwargs: Additional arguments to pass to image processing functions

    Returns:
        List of processed message lists or an async iterator yielding processed message lists
    """
    # 创建处理单个消息列表的函数
    async def process_single_batch(messages, semaphore):
        async with semaphore:
            try:
                processed_messages = await messages_preprocess(
                    messages, 
                    inplace=inplace, 
                    use_cache=use_cache, 
                    cache_dir=cache_dir,
                    force_refresh=force_refresh,
                    **kwargs
                )
            except Exception as e:
                logger.error(f"{e=}\n")
                processed_messages = messages
            return processed_messages

    # 如果要求返回迭代器
    if as_iterator:
        async def process_iterator():
            semaphore = asyncio.Semaphore(max_concurrent)
            
            # 检查是否为异步迭代器
            is_async_iterator = hasattr(messages_list, '__aiter__')
            
            # 处理异步迭代器
            if is_async_iterator:
                pending_tasks = []
                async for messages in messages_list:
                    # 如果已经达到最大并发数，等待一个任务完成
                    if len(pending_tasks) >= max_concurrent:
                        done, pending_tasks = await asyncio.wait(
                            pending_tasks, 
                            return_when=asyncio.FIRST_COMPLETED
                        )
                        for task in done:
                            yield await task
                    
                    # 创建新任务
                    task = asyncio.create_task(process_single_batch(messages, semaphore))
                    pending_tasks.append(task)
                
                # 等待所有剩余任务完成
                if pending_tasks:
                    for task in asyncio.as_completed(pending_tasks):
                        yield await task
            
            # 处理同步迭代器或列表
            else:
                # 转换为列表以避免消耗迭代器
                if not isinstance(messages_list, (list, tuple)):
                    messages_list_converted = list(messages_list)
                else:
                    messages_list_converted = messages_list
                
                # 分批处理
                for i in range(0, len(messages_list_converted), max_concurrent):
                    batch = messages_list_converted[i:i+max_concurrent]
                    tasks = [process_single_batch(messages, semaphore) for messages in batch]
                    results = await asyncio.gather(*tasks)
                    
                    for result in results:
                        yield result
        
        # 返回异步生成器函数的调用结果
        
        return process_iterator()
    
    # 原始实现，返回列表
    else:
        semaphore = asyncio.Semaphore(max_concurrent)
        
        # 检查是否为异步迭代器
        is_async_iterator = hasattr(messages_list, '__aiter__')
        
        # 转换为列表
        if is_async_iterator:
            messages_list_converted = []
            async for messages in messages_list:
                messages_list_converted.append(messages)
        elif not isinstance(messages_list, (list, tuple)):
            messages_list_converted = list(messages_list)
        else:
            messages_list_converted = messages_list
        
        if not messages_list_converted:
            return []
        
        tasks = [process_single_batch(messages, semaphore) for messages in messages_list_converted]
        results = await asyncio.gather(*tasks)
        return results
    

batch_process_messages = batch_messages_preprocess

