"""JavaScript/TypeScript/TSX import and function parser using ast-grep."""

from pathlib import Path
from ast_grep_py import SgRoot

from .base import BaseParser, ImportInfo, FunctionInfo, ClassInfo, ParseResult


class JsParser(BaseParser):
    """Parser for JavaScript, TypeScript, and TSX files."""

    supported_extensions = [".js", ".jsx", ".ts", ".tsx", ".mjs", ".cjs"]

    def parse(self, content: str, file_path: Path) -> list[ImportInfo]:
        """Parse JS/TS source and extract imports."""
        imports: list[ImportInfo] = []

        # Determine language based on extension
        ext = file_path.suffix.lower()
        if ext in [".ts", ".tsx"]:
            lang = "typescript"
        else:
            lang = "javascript"

        try:
            root = SgRoot(content, lang)
            node = root.root()

            # ES6 static imports: import xxx from 'path'
            for imp in node.find_all(kind="import_statement"):
                self._extract_es6_import(imp, imports)

            # require() calls
            for req in node.find_all(pattern="require($PATH)"):
                self._extract_require(req, imports)

            # Dynamic imports: import('path')
            for dyn in node.find_all(kind="call_expression"):
                self._extract_dynamic_import(dyn, imports)

            # export ... from 'path'
            for exp in node.find_all(kind="export_statement"):
                self._extract_export_from(exp, imports)

        except Exception as e:
            print(f"Warning: Failed to parse {file_path}: {e}")

        return imports

    def parse_full(self, content: str, file_path: Path) -> ParseResult:
        """Parse JS/TS source and extract all information."""
        result = ParseResult()

        # Determine language based on extension
        ext = file_path.suffix.lower()
        if ext in [".ts", ".tsx"]:
            lang = "typescript"
        else:
            lang = "javascript"

        try:
            root = SgRoot(content, lang)
            node = root.root()

            # Extract imports
            for imp in node.find_all(kind="import_statement"):
                self._extract_es6_import(imp, result.imports)
            for req in node.find_all(pattern="require($PATH)"):
                self._extract_require(req, result.imports)
            for dyn in node.find_all(kind="call_expression"):
                self._extract_dynamic_import(dyn, result.imports)
            for exp in node.find_all(kind="export_statement"):
                self._extract_export_from(exp, result.imports)

            # Extract functions
            for func_node in node.find_all(kind="function_declaration"):
                func_info = self._extract_function(func_node, content)
                if func_info:
                    result.functions.append(func_info)

            # Arrow functions assigned to const/let/var
            for var_decl in node.find_all(kind="lexical_declaration"):
                func_info = self._extract_arrow_function(var_decl, content)
                if func_info:
                    result.functions.append(func_info)

            for var_decl in node.find_all(kind="variable_declaration"):
                func_info = self._extract_arrow_function(var_decl, content)
                if func_info:
                    result.functions.append(func_info)

            # Extract classes
            for class_node in node.find_all(kind="class_declaration"):
                class_info = self._extract_class(class_node, content)
                if class_info:
                    result.classes.append(class_info)

        except Exception as e:
            print(f"Warning: Failed to parse {file_path}: {e}")

        return result

    def _extract_es6_import(self, node, imports: list[ImportInfo]) -> None:
        """Extract path from ES6 import statement."""
        # Find the string node containing the path
        string_node = node.find(kind="string")
        if string_node:
            path = self._clean_string(string_node.text())
            imports.append(ImportInfo(
                module=path,
                import_type="static",
                line=node.range().start.line + 1
            ))

    def _extract_require(self, node, imports: list[ImportInfo]) -> None:
        """Extract path from require() call."""
        # Get the argument
        args = node.find(kind="arguments")
        if args:
            string_node = args.find(kind="string")
            if string_node:
                path = self._clean_string(string_node.text())
                imports.append(ImportInfo(
                    module=path,
                    import_type="require",
                    line=node.range().start.line + 1
                ))

    def _extract_dynamic_import(self, node, imports: list[ImportInfo]) -> None:
        """Extract path from dynamic import() call."""
        # Check if this is an import() call
        func = node.find(kind="import")
        if not func:
            return

        args = node.find(kind="arguments")
        if args:
            string_node = args.find(kind="string")
            if string_node:
                path = self._clean_string(string_node.text())
                imports.append(ImportInfo(
                    module=path,
                    import_type="dynamic",
                    line=node.range().start.line + 1
                ))

    def _extract_export_from(self, node, imports: list[ImportInfo]) -> None:
        """Extract path from 'export ... from' statement."""
        text = node.text()
        if "from" not in text:
            return

        string_node = node.find(kind="string")
        if string_node:
            path = self._clean_string(string_node.text())
            imports.append(ImportInfo(
                module=path,
                import_type="static",
                line=node.range().start.line + 1
            ))

    def _extract_function(self, node, content: str) -> FunctionInfo | None:
        """Extract function information from a function_declaration node."""
        try:
            # Get function name
            name_node = node.find(kind="identifier")
            if not name_node:
                return None
            name = name_node.text()

            # Get line numbers
            start_line = node.range().start.line + 1
            end_line = node.range().end.line + 1

            # Build signature
            lines = content.split('\n')
            signature_line = lines[start_line - 1].strip() if start_line <= len(lines) else ""

            # Extract up to the opening brace
            sig_parts = []
            for i in range(start_line - 1, min(end_line, len(lines))):
                line = lines[i]
                sig_parts.append(line.strip())
                if '{' in line:
                    # Truncate at the brace
                    last = sig_parts[-1]
                    idx = last.find('{')
                    sig_parts[-1] = last[:idx].strip()
                    break
            signature = ' '.join(sig_parts).strip()
            if signature.endswith('{'):
                signature = signature[:-1].strip()

            # Check if async
            is_async = 'async' in signature_line and 'async' in signature

            # Extract function calls
            calls = self._extract_calls(node)

            return FunctionInfo(
                name=name,
                signature=signature if signature else signature_line,
                start_line=start_line,
                end_line=end_line,
                calls=calls,
                is_method=False,
                is_async=is_async,
                docstring=None  # JS doesn't have docstrings in the same way
            )
        except Exception:
            return None

    def _extract_arrow_function(self, node, content: str) -> FunctionInfo | None:
        """Extract arrow function from variable declaration."""
        try:
            # Find the arrow function
            arrow = node.find(kind="arrow_function")
            if not arrow:
                return None

            # Find the variable name
            declarator = node.find(kind="variable_declarator")
            if not declarator:
                return None

            name_node = declarator.find(kind="identifier")
            if not name_node:
                return None
            name = name_node.text()

            # Get line numbers
            start_line = node.range().start.line + 1
            end_line = node.range().end.line + 1

            # Build signature
            lines = content.split('\n')
            signature_line = lines[start_line - 1].strip() if start_line <= len(lines) else ""

            # Extract up to the arrow
            sig_parts = []
            for i in range(start_line - 1, min(end_line, len(lines))):
                line = lines[i]
                sig_parts.append(line.strip())
                if '=>' in line:
                    # Keep up to and including the arrow
                    last = sig_parts[-1]
                    idx = last.find('=>')
                    sig_parts[-1] = last[:idx + 2].strip()
                    break
            signature = ' '.join(sig_parts).strip()

            # Check if async
            is_async = 'async' in signature

            # Extract function calls
            calls = self._extract_calls(arrow)

            return FunctionInfo(
                name=name,
                signature=signature if signature else signature_line,
                start_line=start_line,
                end_line=end_line,
                calls=calls,
                is_method=False,
                is_async=is_async,
                docstring=None
            )
        except Exception:
            return None

    def _extract_class(self, node, content: str) -> ClassInfo | None:
        """Extract class information from a class_declaration node."""
        try:
            # Get class name
            name_node = node.find(kind="identifier")
            if not name_node:
                return None
            name = name_node.text()

            # Get line numbers
            start_line = node.range().start.line + 1
            end_line = node.range().end.line + 1

            # Build signature
            lines = content.split('\n')
            signature_line = lines[start_line - 1].strip() if start_line <= len(lines) else ""

            # Extract up to the opening brace
            sig_parts = []
            for i in range(start_line - 1, min(start_line + 3, len(lines))):
                line = lines[i]
                sig_parts.append(line.strip())
                if '{' in line:
                    last = sig_parts[-1]
                    idx = last.find('{')
                    sig_parts[-1] = last[:idx].strip()
                    break
            signature = ' '.join(sig_parts).strip()

            # Extract base class
            bases = []
            heritage = node.find(kind="class_heritage")
            if heritage:
                extends = heritage.find(kind="identifier")
                if extends:
                    bases.append(extends.text())

            # Extract methods
            methods = []
            class_body = node.find(kind="class_body")
            if class_body:
                for method_node in class_body.find_all(kind="method_definition"):
                    method_info = self._extract_method(method_node, content)
                    if method_info:
                        methods.append(method_info)

            return ClassInfo(
                name=name,
                signature=signature if signature else signature_line,
                start_line=start_line,
                end_line=end_line,
                methods=methods,
                bases=bases,
                docstring=None
            )
        except Exception:
            return None

    def _extract_method(self, node, content: str) -> FunctionInfo | None:
        """Extract method information from a method_definition node."""
        try:
            # Get method name
            name_node = node.find(kind="property_identifier")
            if not name_node:
                return None
            name = name_node.text()

            # Get line numbers
            start_line = node.range().start.line + 1
            end_line = node.range().end.line + 1

            # Build signature
            lines = content.split('\n')
            signature_line = lines[start_line - 1].strip() if start_line <= len(lines) else ""

            # Extract up to the opening brace
            sig_parts = []
            for i in range(start_line - 1, min(end_line, len(lines))):
                line = lines[i]
                sig_parts.append(line.strip())
                if '{' in line:
                    last = sig_parts[-1]
                    idx = last.find('{')
                    sig_parts[-1] = last[:idx].strip()
                    break
            signature = ' '.join(sig_parts).strip()

            # Check if async
            is_async = 'async' in signature

            # Extract function calls
            calls = self._extract_calls(node)

            return FunctionInfo(
                name=name,
                signature=signature if signature else signature_line,
                start_line=start_line,
                end_line=end_line,
                calls=calls,
                is_method=True,
                is_async=is_async,
                docstring=None
            )
        except Exception:
            return None

    def _extract_calls(self, node) -> list[str]:
        """Extract function calls from a function body."""
        calls = []
        try:
            for call in node.find_all(kind="call_expression"):
                # Get the function being called
                func = call.find(kind="identifier")
                if func:
                    calls.append(func.text())
                else:
                    # Could be a member expression like obj.method()
                    member = call.find(kind="member_expression")
                    if member:
                        prop = member.find(kind="property_identifier")
                        if prop:
                            calls.append(prop.text())
        except Exception:
            pass
        return list(set(calls))  # Remove duplicates

    def _clean_string(self, s: str) -> str:
        """Remove quotes from string literal."""
        return s.strip("'\"`")
