"""Python import and function parser using ast-grep."""

from pathlib import Path
from ast_grep_py import SgRoot

from .base import BaseParser, ImportInfo, FunctionInfo, ClassInfo, ParseResult


class PythonParser(BaseParser):
    """Parser for Python files."""

    supported_extensions = [".py", ".pyw"]

    def parse(self, content: str, file_path: Path) -> list[ImportInfo]:
        """Parse Python source and extract imports."""
        imports: list[ImportInfo] = []

        try:
            root = SgRoot(content, "python")
            node = root.root()

            # Find 'import xxx' statements
            for imp in node.find_all(kind="import_statement"):
                self._extract_import_statement(imp, imports)

            # Find 'from xxx import yyy' statements
            for imp in node.find_all(kind="import_from_statement"):
                self._extract_from_import(imp, imports)

        except Exception as e:
            # If parsing fails, return empty list
            print(f"Warning: Failed to parse {file_path}: {e}")

        return imports

    def parse_full(self, content: str, file_path: Path) -> ParseResult:
        """Parse Python source and extract all information."""
        result = ParseResult()

        try:
            root = SgRoot(content, "python")
            node = root.root()

            # Extract imports
            for imp in node.find_all(kind="import_statement"):
                self._extract_import_statement(imp, result.imports)
            for imp in node.find_all(kind="import_from_statement"):
                self._extract_from_import(imp, result.imports)

            # Extract classes
            for class_node in node.find_all(kind="class_definition"):
                class_info = self._extract_class(class_node, content)
                if class_info:
                    result.classes.append(class_info)

            # Extract top-level functions (not methods)
            for func_node in node.find_all(kind="function_definition"):
                # Check if this is a top-level function (not inside a class)
                parent = func_node.parent()
                if parent and parent.kind() not in ("class_definition", "block"):
                    func_info = self._extract_function(func_node, content)
                    if func_info:
                        result.functions.append(func_info)
                elif parent and parent.kind() == "block":
                    # Check if the block's parent is module (top-level)
                    grandparent = parent.parent()
                    if grandparent and grandparent.kind() == "module":
                        func_info = self._extract_function(func_node, content)
                        if func_info:
                            result.functions.append(func_info)

        except Exception as e:
            print(f"Warning: Failed to parse {file_path}: {e}")

        return result

    def _extract_import_statement(self, node, imports: list[ImportInfo]) -> None:
        """Extract module names from 'import xxx' statement."""
        # Find dotted_name or aliased_import children
        for child in node.children():
            if child.kind() == "dotted_name":
                module = child.text()
                imports.append(ImportInfo(
                    module=module,
                    import_type="static",
                    line=child.range().start.line + 1
                ))
            elif child.kind() == "aliased_import":
                # import xxx as yyy
                dotted = child.find(kind="dotted_name")
                if dotted:
                    imports.append(ImportInfo(
                        module=dotted.text(),
                        import_type="static",
                        line=child.range().start.line + 1
                    ))

    def _extract_from_import(self, node, imports: list[ImportInfo]) -> None:
        """Extract module name from 'from xxx import yyy' statement."""
        # The module name is in dotted_name or relative_import
        text = node.text()
        line = node.range().start.line + 1

        # Handle relative imports: from . import xxx or from .foo import xxx
        if "from ." in text:
            # Extract the relative module path
            # from . import x -> "."
            # from .foo import x -> ".foo"
            # from ..bar import x -> "..bar"
            parts = text.split("import")[0].strip()
            module = parts.replace("from", "").strip()
            imports.append(ImportInfo(
                module=module,
                import_type="static",
                line=line
            ))
        else:
            # Absolute import: from xxx import yyy
            dotted = node.find(kind="dotted_name")
            if dotted:
                imports.append(ImportInfo(
                    module=dotted.text(),
                    import_type="static",
                    line=line
                ))

    def _extract_function(self, node, content: str) -> FunctionInfo | None:
        """Extract function information from a function_definition node."""
        try:
            # Get function name
            name_node = node.find(kind="identifier")
            if not name_node:
                return None
            name = name_node.text()

            # Get line numbers
            start_line = node.range().start.line + 1
            end_line = node.range().end.line + 1

            # Build signature (first line of function)
            lines = content.split('\n')
            signature_line = lines[start_line - 1].strip() if start_line <= len(lines) else ""

            # Remove the body, keep only the def line
            if signature_line.endswith(':'):
                signature = signature_line
            else:
                # Multi-line signature, find the colon
                sig_lines = []
                for i in range(start_line - 1, min(start_line + 5, len(lines))):
                    sig_lines.append(lines[i].strip())
                    if ':' in lines[i]:
                        break
                signature = ' '.join(sig_lines)

            # Check if async
            is_async = signature.strip().startswith('async ')

            # Extract function calls
            calls = self._extract_calls(node)

            # Extract docstring
            docstring = self._extract_docstring(node)

            return FunctionInfo(
                name=name,
                signature=signature,
                start_line=start_line,
                end_line=end_line,
                calls=calls,
                is_method=False,
                is_async=is_async,
                docstring=docstring
            )
        except Exception:
            return None

    def _extract_class(self, node, content: str) -> ClassInfo | None:
        """Extract class information from a class_definition node."""
        try:
            # Get class name
            name_node = node.find(kind="identifier")
            if not name_node:
                return None
            name = name_node.text()

            # Get line numbers
            start_line = node.range().start.line + 1
            end_line = node.range().end.line + 1

            # Build signature (first line of class)
            lines = content.split('\n')
            signature = lines[start_line - 1].strip() if start_line <= len(lines) else ""

            # Extract base classes
            bases = []
            arg_list = node.find(kind="argument_list")
            if arg_list:
                for child in arg_list.children():
                    if child.kind() == "identifier":
                        bases.append(child.text())
                    elif child.kind() == "attribute":
                        bases.append(child.text())

            # Extract methods
            methods = []
            for func_node in node.find_all(kind="function_definition"):
                # Only direct children (not nested functions)
                func_info = self._extract_function(func_node, content)
                if func_info:
                    func_info.is_method = True
                    methods.append(func_info)

            # Extract docstring
            docstring = self._extract_docstring(node)

            return ClassInfo(
                name=name,
                signature=signature,
                start_line=start_line,
                end_line=end_line,
                methods=methods,
                bases=bases,
                docstring=docstring
            )
        except Exception:
            return None

    def _extract_calls(self, node) -> list[str]:
        """Extract function calls from a function body."""
        calls = []
        try:
            for call in node.find_all(kind="call"):
                # Get the function being called
                func = call.find(kind="identifier")
                if func:
                    calls.append(func.text())
                else:
                    # Could be an attribute call like obj.method()
                    attr = call.find(kind="attribute")
                    if attr:
                        # Get the last identifier (method name)
                        ids = list(attr.find_all(kind="identifier"))
                        if ids:
                            calls.append(ids[-1].text())
        except Exception:
            pass
        return list(set(calls))  # Remove duplicates

    def _extract_docstring(self, node) -> str | None:
        """Extract docstring from a function or class."""
        try:
            # Find the block containing the body
            block = node.find(kind="block")
            if not block:
                return None

            # First statement in the block
            children = list(block.children())
            if not children:
                return None

            first = children[0]
            if first.kind() == "expression_statement":
                string = first.find(kind="string")
                if string:
                    text = string.text()
                    # Remove quotes
                    if text.startswith('"""') and text.endswith('"""'):
                        return text[3:-3].strip()
                    elif text.startswith("'''") and text.endswith("'''"):
                        return text[3:-3].strip()
                    elif text.startswith('"') and text.endswith('"'):
                        return text[1:-1].strip()
                    elif text.startswith("'") and text.endswith("'"):
                        return text[1:-1].strip()
        except Exception:
            pass
        return None
