"""NLQL parser implementation using Lark."""

from pathlib import Path
from typing import Any

from lark import Lark, UnexpectedInput, UnexpectedToken
from lark.exceptions import LarkError

from nlql.ast.nodes import SelectStatement
from nlql.errors import NLQLParseError
from nlql.parser.transformer import NLQLTransformer


class NLQLParser:
    """Parser for NLQL queries.

    Uses Lark for parsing and transforms the parse tree into AST nodes.
    """

    def __init__(self) -> None:
        """Initialize the parser with the NLQL grammar."""
        grammar_path = Path(__file__).parent / "grammar.lark"
        with open(grammar_path, encoding="utf-8") as f:
            grammar = f.read()

        self._parser = Lark(
            grammar,
            parser="lalr",
            start="start",
            propagate_positions=True,
        )
        self._transformer = NLQLTransformer()

    def parse(self, query: str) -> SelectStatement:
        """Parse an NLQL query string into an AST.

        Args:
            query: NLQL query string

        Returns:
            SelectStatement AST node

        Raises:
            NLQLParseError: If parsing fails
        """
        try:
            tree = self._parser.parse(query)
            ast = self._transformer.transform(tree)
            return ast
        except UnexpectedToken as e:
            self._raise_parse_error(e, query)
        except UnexpectedInput as e:
            self._raise_parse_error(e, query)
        except LarkError as e:
            raise NLQLParseError(f"Parse error: {e}") from e

    def _raise_parse_error(self, error: Any, query: str) -> None:
        """Convert Lark error to NLQLParseError with context.

        Args:
            error: Lark error object
            query: Original query string

        Raises:
            NLQLParseError: With line/column information and context
        """
        line = getattr(error, "line", None)
        column = getattr(error, "column", None)

        # Extract context lines
        context = None
        if line is not None:
            lines = query.split("\n")
            if 0 < line <= len(lines):
                context_lines = []
                # Show line before, error line, and line after
                for i in range(max(0, line - 2), min(len(lines), line + 1)):
                    prefix = ">>> " if i == line - 1 else "    "
                    context_lines.append(f"{prefix}{lines[i]}")
                context = "\n".join(context_lines)

        message = str(error)
        raise NLQLParseError(
            message=message,
            line=line,
            column=column,
            context=context,
        ) from error

