"""
Execution engine - builds iterator pipeline from logical plan.
"""

from .planner import LogicalPlan, JoinInfo
from .operators import (
    ScanIterator,
    FilterIterator,
    ProjectIterator,
    LookupJoinIterator,
    MergeJoinIterator
)

# Try importing Polars operators (optional)
try:
    from .operators_polars import (
        PolarsLookupJoinIterator,
        PolarsBatchFilterIterator,
        PolarsBatchProjectIterator,
        should_use_polars
    )
    POLARS_AVAILABLE = True
except ImportError:
    POLARS_AVAILABLE = False
    PolarsBatchFilterIterator = None
    PolarsBatchProjectIterator = None

# Try importing mmap operators (optional)
try:
    from .operators_mmap import MmapLookupJoinIterator
    MMAP_AVAILABLE = True
except ImportError:
    MMAP_AVAILABLE = False
    MmapLookupJoinIterator = None


def execute_plan(
    plan,
    sources,
    source_metadata,
    debug=False,
    use_polars=True
):
    """
    Execute a logical plan and return a generator of result row dictionaries.
    
    Args:
        plan: Logical execution plan (with optimizations: required_columns, pushable_where_expr)
        sources: Dictionary mapping table names to source functions
        source_metadata: Dictionary with metadata about sources (e.g., ordered_by)
        
    Returns:
        Generator of result row dictionaries
    """
    # Get required columns for root table (column pruning)
    root_required_columns = plan.required_columns.get(plan.root_table)
    
    # Handle pushable WHERE clause (filter pushdown)
    # For database sources, we can push WHERE to the database
    # For other sources, we'll apply it after scanning
    root_source_fn = sources[plan.root_table]
    pushable_where_sql = None
    
    if plan.pushable_where_expr:
        # Convert pushable WHERE expression to SQL string
        from .optimizer import expression_to_sql_string
        try:
            pushable_where_sql = expression_to_sql_string(plan.pushable_where_expr)
            if debug:
                print(f"  [OPTIMIZATION] Pushing WHERE clause to source: {pushable_where_sql}")
        except Exception as e:
            if debug:
                print(f"  [OPTIMIZATION] Could not push WHERE clause: {e}")
            pushable_where_sql = None
    
    # Check if source supports optimization (database sources)
    # If source_metadata has 'is_database_source', we can optimize
    root_metadata = source_metadata.get(plan.root_table, {})
    if root_metadata.get('is_database_source') and (root_required_columns or pushable_where_sql):
        # Wrap source function to pass optimization parameters
        # Database sources created with create_table_source accept dynamic parameters
        if debug:
            print(f"  [OPTIMIZATION] Applying column pruning and filter pushdown to database source")
        
        original_source_fn = root_source_fn
        
        def optimized_source_fn():
            # Try calling with optimization parameters
            # If source function accepts parameters, use them
            import inspect
            sig = inspect.signature(original_source_fn)
            if len(sig.parameters) > 0:
                # Source accepts parameters - pass optimizations
                return original_source_fn(
                    dynamic_where=pushable_where_sql,
                    dynamic_columns=list(root_required_columns) if root_required_columns else None
                )
            else:
                # Source doesn't accept parameters - use original
                return original_source_fn()
        
        root_source_fn = optimized_source_fn
    
    # Start with scan of root table
    if debug:
        if root_required_columns:
            print(f"  [SCAN] Scanning table: {plan.root_table} (columns: {len(root_required_columns)})")
        else:
            print(f"  [SCAN] Scanning table: {plan.root_table}")
    
    iterator = ScanIterator(
        root_source_fn,
        plan.root_table,
        plan.root_alias,
        required_columns=root_required_columns,
        debug=debug
    )
    
    # Apply WHERE filter if present (non-pushable conditions)
    if plan.where_expr:
        if debug:
            print(f"  [FILTER] Applying WHERE clause (non-pushable conditions)")
        
        # Use Polars batch filtering if available and beneficial
        if (use_polars and POLARS_AVAILABLE and PolarsBatchFilterIterator is not None):
            try:
                iterator = PolarsBatchFilterIterator(iterator, plan.where_expr, batch_size=10000, debug=debug)
                if debug:
                    print(f"  [OPTIMIZATION] Using Polars vectorized filtering (SIMD)")
            except Exception as e:
                if debug:
                    print(f"  [OPTIMIZATION] Polars filtering failed: {e}, using Python")
                iterator = FilterIterator(iterator, plan.where_expr, debug=debug)
        else:
            iterator = FilterIterator(iterator, plan.where_expr, debug=debug)
    
    # Apply joins in order
    for i, join_info in enumerate(plan.joins, 1):
        if debug:
            # Get required columns for joined table
            join_required_columns = plan.required_columns.get(join_info.table)
            if join_required_columns:
                print(f"  [JOIN {i}/{len(plan.joins)}] {join_info.join_type} JOIN {join_info.table} (columns: {len(join_required_columns)})")
            else:
                print(f"  [JOIN {i}/{len(plan.joins)}] {join_info.join_type} JOIN {join_info.table}")
        iterator = _build_join_iterator(
            iterator,
            join_info,
            sources,
            source_metadata,
            plan.required_columns.get(join_info.table),  # Pass required columns
            debug=debug,
            use_polars=use_polars  # Pass Polars flag
        )
    
    # Apply projection
    if debug:
        print(f"  [PROJECT] Applying SELECT projection")
        print(f"\nPipeline ready. Starting row processing...\n")
        print("-" * 60)
    
    # Use Polars batch projection if available and beneficial
    if (use_polars and POLARS_AVAILABLE and PolarsBatchProjectIterator is not None):
        try:
            iterator = PolarsBatchProjectIterator(iterator, plan.projections, batch_size=10000, debug=debug)
            if debug:
                print(f"  [OPTIMIZATION] Using Polars vectorized projection")
        except Exception as e:
            if debug:
                print(f"  [OPTIMIZATION] Polars projection failed: {e}, using Python")
            iterator = ProjectIterator(iterator, plan.projections, debug=debug)
    else:
        iterator = ProjectIterator(iterator, plan.projections, debug=debug)
    
    return iterator
    
    
