from prompt_toolkit.completion import NestedCompleter
from prompt_toolkit.document import Document
from prompt_toolkit.validation import DummyValidator, Validator, ValidationError

from python_tty.commands.registry import COMMAND_REGISTRY, ArgSpec
from python_tty.exceptions.console_exception import ConsoleInitException
from python_tty.utils import split_cmd


class CommandValidator(Validator):
    def __init__(self, command_validators: dict, enable_undefined_command=False):
        self.command_validators = command_validators
        self.enable_undefined_command = enable_undefined_command
        super().__init__()

    def validate(self, document: Document) -> None:
        try:
            token, arg_text, _ = split_cmd(document.text)
            if token in self.command_validators.keys():
                cmd_validator = self.command_validators[token]
                cmd_validator.validate(Document(text=arg_text))
            else:
                if not self.enable_undefined_command:
                    raise ValidationError(message="Bad command")
        except ValueError:
            return


class BaseCommands:
    def __init__(self, console, registry=None):
        self.console = console
        self.registry = registry if registry is not None else COMMAND_REGISTRY
        self.command_defs = []
        self.command_defs_by_name = {}
        self.command_defs_by_id = {}
        self.command_completers = {}
        self.command_validators = {}
        self.command_funcs = {}
        self._init_funcs()
        self.completer = NestedCompleter.from_nested_dict(self.command_completers)
        self.validator = CommandValidator(self.command_validators, self.enable_undefined_command)

    @property
    def enable_undefined_command(self):
        return False

    def _init_funcs(self):
        if self.console is None:
            raise ConsoleInitException("Console is None")
        defs = self.registry.get_command_defs_for_console(self.console.__class__)
        if len(defs) == 0:
            defs = self.registry.collect_from_commands_cls(self.__class__)
        self.command_defs = defs
        self._collect_completer_and_validator(defs)

    def _collect_completer_and_validator(self, defs):
        for command_def in defs:
            self._map_components(command_def)

    def _map_components(self, command_def):
        command_id = self._build_command_id(command_def)
        if command_id is not None:
            self.command_defs_by_id[command_id] = command_def
        for command_name in command_def.all_names():
            self.command_funcs[command_name] = command_def.func
            self.command_defs_by_name[command_name] = command_def
            if command_def.completer is None:
                self.command_completers[command_name] = None
            else:
                self.command_completers[command_name] = self._build_completer(command_def)
            self.command_validators[command_name] = self._build_validator(command_def)

    def _build_completer(self, command_def):
        try:
            return command_def.completer(self.console, command_def.arg_spec)
        except TypeError:
            try:
                return command_def.completer(self.console)
            except TypeError as exc:
                raise ConsoleInitException(
                    "Completer init failed. Use completer_from(...) to adapt "
                    "prompt_toolkit completers."
                ) from exc

    def _build_validator(self, command_def):
        if command_def.validator is None:
            return DummyValidator()
        try:
            return command_def.validator(self.console, command_def.func, command_def.arg_spec)
        except TypeError:
            return command_def.validator(self.console, command_def.func)

    def get_command_def(self, command_name):
        command_def = self.command_defs_by_id.get(command_name)
        if command_def is not None:
            return command_def
        return self.command_defs_by_name.get(command_name)

    def get_command_def_by_id(self, command_id):
        return self.command_defs_by_id.get(command_id)

    def get_command_id(self, command_name):
        command_def = self.command_defs_by_name.get(command_name)
        if command_def is None:
            return None
        return self._build_command_id(command_def)

    def _build_command_id(self, command_def):
        console_name = getattr(self.console, "console_name", None)
        if not console_name:
            console_name = self.console.__class__.__name__.lower()
        return f"cmd:{console_name}:{command_def.func_name}"

    def deserialize_args(self, command_def, raw_text):
        if command_def.arg_spec is None:
            arg_spec = ArgSpec.from_signature(command_def.func)
            return arg_spec.parse(raw_text)
        return command_def.arg_spec.parse(raw_text)

