import inspect
import ast
from ast import NodeVisitor
from typing import Any, Dict, List


import dlt.reflection.names as n


class PipelineScriptVisitor(NodeVisitor):

    def __init__(self, source: str, add_parents: bool = False):
        self.source = source
        self.add_parents = add_parents

        self.mod_aliases: Dict[str, str] = {}
        self.func_aliases: Dict[str, str] = {}
        # self.source_aliases: Dict[str, str] = {}
        self.is_destination_imported: bool = False
        self.known_calls: Dict[str, List[inspect.BoundArguments]] = {}
        self.known_sources: Dict[str, ast.FunctionDef] = {}
        self.known_source_calls: Dict[str, List[ast.Call]] = {}

    def visit(self, tree: ast.AST) -> Any:
        if self.add_parents:
            for node in ast.walk(tree):
                for child in ast.iter_child_nodes(node):
                    child.parent = node if node is not tree else None  # type: ignore
        super().visit(tree)

    def visit_Import(self, node: ast.Import) -> Any:
        # reflect on imported modules
        for alias in node.names:
            # detect dlt import
            if alias.name == n.DLT:
                eff_name = alias.asname or alias.name
                self.mod_aliases[eff_name] = alias.name
                self._add_f_aliases(eff_name)
            if alias.name.startswith(f"{n.DLT}.") and alias.asname is None:
                # this also imports dlt
                self.mod_aliases[alias.name] = alias.name
                self._add_f_aliases(alias.name)
            if alias.name.startswith(f"{n.DESTINATIONS}."):
                self.is_destination_imported = True
        super().generic_visit(node)

    def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
        # reflect on pipeline functions and decorators
        if node.module == n.DLT:
            for alias in node.names:
                if alias.name in n.DETECTED_FUNCTIONS:
                    self.func_aliases[alias.asname or alias.name] = alias.name
        if node.module == n.DESTINATIONS:
            self.is_destination_imported = True
        super().generic_visit(node)

    def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
        # find all sources and resources by inspecting decorators
        for deco in node.decorator_list:
            alias_name = self.source_segment(deco)
            fn = self.func_aliases.get(alias_name)
            if fn in [n.SOURCE, n.RESOURCE]:
                self.known_sources[str(node.name)] = node
        super().generic_visit(node)

    def visit_Call(self, node: ast.Call) -> Any:
        # check if this is a call to any of known functions
        alias_name = self.source_segment(node.func)
        fn = self.func_aliases.get(alias_name)
        if not fn:
            # try a fallback to "run" function that may be called on pipeline or source
            if isinstance(node.func, ast.Attribute) and node.func.attr == n.RUN:
                fn = n.RUN
        if fn:
            sig = n.SIGNATURES[fn]
            try:
                # bind the signature where the argument values are the corresponding ast nodes
                bound_args = sig.bind(*node.args, **{str(kwd.arg):kwd.value for kwd in node.keywords})
                bound_args.apply_defaults()
                # print(f"ALIAS: {alias_name} of {self.func_aliases.get(alias_name)} with {bound_args}")
                fun_calls = self.known_calls.setdefault(fn, [])
                fun_calls.append(bound_args)
            except TypeError:
                # skip the signature
                pass
        else:
            # check if this is a call to any known source
            if alias_name in self.known_sources:
                source_calls = self.known_source_calls.setdefault(alias_name, [])
                source_calls.append(node)

        # visit the children
        super().generic_visit(node)

    def source_segment(self, node: ast.AST) -> str:
        # TODO: this must cache parsed source. right now the full source is tokenized on every call
        return ast.get_source_segment(self.source, node)

    def _add_f_aliases(self, module_name: str) -> None:
        for fn in n.DETECTED_FUNCTIONS:
            self.func_aliases[f"{module_name}.{fn}"] = fn
