"""Lark transformer to convert parse tree to AST."""

from lark import Token, Transformer

from nlql.ast.nodes import (
    ComparisonExpr,
    FunctionCall,
    Identifier,
    Literal,
    LogicalExpr,
    OperatorCall,
    OrderByClause,
    SelectStatement,
    UnaryOp,
    WhereClause,
)


class NLQLTransformer(Transformer):
    """Transform Lark parse tree into NLQL AST nodes."""

    def select_statement(self, items: list) -> SelectStatement:
        """Transform SELECT statement."""
        select_unit = items[0]
        where = None
        order_by = []
        limit = None

        for item in items[1:]:
            if isinstance(item, WhereClause):
                where = item
            elif isinstance(item, list) and item and isinstance(item[0], OrderByClause):
                order_by = item
            elif isinstance(item, int):
                limit = item

        # Handle span configuration
        span_config = None
        if isinstance(select_unit, dict):
            span_config = select_unit
            select_unit = "SPAN"

        return SelectStatement(
            select_unit=select_unit,
            span_config=span_config,
            where=where,
            order_by=order_by,
            limit=limit,
        )

    def select_unit(self, items: list) -> str | dict:
        """Transform select unit."""
        if not items:
            return "CHUNK"  # Default
        item = items[0]
        # If it's a dict, it's a span_unit result
        if isinstance(item, dict):
            return item
        # Otherwise it's a Token with the unit type
        return str(item)

    def span_unit(self, items: list) -> dict:
        """Transform SPAN unit configuration."""
        unit_type = str(items[0])
        window = 1  # default window size

        # Check if window parameter is provided
        if len(items) > 1:
            # Convert Token to int
            window_token = items[1]
            window = int(str(window_token))

        return {"unit": unit_type, "window": window}

    def where_clause(self, items: list) -> WhereClause:
        """Transform WHERE clause."""
        return WhereClause(condition=items[0])

    def condition(self, items: list) -> any:
        """Transform condition."""
        return items[0]

    def logical_or(self, items: list) -> LogicalExpr:
        """Transform OR expression."""
        # items = [left, right]
        left, right = items
        # If left is already an OR expression, extend it
        if isinstance(left, LogicalExpr) and left.op == "OR":
            return LogicalExpr(op="OR", operands=left.operands + [right])
        else:
            return LogicalExpr(op="OR", operands=[left, right])

    def logical_and(self, items: list) -> LogicalExpr:
        """Transform AND expression."""
        # items = [left, right]
        left, right = items
        # If left is already an AND expression, extend it
        if isinstance(left, LogicalExpr) and left.op == "AND":
            return LogicalExpr(op="AND", operands=left.operands + [right])
        else:
            return LogicalExpr(op="AND", operands=[left, right])

    def logical_not(self, items: list) -> UnaryOp:
        """Transform NOT expression."""
        # items = [operand]
        return UnaryOp(op="NOT", operand=items[0])

    def comparison_expr(self, items: list) -> ComparisonExpr:
        """Transform comparison expression."""
        left = items[0]
        op = str(items[1])
        right = items[2]
        return ComparisonExpr(op=op, left=left, right=right)

    def operator_call(self, items: list) -> OperatorCall:
        """Transform operator call."""
        operator = str(items[0])
        args = items[1] if len(items) > 1 else []
        return OperatorCall(operator=operator, args=args)

    def function_call(self, items: list) -> FunctionCall:
        """Transform function call."""
        name = str(items[0])
        args = items[1] if len(items) > 1 else []
        return FunctionCall(name=name, args=args)

    def arguments(self, items: list) -> list:
        """Transform function/operator arguments."""
        return items

    def order_by_clause(self, items: list) -> list[OrderByClause]:
        """Transform ORDER BY clause."""
        return items

    def order_by_field(self, items: list) -> OrderByClause:
        """Transform ORDER BY field.

        The field can be:
        - A Token with value "SIMILARITY" (for ORDER BY SIMILARITY)
        - An Identifier node (for ORDER BY META("field") - parsed as identifier)
        - A string identifier name (for simple field names)
        """
        if not items:
            raise ValueError("order_by_field requires at least one item")

        field = items[0]

        # Handle SIMILARITY keyword (Token)
        if hasattr(field, 'type') and hasattr(field, 'value'):
            # It's a Token - extract the value
            field_value = field.value
        elif isinstance(field, Identifier):
            # It's an Identifier node - keep it as is for META field access
            field_value = field
        else:
            # Fallback: convert to string
            field_value = str(field)

        direction = "ASC"  # default
        if len(items) > 1:
            direction = str(items[1])

        return OrderByClause(field=field_value, direction=direction)

    def limit_clause(self, items: list) -> int:
        """Transform LIMIT clause."""
        return int(items[0])

    def literal(self, items: list) -> Literal:
        """Transform literal value."""
        token = items[0]
        if isinstance(token, Token):
            if token.type == "STRING":
                # Remove quotes
                value = str(token.value).strip('"\'')
                return Literal(value=value, type="string")
            elif token.type == "NUMBER":
                value = float(token.value) if "." in token.value else int(token.value)
                return Literal(value=value, type="number")
            elif token.type == "BOOLEAN":
                value = token.value.lower() == "true"
                return Literal(value=value, type="boolean")
        return Literal(value=str(token), type="string")

    def identifier(self, items: list) -> Identifier:
        """Transform identifier."""
        return Identifier(name=str(items[0]))

    def expr(self, items: list) -> any:
        """Transform expression."""
        return items[0]

