"""Main query executor for NLQL."""

import logging
from typing import Any

from nlql.adapters.base import QueryPlan
from nlql.ast.nodes import SelectStatement
from nlql.engine.context import ExecutionContext
from nlql.engine.evaluator import WhereEvaluator
from nlql.errors import NLQLExecutionError
from nlql.result import Result
from nlql.text.units import TextUnit

logger = logging.getLogger(__name__)


class Executor:
    """Main executor for NLQL queries.

    This orchestrates the three-phase execution:
    1. Parsing (done before this)
    2. Routing (pushdown vs in-memory)
    3. Reshaping (granularity processing)
    """

    def __init__(self, context: ExecutionContext) -> None:
        """Initialize the executor.

        Args:
            context: Execution context with adapter and configuration
        """
        self.context = context
        self.evaluator = WhereEvaluator(context)

    def execute(self, ast: SelectStatement) -> list[Result]:
        """Execute a parsed NLQL query.

        Execution order:
        1. Create query plan
        2. Fetch data from adapter
        3. Apply semantic search (if SIMILAR_TO in WHERE)
        4. Apply granularity transformation (SENTENCE/SPAN)
        5. Apply WHERE filtering
        6. Apply ORDER BY sorting
        7. Apply LIMIT

        Args:
            ast: Parsed SELECT statement

        Returns:
            List of query results

        Raises:
            NLQLExecutionError: If execution fails
        """
        debug_mode = self.context.config.debug_mode

        try:
            if debug_mode:
                logger.debug(f"Executing query: SELECT {ast.select_unit}")
                logger.debug(f"AST: {ast}")

            # Phase 1: Create query plan (simplified for now)
            plan = self._create_query_plan(ast)
            if debug_mode:
                logger.debug(f"Query plan created: {plan}")

            # Phase 2: Execute against data source
            units = self.context.adapter.query(plan)
            if debug_mode:
                logger.debug(f"Fetched {len(units)} units from adapter")

            # Phase 2.5: Apply semantic search if SIMILAR_TO is in WHERE clause
            # This must happen BEFORE granularity transformation so that similarity
            # scores are computed on the original chunks
            if ast.where is not None:
                units = self._apply_semantic_search(units, ast.where.condition)
                if debug_mode:
                    logger.debug(f"After semantic search: {len(units)} units")

            # Phase 3: Apply granularity transformation
            # This must happen BEFORE WHERE filtering so that WHERE operates
            # on the requested granularity (SENTENCE/SPAN) not on chunks
            units = self._apply_granularity_transformation(units, ast)
            if debug_mode:
                logger.debug(f"After granularity transformation: {len(units)} units")

            # Phase 4: Apply in-memory filters
            if ast.where is not None:
                units = [
                    unit
                    for unit in units
                    if self.evaluator.evaluate(ast.where.condition, unit)
                ]
                if debug_mode:
                    logger.debug(f"After WHERE filtering: {len(units)} units")

            # Phase 5: Apply ordering
            if ast.order_by:
                units = self._apply_ordering(units, ast.order_by)
                if debug_mode:
                    logger.debug(f"After ORDER BY: {len(units)} units")

            # Phase 6: Apply limit
            # Use query LIMIT if specified, otherwise use config default_limit
            limit = ast.limit if ast.limit is not None else self.context.config.default_limit
            if limit is not None:
                units = units[:limit]
                if debug_mode:
                    logger.debug(f"After LIMIT {limit}: {len(units)} units")

            # Convert to Result objects
            results = [
                Result(
                    content=unit.content,
                    metadata=unit.metadata,
                    unit=ast.select_unit,
                    source_id=getattr(unit, "chunk_id", None)
                    or getattr(unit, "source_id", None),
                )
                for unit in units
            ]

            if debug_mode:
                logger.debug(f"Query execution completed: {len(results)} results")

            return results

        except Exception as e:
            if debug_mode:
                logger.exception(f"Query execution failed with exception: {e}")
                logger.debug(f"AST at failure: {ast}")
            raise NLQLExecutionError(f"Query execution failed: {e}") from e

    def _apply_semantic_search(
        self, units: list[TextUnit], where_condition: Any
    ) -> list[TextUnit]:
        """Apply semantic search if SIMILAR_TO is in WHERE clause.

        This function:
        1. Extracts the query text from SIMILAR_TO("query")
        2. Computes embeddings and similarity scores
        3. Stores scores in unit.metadata["similarity"]

        Args:
            units: List of text units
            where_condition: WHERE clause AST node

        Returns:
            Same list of units with similarity scores added to metadata

        Note:
            This modifies units in-place by adding "similarity" to metadata.
            The similarity score is stored in metadata (not as object attribute)
            because users need to access it in WHERE and ORDER BY clauses.
        """
        from nlql.text.similarity import (
            compute_similarity_scores,
            extract_query_from_similar_to,
        )

        # Extract query text from SIMILAR_TO operator
        query_text = extract_query_from_similar_to(where_condition)

        if query_text is None:
            # No SIMILAR_TO in WHERE clause, nothing to do
            return units

        # Compute similarity scores and store in metadata
        return compute_similarity_scores(units, query_text)

    def _apply_granularity_transformation(
        self, units: list[TextUnit], ast: SelectStatement
    ) -> list[TextUnit]:
        """Apply granularity transformation based on SELECT unit type.

        Args:
            units: List of text units (typically chunks from adapter)
            ast: SELECT statement with unit type

        Returns:
            Transformed list of text units
        """
        from nlql.text.splitting import create_span, split_into_sentences
        from nlql.text.units import Sentence, Span

        select_unit = ast.select_unit

        # CHUNK and DOCUMENT: no transformation needed
        if select_unit in ("CHUNK", "DOCUMENT"):
            return units

        # SENTENCE: split each chunk into sentences
        elif select_unit == "SENTENCE":
            sentences: list[TextUnit] = []
            for unit in units:
                # Split the chunk into sentences
                unit_sentences = split_into_sentences(unit.content)

                # Preserve user metadata from parent chunk (business data only)
                # and set system fields (source_chunk_id, source_id) as object attributes
                for sentence in unit_sentences:
                    # Copy user's business metadata
                    sentence.metadata.update(unit.metadata)

                    # Set system fields as object attributes (not in metadata)
                    if hasattr(unit, "chunk_id") and unit.chunk_id is not None:
                        sentence.source_chunk_id = unit.chunk_id
                    if hasattr(unit, "source_id") and unit.source_id is not None:
                        sentence.source_id = unit.source_id

                sentences.extend(unit_sentences)

            return sentences

        # SPAN: create context windows
        elif select_unit == "SPAN":
            if ast.span_config is None:
                raise NLQLExecutionError("SPAN requires configuration (unit and window)")

            span_unit = ast.span_config.get("unit", "SENTENCE")
            window_size = ast.span_config.get("window", 1)

            # First, convert to the base unit type
            if span_unit == "SENTENCE":
                # Convert chunks to sentences first
                base_units: list[TextUnit] = []
                for unit in units:
                    unit_sentences = split_into_sentences(unit.content)
                    for sentence in unit_sentences:
                        # Copy user's business metadata
                        sentence.metadata.update(unit.metadata)

                        # Set system fields as object attributes
                        if hasattr(unit, "chunk_id") and unit.chunk_id is not None:
                            sentence.source_chunk_id = unit.chunk_id
                        if hasattr(unit, "source_id") and unit.source_id is not None:
                            sentence.source_id = unit.source_id
                    base_units.extend(unit_sentences)
            elif span_unit == "CHUNK":
                base_units = units
            else:
                raise NLQLExecutionError(f"Unsupported SPAN unit type: {span_unit}")

            # Create spans with context windows
            spans: list[TextUnit] = []
            for i, target in enumerate(base_units):
                context_before, _, context_after = create_span(
                    base_units, i, window_size
                )

                # Create Span object
                span = Span(
                    content=target.content,
                    metadata=target.metadata.copy(),
                    source_id=target.source_id,
                    target_unit=target,
                    context_before=context_before,
                    context_after=context_after,
                    window_size=window_size,
                )

                spans.append(span)

            return spans

        else:
            raise NLQLExecutionError(f"Unsupported SELECT unit type: {select_unit}")

    def _apply_ordering(
        self, units: list[TextUnit], order_by: list
    ) -> list[TextUnit]:
        """Apply ORDER BY sorting to text units.

        Args:
            units: List of text units to sort
            order_by: List of OrderByClause objects

        Returns:
            Sorted list of text units
        """
        from nlql.ast.nodes import Identifier, OrderByClause

        if not order_by:
            return units

        # Process order_by clauses in reverse order (last clause has lowest priority)
        sorted_units = units.copy()
        for clause in reversed(order_by):
            reverse = clause.direction == "DESC"

            # Create sort key function based on field type
            sort_key = self._create_sort_key(clause, reverse)
            if sort_key is None:
                # Unknown field type - skip this clause
                continue

            # Sort with the key function
            try:
                sorted_units = sorted(sorted_units, key=sort_key, reverse=reverse)
            except TypeError as e:
                # Handle comparison errors (e.g., comparing incompatible types)
                # Log warning and skip this sort clause
                import warnings

                warnings.warn(
                    f"Cannot sort by {clause.field}: {e}. Skipping this ORDER BY clause.",
                    stacklevel=2,
                )

        return sorted_units

    def _create_sort_key(self, clause: Any, reverse: bool) -> Any:
        """Create a sort key function for an ORDER BY clause.

        Args:
            clause: OrderByClause object
            reverse: Whether sorting in reverse order

        Returns:
            Sort key function, or None if field type is unknown
        """
        from nlql.ast.nodes import Identifier

        # Determine sort key based on field type
        if isinstance(clause.field, str):
            # String field - could be "SIMILARITY" or a field name
            if clause.field == "SIMILARITY":
                # Sort by similarity score
                def sort_key(unit: TextUnit) -> float:
                    # Get similarity score from metadata
                    # Default to 0.0 if not present
                    return unit.metadata.get("similarity", 0.0)

                return sort_key

            else:
                # Sort by metadata field (string field name)
                field_name = clause.field

                def sort_key(unit: TextUnit) -> tuple[int, Any]:
                    value = unit.metadata.get(field_name)
                    # Handle None values - put them at the end
                    # Use tuple (priority, value) for sorting
                    if value is None:
                        return (1, 0)  # None values have priority 1 (last)
                    return (0, value)  # Non-None values have priority 0 (first)

                return sort_key

        elif isinstance(clause.field, Identifier):
            # Identifier - sort by metadata field
            field_name = clause.field.name

            def sort_key(unit: TextUnit) -> tuple[int, Any]:
                value = unit.metadata.get(field_name)
                # Handle None values - put them at the end
                # Use tuple (priority, value) for sorting
                if value is None:
                    return (1, 0)  # None values have priority 1 (last)
                return (0, value)  # Non-None values have priority 0 (first)

            return sort_key

        else:
            # Unknown field type
            return None

    def _create_query_plan(self, ast: SelectStatement) -> QueryPlan:
        """Create a query plan from AST.

        This is a simplified version - full implementation would analyze
        WHERE clause and determine what can be pushed down.

        Args:
            ast: SELECT statement AST

        Returns:
            Query plan for the adapter
        """
        # For now, create a simple plan
        # TODO: Implement WHERE clause analysis and pushdown logic
        # Note: We don't pass limit to the adapter because we need to apply
        # WHERE filters first in the executor, then apply LIMIT
        return QueryPlan(
            filters=None,
            query_text=None,
            limit=None,  # Don't push down LIMIT - apply it after WHERE filtering
        )