def _build_join_iterator(
    left_iterator,
    join_info,
    sources,
    source_metadata,
    required_columns=None,
    debug=False,
    use_polars=True
):
    """
    Build appropriate join iterator based on source capabilities.
    
    Args:
        required_columns: Set of column names needed from right table (for column pruning)
    """
    right_source = sources[join_info.table]
    right_metadata = source_metadata.get(join_info.table, {})
    
    # Apply column pruning to right source if needed
    # For database sources, this would be handled at source creation
    # For other sources, LookupJoinIterator will handle it via ScanIterator
    
    # Check if both sides are ordered by join keys
    left_ordered_by = _extract_table_from_key(join_info.left_key)
    right_ordered_by = right_metadata.get("ordered_by")
    
    # For merge join, we need both sides sorted on their respective join keys
    # This is a simplified check - in practice, we'd need to verify the actual
    # column names match
    use_merge_join = (
        right_ordered_by is not None and
        right_ordered_by == _extract_column_from_key(join_info.right_key)
    )
    
    if use_merge_join:
        if debug:
            print(f"      Using MERGE JOIN (sorted data)")
        return MergeJoinIterator(
            left_iterator,
            right_source,
            join_info.left_key,
            join_info.right_key,
            join_info.join_type,
            join_info.table,
            join_info.alias,
            debug=debug
        )
    else:
        # Check if we can use mmap-based join (lowest memory - PRIORITY for large tables)
        right_metadata = source_metadata.get(join_info.table, {})
        right_table_filename = right_metadata.get("filename")
        
        # For very large tables, prefer mmap over Polars (memory is critical)
        if (MMAP_AVAILABLE and MmapLookupJoinIterator is not None and 
            right_table_filename):
            if debug:
                if required_columns:
                    print(f"      Using MMAP LOOKUP JOIN (low memory, columns: {len(required_columns)})...")
                else:
                    print(f"      Using MMAP LOOKUP JOIN (low memory, position-based index)...")
            try:
                return MmapLookupJoinIterator(
                    left_iterator,
                    right_source,
                    join_info.left_key,
                    join_info.right_key,
                    join_info.join_type,
                    join_info.table,
                    join_info.alias,
                    right_table_filename=right_table_filename,
                    required_columns=required_columns,
                    debug=debug
                )
            except Exception as e:
                if debug:
                    print(f"      ⚠️  Mmap join failed: {e}")
                    import traceback
                    traceback.print_exc()
                    print(f"      Falling back to Polars/Python")
                # Fallback to Polars or Python
                pass
        
        # Decide between Polars and Python implementation
        # NOTE: Only use Polars if mmap not available (mmap is better for memory)
        if (use_polars and POLARS_AVAILABLE and 
            should_use_polars(right_source, threshold=10000)):
            if debug:
                if required_columns:
                    print(f"      Using POLARS LOOKUP JOIN (fast, columns: {len(required_columns)})...")
                else:
                    print(f"      Using POLARS LOOKUP JOIN (fast, vectorized)...")
            try:
                return PolarsLookupJoinIterator(
                    left_iterator,
                    right_source,
                    join_info.left_key,
                    join_info.right_key,
                    join_info.join_type,
                    join_info.table,
                    join_info.alias,
                    batch_size=10000,
                    required_columns=required_columns,
                    debug=debug
                )
            except Exception as e:
                if debug:
                    print(f"      Polars join failed: {e}, falling back to Python")
                # Fallback to Python implementation
                pass
        
        if debug:
            if required_columns:
                print(f"      Using LOOKUP JOIN (building index, columns: {len(required_columns)})...")
            else:
                print(f"      Using LOOKUP JOIN (building index...)")
        return LookupJoinIterator(
            left_iterator,
            right_source,
            join_info.left_key,
            join_info.right_key,
            join_info.join_type,
            join_info.table,
            join_info.alias,
            required_columns=required_columns,
            debug=debug
        )


def _extract_table_from_key(key):
    """Extract table alias from a key like 'alias.column'."""
    if "." in key:
        return key.split(".", 1)[0]
    return None


def _extract_column_from_key(key):
    """Extract column name from a key like 'alias.column'."""
    if "." in key:
        return key.split(".", 1)[1]
    return key

