#!/usr/bin/env python3
import __future__

import argparse
import ast
import copy
import functools
import inspect
import itertools
import logging
import operator
import re
import sys
import traceback
import types
import warnings
from typing import Any, Dict, List, Set, TextIO, Union

L = logging.getLogger("ipython_utils." + __file__)

magic = "_ipy_magic"
_ipy_magic_inner: None  # dummy to resolve IDE errors
globals()[magic + "_inner"] = lambda: None


def add_except_hook(logger=L):
    """
    add exception hook, which will embed a shell when an exception is raised
    """
    sys.excepthook = get_except_hook(logger)


def get_except_hook(logger):
    """
    get exception hook to embed a shell when an exception is raised
    :param logger: use this logger to print exceptions and info
    """

    def excepthook(etype, value, tb):
        """
        exception hook to embed a shell when an exception is raised
        :param etype: exception type
        :param value: exception
        :param tb: traceback
        """
        import inspect
        exc_info = (etype, value, tb)
        logger.error("uncaught exception", exc_info=exc_info)
        records = inspect.getinnerframes(tb)
        for record in reversed(records):
            if "/site-packages/" in record.filename:
                continue
            if record.filename.startswith("pandas/_libs/"):
                continue
            logger.info("frame: %s", record)
            frame = record.frame
            msg = "Entering IPython console at {0.f_code.co_filename} at line {0.f_lineno}".format(
                frame)
            savehook = sys.excepthook  # save the exception hook
            embed(header=msg,
                  frame=frame,
                  extra_locals={magic + "_exc": exc_info})
            sys.excepthook = savehook  # reset IPython's change to the exception hook
            break

    return excepthook


def embed(funcs: Union[List[types.FunctionType], types.FunctionType] = None,
          *,
          frame: types.FrameType = None,
          extra_locals: Dict[str, Any] = None,
          header="",
          compile_flags=None,
          **kwargs):
    """
    This is copied from IPython/terminal/embed.py but modified to allow closures
    over local variables.

    Allowing closures over local variables is done by using a transformer added
    to shell.ast_transformers that replaces the AST completely with a call to a
    temporary injected function generated by a special closure. More precisely,
    we first dynamically create and run an outer function scope wrapping the
    cell code with the original local variables as its parameters. For
    illustration, we have the following generated when the code has `statement0;
    statement1; last_expression` and the local variables are x, y, z:

    ```
    def _ipy_magic_outer(x, y, z):
        del x, y, z
        def _ipy_magic_inner():
            nonlocal _ipy_magic_inner
            nonlocal x, y, z
            statement0
            statement1
            return last_expression
        return _ipy_magic_inner
    ```

    Then we patch the returned function's closure in order to carry forward the
    cells (and their contents) from previous cell executions. This will allow
    even inner closures to be closed over the right cells. The argument `funcs`
    allows you to have additional closed variables as the resolving of cells in
    this patching process will include cells in the closure of functions in
    `funcs`. Note that only such variable modifications are able to persist
    after the shell exits. This means that they must originally be closed over,
    because otherwise they are not stored as cells in the corresponding frames.
    Calling embed inside `try_all_statements` or another `embed` should use
    `_ipy_magic_inner` in `funcs` in order to get the right closure cells.

    Some known issues or unusual behaviour:
    * changes to local variables are only visible to code within the shell
    session (it is possible to make this work for cell variables in CPython, but
    it may be prone to breakage), unless `funcs` has a function with a closure
    containing the cells for these local variables
    * outer scope variables must already have had a closure referencing it to be
    visible (`nonlocal` does not work); changes to these are persistent
    * use of variables starting with `_ipy_magic` may result in undefined
    behaviour
    * return statements without an enclosing function does not give an error and
    the value returned is the result of the cell

    :param funcs: list of functions to get closure cells from
    :param frame: frame to get locals and globals from
    :param extra_locals: extra variable -> value dict to add to locals

    --- original docs below ---

    Call this to embed IPython at the current point in your program.

    The first invocation of this will create an :class:`InteractiveShellEmbed`
    instance and then call it.  Consecutive calls just call the already
    created instance.

    If you don't want the kernel to initialize the namespace
    from the scope of the surrounding function,
    and/or you want to load full IPython configuration,
    you probably want `IPython.start_ipython()` instead.

    Here is a simple example::

        from IPython import embed
        a = 10
        b = 20
        embed(header='First time')
        c = 30
        d = 40
        embed()

    Full customization can be done by passing a :class:`Config` in as the
    config argument.
    """
    from IPython.core.interactiveshell import InteractiveShell
    from IPython.terminal.embed import InteractiveShellEmbed
    from IPython.terminal.ipapp import load_default_config
    config = kwargs.get('config')
    if config is None:
        config = load_default_config()
        config.InteractiveShellEmbed = config.TerminalInteractiveShell
        kwargs['config'] = config
    using = kwargs.get('using', 'sync')
    if using:
        kwargs['config'].update({
            'TerminalInteractiveShell': {
                'loop_runner': using,
                'colors': 'NoColor',
                'autoawait': using != 'sync'
            }
        })
    #save ps1/ps2 if defined
    ps1 = None
    ps2 = None
    try:
        ps1 = sys.ps1
        ps2 = sys.ps2
    except AttributeError:
        pass
    #save previous instance
    saved_shell_instance = InteractiveShell._instance
    if saved_shell_instance is not None:
        cls = type(saved_shell_instance)
        cls.clear_instance()
    if frame is None:
        frame = sys._getframe(1)
    shell = InteractiveShellEmbed.instance(
        _init_location_id='%s:%s' % (frame.f_code.co_filename, frame.f_lineno),
        **kwargs)
    local_ns, module, cell_dict, write_back_vars = setup_embedded_shell(
        shell, funcs, frame, extra_locals)
    shell(header=header,
          local_ns=local_ns,
          module=module,
          compile_flags=compile_flags,
          _call_location_id='%s:%s' %
          (frame.f_code.co_filename, frame.f_lineno))
    InteractiveShellEmbed.clear_instance()
    #restore previous instance
    if saved_shell_instance is not None:
        cls = type(saved_shell_instance)
        cls.clear_instance()
        for subclass in cls._walk_mro():
            subclass._instance = saved_shell_instance
    if ps1 is not None:
        sys.ps1 = ps1
        sys.ps2 = ps2
    update_locals({k: cell_dict[k].cell_contents
                   for k in write_back_vars}, frame)


def embed_kernel(funcs: Union[List[types.FunctionType],
                              types.FunctionType] = None,
                 *,
                 frame: types.FrameType = None,
                 extra_locals: Dict[str, Any] = None,
                 **kwargs):
    """
    This is copied from ipykernel/embed.py but modified to allow closures
    over local variables. See `embed` for full documentation.

    :param funcs: list of functions to get closure cells from
    :param frame: frame to get locals and globals from
    :param extra_locals: extra variable -> value dict to add to locals

    --- original docs below ---

    Embed and start an IPython kernel in a given scope.

    Parameters
    ----------
    module : ModuleType, optional
        The module to load into IPython globals (default: caller)
    local_ns : dict, optional
        The namespace to load into IPython user namespace (default: caller)
    kwargs : dict, optional
        Further keyword args are relayed to the IPKernelApp constructor,
        allowing configuration of the Kernel.  Will only have an effect
        on the first embed_kernel call for a given process.

    """
    from ipykernel.kernelapp import IPKernelApp

    # get the app if it exists, or set it up if it doesn't
    assert not IPKernelApp.initialized()
    app = IPKernelApp.instance(**kwargs)
    app.initialize([])
    # Undo unnecessary sys module mangling from init_sys_modules.
    # This would not be necessary if we could prevent it
    # in the first place by using a different InteractiveShell
    # subclass, as in the regular embed case.
    main = app.kernel.shell._orig_sys_modules_main_mod
    if main is not None:
        sys.modules[app.kernel.shell._orig_sys_modules_main_name] = main
    if frame is None:
        frame = sys._getframe(1)
    local_ns, module, cell_dict, write_back_vars = setup_embedded_shell(
        app.kernel.shell, funcs, frame, extra_locals)
    app.kernel.user_module = module
    app.kernel.user_ns = local_ns
    app.shell.set_completer_frame()
    app.start()
    app.close()
    app.kernel.shell_class.clear_instance()
    app.kernel_class.clear_instance()
    IPKernelApp.clear_instance()
    update_locals({k: cell_dict[k].cell_contents
                   for k in write_back_vars}, frame)


def setup_embedded_shell(shell, funcs: Union[List[types.FunctionType],
                                             types.FunctionType],
                         frame: types.FrameType, extra_locals: Dict[str, Any]):
    """
    setup embedded shell for both `embed` and `embed_kernel`

    :param shell: shell to augment
    :param funcs: list of functions to get closure cells from
    :param frame: frame to get locals and globals from
    :param extra_locals: extra variable -> value dict to add to locals
    :return: local_ns, module
        local_ns: dict of locals
        module: a module that the shell should run in the scope of; has __dict__
        as the globals
    """
    local_ns = frame.f_locals
    cell_dict = get_cell_dict_from_funcs(funcs)
    write_back_vars = []
    for k, v in local_ns.items():
        if k not in cell_dict:
            write_back_vars.append(k)
            cell_dict[k] = types.CellType(v)
    # the following allows something like `_ipy_magic_shell.keep_running =
    # False` (embed) or `_ipy_magic_shell.ask_exit()` (embed_kernel) to exit the
    # shell without requiring an EOF or killing the process
    cell_dict[magic + "_shell"] = types.CellType(shell)
    if extra_locals:
        for k, v in extra_locals.items():
            cell_dict[k] = types.CellType(v)
    extra_globals = set()
    shell.ast_transformers.append(
        FixLocals(shell, cell_dict, extra_globals, magic))
    local_ns.update(iter_cell_dict_contents(cell_dict))
    old_get_exc_info = shell._get_exc_info

    def get_exc_info(exc_tuple=None):
        try:
            tb: types.TracebackType
            etype, value, tb = old_get_exc_info(exc_tuple)
            final_tb = None
            pprev = None
            prev = tb
            tb = prev.tb_next
            while tb:
                if ((prev.tb_frame.f_code.co_filename.startswith(
                        "<ipython-input-") or prev.tb_frame.f_code.co_filename.
                     startswith("/tmp/ipykernel_"))
                        and prev.tb_frame.f_code.co_name == "<module>"
                        and tb.tb_frame.f_code.co_name == magic + "_cell"):
                    prev = tb
                    if pprev is not None:
                        pprev.tb_next = prev
                if final_tb is None:
                    final_tb = prev
                pprev = prev
                prev = tb
                tb = tb.tb_next
        except Exception as e:
            L.error("error adjusting traceback", exc_info=True)
            raise e
        return etype, value, final_tb

    shell._get_exc_info = get_exc_info
    old_compiler_class = shell.compiler_class

    class Compiler(old_compiler_class):

        def __init__(self):
            super().__init__()

        def ast_parse(self, source, filename="<unknown>", symbol="exec"):
            code_ast: ast.Module = super().ast_parse(source, filename, symbol)
            code_ast.filename = filename
            return code_ast
            # return AstModule(code_ast.body, code_ast.type_ignores, filename)

    shell.compiler_class = Compiler
    shell.compile = Compiler()

    # from IPython.core.interactiveshell import DummyMod
    # module = DummyMod()
    # module.__dict__ = frame.f_globals
    return local_ns, sys.modules[
        frame.f_globals["__name__"]], cell_dict, write_back_vars


def iter_cell_dict_contents(cell_dict: Dict[str, types.CellType]):
    for key, cell in cell_dict.items():
        try:
            yield key, cell.cell_contents
        except:
            pass


# class AstModule(ast.Module):

#     def __init__(self, body, type_ignores, filename):
#         super().__init__(body, type_ignores)
#         self.filename = filename


class FixLocals(object):
    """
    use the helper to generate an inner closure with all the locals
    """

    def __init__(self, shell, cell_dict: Dict[str, types.CellType],
                 extra_globals: Set[str], magic):
        self.shell = shell
        self.cell_dict = cell_dict
        self.extra_globals = extra_globals
        self.magic = magic

    def visit(self, module_ast: ast.Module):
        try:
            ExcludeNlsFromGlobals(self.extra_globals).visit(module_ast)
            if len(self.extra_globals):
                module_ast.body.insert(
                    0,
                    ast.copy_location(
                        ast.Global(names=list(self.extra_globals)),
                        module_ast.body[0] if module_ast.body else module_ast))
            CollectGlobals(self.extra_globals,
                           self.cell_dict).visit(module_ast)
            patcher_cell = types.CellType()
            statement = module_ast.body[-1]
            if isinstance(statement, ast.Expr):
                module_ast.body[-1] = ast.copy_location(
                    ast.Return(value=statement.value), statement)
            run_statements_helper(
                patcher_cell, module_ast.body, None,
                self.magic + "_cell", None, self.shell.user_global_ns,
                list(self.cell_dict.keys()), [], [],
                getattr(module_ast, "filename", "<" + self.magic + "_source>"),
                0, self.shell.compile.flags, self.magic, False, None)
            patched = patcher_cell.cell_contents(self.cell_dict)
            self.shell.user_ns[self.magic + "_update"] = self.update_ns
            self.shell.user_ns[self.magic + "_inner"] = patched
            return self.shell.compile.ast_parse("(" + self.magic +
                                                "_inner(), " + self.magic +
                                                "_update())[0]")
        except Exception as e:
            self.user_ns[self.magic + "_error"] = e
            return self.shell.compile.ast_parse("raise " + self.magic +
                                                "_error")

    def update_ns(self):
        self.shell.user_ns.update(iter_cell_dict_contents(self.cell_dict))


class CollectGlobals(ast.NodeVisitor):
    """
    collect all variables in global statements
    """

    def __init__(self, found_globals: Set[str],
                 cell_dict: Dict[str, types.CellType]):
        super().__init__()
        self.found_globals = found_globals
        self.cell_dict = cell_dict

    def visit_FunctionDef(self, node: ast.FunctionDef):
        return

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
        return

    def visit_Global(self, node: ast.Global):
        self.found_globals.update(
            (x for x in node.names if x not in self.cell_dict))


class ExcludeNlsFromGlobals(ast.NodeVisitor):
    """
    exclude nonlocal variables from found globals
    """

    def __init__(self, found_globals: Set[str]):
        super().__init__()
        self.found_globals = found_globals

    def visit_FunctionDef(self, node: ast.FunctionDef):
        return

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
        return

    def visit_Nonlocal(self, node: ast.Nonlocal):
        self.found_globals.difference_update(node.names)


def embed2(*, frame=None, header="", compile_flags=None, **kwargs):
    """
    embed a shell using the mocked-up globals method (not as powerful)
    :param frame: frame to use 
    :param header: header to display when starting the shell
    :param compile_flags: flags used to compile interactive code
    """
    if frame is None:
        frame = sys._getframe(1)
    # import IPython
    # IPython.start_ipython(argv=[], user_ns=env, config=config)
    from IPython.core.interactiveshell import DummyMod, InteractiveShell
    from IPython.terminal.embed import InteractiveShellEmbed
    from IPython.terminal.ipapp import load_default_config
    config = kwargs.get('config')
    if config is None:
        config = load_default_config()
        config.InteractiveShellEmbed = config.TerminalInteractiveShell
        kwargs['config'] = config
    using = kwargs.get('using', 'sync')
    if using:
        kwargs['config'].update({
            'TerminalInteractiveShell': {
                'loop_runner': using,
                'colors': 'NoColor',
                'autoawait': using != 'sync'
            }
        })
    #save ps1/ps2 if defined
    ps1 = None
    ps2 = None
    try:
        ps1 = sys.ps1
        ps2 = sys.ps2
    except AttributeError:
        pass
    #save previous instance
    saved_shell_instance = InteractiveShell._instance
    if saved_shell_instance is not None:
        cls = type(saved_shell_instance)
        cls.clear_instance()
    if frame is None:
        frame = sys._getframe(1)
    shell = InteractiveShellEmbed.instance(
        _init_location_id='%s:%s' % (frame.f_code.co_filename, frame.f_lineno),
        **kwargs)

    # this almost worked in place of the mocked up globals, except that setting
    # a global in a newly created function with the global keyword did not work,
    # probably because they made a direct call to the underlying dict "set_item"
    # env = Glocals(frame.f_globals, {**frame.f_locals}, shell.user_ns_hidden)
    env = {}
    for (k, v) in itertools.chain(frame.f_globals.items(),
                                  frame.f_locals.items()):
        if k not in shell.user_ns_hidden.keys():
            env[k] = v
    env[magic + "_shell"] = shell
    module = DummyMod()
    module.__dict__ = env
    InteractiveShellEmbed.mainloop = mainloop
    shell(header=header,
          local_ns=env,
          module=module,
          compile_flags=compile_flags,
          _call_location_id='%s:%s' %
          (frame.f_code.co_filename, frame.f_lineno))
    InteractiveShellEmbed.clear_instance()
    #restore previous instance
    if saved_shell_instance is not None:
        cls = type(saved_shell_instance)
        cls.clear_instance()
        for subclass in cls._walk_mro():
            subclass._instance = saved_shell_instance
    if ps1 is not None:
        sys.ps1 = ps1
        sys.ps2 = ps2
    update_locals({k: env[k] for k in frame.f_locals}, frame)


# class Glocals(dict):

#     def __init__(self, gs, ls, hidden_keys):
#         self.gs = gs
#         self.ls = ls
#         self.hidden_keys = hidden_keys
#         self.extras = {}

#     def __getitem__(self, k):
#         if k not in self.hidden_keys:
#             if k in self.gs:
#                 return self.gs[k]
#             if k in self.ls:
#                 return self.ls[k]
#         return self.extras[k]

#     def __setitem__(self, k, v):
#         if k not in self.hidden_keys:
#             if k in self.gs:
#                 self.gs[k] = v
#                 return
#             if k in self.ls:
#                 print("writing ls[%s]" % k)
#                 self.ls[k] = v
#                 return
#         self.extras[k] = v

#     def update(self, other):
#         for k, v in other.items():
#             self[k] = v


def mainloop(
    self,
    local_ns=None,
    module=None,
    stack_depth=0,
    compile_flags=None,
):
    """
    copied from IPython/terminal/embed.py but modified to allow a mocked-up
    globals to be used as both globals and locals, causing exec to behave
    correctly

    --- original docs below ---

    Embeds IPython into a running python program.

    Parameters
    ----------
    local_ns, module
        Working local namespace (a dict) and module (a module or similar
        object). If given as None, they are automatically taken from the scope
        where the shell was called, so that program variables become visible.
    stack_depth : int
        How many levels in the stack to go to looking for namespaces (when
        local_ns or module is None). This allows an intermediate caller to
        make sure that this function gets the namespace from the intended
        level in the stack. By default (0) it will get its locals and globals
        from the immediate caller.
    compile_flags
        A bit field identifying the __future__ features
        that are enabled, as passed to the builtin :func:`compile` function.
        If given as None, they are automatically taken from the scope where
        the shell was called.

    """
    from IPython.core import compilerop
    from IPython.core.interactiveshell import DummyMod

    # Get locals and globals from caller
    if ((local_ns is None or module is None or compile_flags is None)
            and self.default_user_namespaces):
        call_frame = sys._getframe(stack_depth).f_back

        if local_ns is None:
            local_ns = call_frame.f_locals
        if module is None:
            global_ns = call_frame.f_globals
            try:
                module = sys.modules[global_ns['__name__']]
            except KeyError:
                warnings.warn("Failed to get module %s" % \
                    global_ns.get('__name__', 'unknown module')
                )
                module = DummyMod()
                module.__dict__ = global_ns
        if compile_flags is None:
            compile_flags = (call_frame.f_code.co_flags & compilerop.PyCF_MASK)

    # Save original namespace and module so we can restore them after
    # embedding; otherwise the shell doesn't shut down correctly.
    orig_user_module = self.user_module
    orig_user_ns = self.user_ns
    orig_compile_flags = self.compile.flags

    # Update namespaces and fire up interpreter

    # The global one is easy, we can just throw it in
    if module is not None:
        self.user_module = module

    # But the user/local one is tricky: ipython needs it to store internal
    # data, but we also need the locals. We'll throw our hidden variables
    # like _ih and get_ipython() into the local namespace, but delete them
    # later.
    # FIX: allow passing self.user_module.__dict__ as local_ns, resulting in
    # the same object as self.user_ns
    # BEGIN_FIX
    #if local_ns is not None:
    if local_ns is self.user_module.__dict__:
        self.user_ns = local_ns
        self.init_user_ns()
    elif local_ns is not None:
        # END_FIX
        reentrant_local_ns = {
            k: v
            for (k, v) in local_ns.items()
            if k not in self.user_ns_hidden.keys()
        }
        self.user_ns = reentrant_local_ns
        self.init_user_ns()

    # Compiler flags
    if compile_flags is not None:
        self.compile.flags = compile_flags

    # make sure the tab-completer has the correct frame information, so it
    # actually completes using the frame's locals/globals
    self.set_completer_frame()

    with self.builtin_trap, self.display_trap:
        self.interact()

    # now, purge out the local namespace of IPython's hidden variables.
    # FIX: do not revert namespace if it was the special case
    #if local_ns is not None:
    if local_ns is not None and local_ns is not self.user_module.__dict__:
        local_ns.update({
            k: v
            for (k, v) in self.user_ns.items()
            if k not in self.user_ns_hidden.keys()
        })

    # Restore original namespace so shell can shut down when we exit.
    self.user_module = orig_user_module
    self.user_ns = orig_user_ns
    self.compile.flags = orig_compile_flags


def try_all_statements(f: types.FunctionType, stream=sys.stderr):
    """
    This is a decorator to convert a function using `run_statements_helper`, and
    `try` all statements, allowing the user to modify the function on the fly or
    drop into a shell when an exception occurs. Refer to `run_statements_helper`
    for more details. The wrapper is created by filling in the cells in
    `cell_dict` with the argument values.
    :param f: function to decorate
    :param stream: stream to use to print exception and retry info
    :return: wrapped function
    """
    (source, filename, flags, func_line_num,
     private_prefix) = uncompile(f.__code__)
    fmod_ast = parse_snippet(source, filename, "exec", flags, func_line_num,
                             private_prefix)
    # L.info("%s", f'{f.__code__.co_freevars=}')
    co_varnames = f.__code__.co_varnames
    co_freevars = f.__code__.co_freevars
    patcher_cell = types.CellType()
    # TODO: add recursive try_all loop helper
    runner = run_statements_helper(patcher_cell, fmod_ast.body[0].body,
                                   f.__module__, f.__name__, f.__qualname__,
                                   f.__globals__, co_varnames,
                                   f.__code__.co_cellvars, co_freevars,
                                   filename, fmod_ast.body[0].lineno, flags,
                                   magic, True, stream)
    f_closure = f.__closure__ or ()
    f_defaults = f.__defaults__
    f_kwdefaults = f.__kwdefaults__
    co_flags = f.__code__.co_flags
    co_argcount = f.__code__.co_argcount
    co_posonlyargcount = f.__code__.co_posonlyargcount
    co_kwonlyargcount = f.__code__.co_kwonlyargcount

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        cell_dict = dict(zip(co_freevars, f_closure))
        patched = patcher_cell.cell_contents(cell_dict)
        # copy the arguments into the right cells
        # `co_varnames`` is ordered as follows:
        #    co_posonlyargcount
        # <- ^   ->
        #      <- defaulted -> <- (use kwdefaults)  ->
        # <- co_argcount    -> <- co_kwonlyargcount ->
        # <positional_args...> <kw_only_args...      > [<args>] [<kwargs>] [<other_local_vars>]
        # L.info("%s",f'{args=} {kwargs=}')
        non_default_count = co_argcount - (len(f_defaults)
                                           if f_defaults else 0)
        assert len(args) >= non_default_count
        for i in range(co_argcount):
            if i < len(args):
                # L.info("setting %s to %s", co_varnames[i], args[i])
                cell_dict[co_varnames[i]].cell_contents = args[i]
            elif i >= non_default_count:
                cell_dict[co_varnames[i]].cell_contents = kwargs.pop(
                    co_varnames[i], f_defaults[i - non_default_count])
            else:
                assert i >= co_posonlyargcount
                cell_dict[co_varnames[i]].cell_contents = kwargs.pop(
                    co_varnames[i])
        next_i = co_argcount
        for i in range(co_kwonlyargcount):
            varname = co_varnames[next_i]
            cell_dict[varname].cell_contents = kwargs.pop(
                varname) if varname in kwargs else f_kwdefaults[varname]
            next_i += 1
        if co_flags & inspect.CO_VARARGS:
            cell_dict[co_varnames[next_i]].cell_contents = args[co_argcount:]
            next_i += 1
        else:
            assert len(args) <= co_argcount
        if co_flags & inspect.CO_VARKEYWORDS:
            cell_dict[co_varnames[next_i]].cell_contents = kwargs
        else:
            assert len(kwargs) == 0
        return runner(patched, 0, cell_dict)

    wrapper.__wrapped__ = None
    return wrapper


def get_cell_dict_from_funcs(funcs: Union[List[types.FunctionType],
                                          types.FunctionType]):
    """
    get cell_dict from the closures of provided funcs
    :param funcs: functions
    :return: cell_dict
    """
    cell_dict: Dict[str, types.CellType] = {}
    if funcs:
        for func in [funcs] if isinstance(funcs,
                                          types.FunctionType) else funcs:
            cell_dict.update(
                zip(func.__code__.co_freevars, func.__closure__ or ()))
    return cell_dict


class TryBlockTransformer(ast.NodeTransformer):
    """
    meant to allow recursive application of try_all_statements within code
    blocks (abandoned)
    """

    def __init__(self):
        self.flag = False

    def visit_Lambda(self, node: ast.Lambda) -> Any:
        if len(node.args.args) == 1 and node.args.args[0].arg == "try_all":
            self.flag = True
        return self.generic_visit(node)

    def visit_For(self, node: ast.For) -> Any:
        if self.flag:
            self.flag = False
            # TODO: new idea: convert `for x, (y, z) in it:` within try_all_statements to
            #   it2 = iter(it)
            #   if i == 0:
            #       try:
            #           _ = next(it)
            #       except:
            #           return True
            #   for _ in [0]:
            #       if i==1:
            #           statement1
            #       if i==2:
            #           statement2
            #       ...

            # TODO: automatically convert `for x, (y, z) in it:` to:
            #   @run_func(it, (1, (1, 1)))
            #   @try_all_statements
            #   def _(x, y, z):
            # and `continue` to `return` and `break` to `return True`
            # how to make variables that become local into nonlocal?
            #   leak locals into cell_dict
        return self.generic_visit(node)


def run_statements_helper(patcher_cell: types.CellType,
                          statements: List[ast.stmt], module: str, name: str,
                          qualname: str, globals: dict, co_varnames: List[str],
                          co_cellvars: List[str], co_freevars: List[str],
                          filename: str, func_line_num: int, flags: int,
                          magic: str, to_try: bool, stream: TextIO):
    """
    If `to_try` is true, this changes a sequence of statements to a function
    that can run individual statements like so (if it hit the end of the
    statement without returning, it will return _ipy_magic_inner):
    ```
    def _ipy_magic_outer(a, b, c):
        del a, b, c
        def _ipy_magic_inner(_i):
            nonlocal _ipy_magic_inner
            nonlocal a, b, c
            if _i == 0:
                statement0
                return _ipy_magic_inner
            elif _i == 1:
                statement1
                return _ipy_magic_inner
            ...
        return _ipy_magic_inner
    ```

    If `to_try` is false, the function is just wrapped as follows:
    ```
    def _ipy_magic_outer(a, b, c):
        del a, b, c
        def _ipy_magic_inner():
            nonlocal _ipy_magic_inner
            nonlocal a, b, c
            statement0
            ...
        return _inner
    ```

    It then compiles the generated code, runs `_ipy_magic_outer` to get an
    instance of `_ipy_magic_inner`, and constructs a patcher, which takes a
    mapping of var names to the associated cell in order to update the closure
    of `_ipy_magic_inner`. This is to maintain the local variable state across
    multiple calls to allow the statements to be run sequentially. The
    `nonlocal` in `_inner` converts both local variables and cell variables to
    free (closure) variables. If some of these were not fully converted, we add
    the extra variables to `_ipy_magic_outer` and recompile.
    
    We provide the instance of the real function running the code as
    `_ipy_magic_inner` in the scope of the statements so that it is possible to
    call `embed(func=_ipy_magic_inner)` to achieve full control over local
    variables in the embedded shell.

    We then construct the runner function. If `to_try` is true, the statements
    are run in sequence from the given starting index. When an uncaught
    exception is raised, we catch it and allow the user to update the source
    code of the function and specify a line to start running from. After doing
    so, we redo the whole process of the compilation of `_ipy_magic_outer` and
    patcher/runner construction, patch with the current cell dict and start
    running the function from the specified line using the new runner.

    Note that we use a cell to return the patcher because by passing this cell
    to future recursive calls, the original wrapper can be updated with the
    patched function as it will use this cell's contents as the patcher.

    The continuation point is specified as follows:
    `[[<filename>;]<func_line_num>;]<next_statement_line_num>`. If the user
    enters a negative number `-x` for the next statement line, an embedded shell
    will be injected into the statement at line `x`, after which a raise will
    once again drop it back here. The special case of `0` is treated as `-x`
    where `x` is the default next statement line (the line of the statement
    which failed). If just a semicolon is given, the exception is re-raised.

    General Info:
    If `f` is a function, then
    `f.__code__.co_freevars` are variables in the parent functions which were
    closed over (by any code or child functions) and `f.__code__.co_cellvars`
    variables in the scope of `f` which are closed over by a child function.
    `f.__closure__` gives all the cells in correspondence with
    `f.__code__.co_freevars`. We use FunctionType to alter which cells the
    closure refers to, as mentioned by
    https://stackoverflow.com/questions/59276834/how-to-set-the-content-of-a-closure-cell
    . We are not able to use pure Python to get the cell of a cell variable in
    the stack frame (any variable which was closed over), and we have to
    probably use `ctypes.pythonapi` and follow something like
    `framelocalsproxy_getval`. Fortunately, this is not needed for our limited
    use-case where we already have the function to be converted.

    We are not able to use pure Python to get the cell of a cell variable in
    the stack frame (any variable which was closed over), and we have to
    probably use `ctypes.pythonapi` and follow something like
    `framelocalsproxy_getval`. Fortunately, this is not needed for our limited
    use-case where we already have the function to be converted. However, we
    might need this in the future to allow modifications to a function's cell
    variables (it is almost impossible to modify non-cell variables).

    :param patcher_cell: cell to place the patcher into
    :param statements: statements to wrap
    :param module: name of module (optional)
    :param name: name of function
    :param qualname: qualified name of function (optional)
    :param globals: function globals
    :param co_varnames: function co_varnames
    :param co_cellvars: function co_cellvars
    :param co_freevars: function co_freevars
    :param filename: filename
    :param func_line_num: function line number
    :param flags: function flags
    :param magic: the magic string to use
    :param to_try: True iff exceptions from every statement is to be caught
    :param stream: if to_try, use stream to print exception and retry info
    :return: runner
        runner(patched, start_i, cell_dict): runs statements from index
        `start_i` after specialising with the `patcher` and `cell_dict`
    """

    statements = [
        AnnotationRemover().visit(statement) for statement in statements
    ]
    # `.0` is an iteration variable that appears because of comprehension
    # inlining in Python 3.12
    co_varnames = [x for x in co_varnames if x != ".0"]
    co_cellvars = list(co_cellvars)
    co_varnames_set = set(co_varnames)
    co_cellvars_set = set(co_cellvars)
    while True:
        # L.info("%s", f'{co_varnames=} {co_cellvars=} {co_freevars=}')
        # co_cellvars might be repeated in co_varnames if it is a parameter
        all_locals = [
            *co_varnames,
            *[x for x in co_cellvars if x not in co_varnames_set], *co_freevars
        ]
        # if func uses more nonlocal variables (more bindings from outer
        # functions), the below will throw an error when compiled, otherwise we
        # can update the list of all locals
        # magic = "_ipy_magic_%s_%d" % (name, id(closure))
        if to_try:
            module_ast = ast.parse("""
                def {0}_outer({1}):
                    {0}_embed = None
                    {2}del {1}
                    def {0}_inner({0}_i):
                        nonlocal {0}_inner, {0}_embed
                        {2}nonlocal {1}
                        if {0}_i == 0:
                            return {0}_inner
                    return {0}_inner
            """.strip().format(magic, ",".join(all_locals),
                               "" if len(all_locals) else "pass # "))
            outer_ast: ast.FunctionDef = module_ast.body[0]
            inner_ast: ast.FunctionDef = outer_ast.body[2]
            if_ast_template: ast.If = inner_ast.body.pop()
            if_asts = []
            for i, statement in enumerate(statements):
                if_ast = copy.deepcopy(if_ast_template)
                # make it compare `_i` with the real statement index
                if_ast.test.comparators[0].value = i
                if_ast.body.insert(0, statement)
                if_asts.append(if_ast)
            # add all the if statements to the body of `_inner`
            inner_ast.body.extend(if_asts)
        else:
            module_ast = ast.parse("""
                def {0}_outer({1}):
                    {0}_embed = None
                    {2}del {1}
                    def {0}_inner():
                        nonlocal {0}_inner, {0}_embed
                        {2}nonlocal {1}
                    return {0}_inner
            """.strip().format(magic, ",".join(all_locals),
                               "" if len(all_locals) else "pass # "))
            outer_ast: ast.FunctionDef = module_ast.body[0]
            inner_ast: ast.FunctionDef = outer_ast.body[2]
            inner_ast.body.extend(statements)
        # execute definition of `_outer` with empty locals and correct globals
        local_dict = {}
        exec(compile(module_ast, filename, "exec", flags, dont_inherit=True),
             globals, local_dict)
        # execute `_outer` to get `_inner`
        _outer = local_dict[magic + "_outer"]
        _inner: types.FunctionType = _outer(
            *[None for i in range(len(all_locals))])
        # L.info(
        #     "%s",
        #     f'{_inner.__code__.co_varnames=} {_inner.__code__.co_cellvars=}')
        # if there are extra local variables in `co_varnames` and `co_cellvars`,
        # add it to `all_locals` and start over
        if to_try:
            assert _inner.__code__.co_varnames[0] == magic + "_i"
            co_varnames_new = _inner.__code__.co_varnames[1:]
        else:
            co_varnames_new = _inner.__code__.co_varnames
        # must check original set due to a bug in comprehension inlining since
        # Python 3.12. the bug causes co_varnames to include comprehension
        # variables arguably incorrectly, and variables can disappear from
        # co_varnames expectedly. See
        # https://github.com/python/cpython/issues/121377 .
        vars_changed = False
        for varname in co_varnames_new:
            if varname not in co_varnames_set:
                co_varnames.append(varname)
                vars_changed = True
        co_varnames_set = set(co_varnames)
        for varname in _inner.__code__.co_cellvars:
            if varname not in co_cellvars_set:
                co_cellvars.append(varname)
                vars_changed = True
        co_cellvars_set = set(co_cellvars)
        if vars_changed:
            continue
        else:
            break

    all_locals_len = len(all_locals)

    # construct patcher
    def patcher(cell_dict: Dict[str, types.CellType]):
        # patch `_inner` with the correct closure, recalling original free
        # variables, previously added cells, and remembering new cells
        _inner: types.FunctionType = _outer(
            *[None for i in range(all_locals_len)])
        c = _inner.__code__
        inner_code = c.replace(co_filename=filename, co_name=name)
        patched_cell = None
        new_closure = []
        for x, y in zip(inner_code.co_freevars, _inner.__closure__ or ()):
            if x == magic + "_inner":
                patched_cell = y
                new_closure.append(y)
            elif x == magic + "_embed":
                y.cell_contents = embed
                new_closure.append(y)
            else:
                new_closure.append(cell_dict.setdefault(x, y))
        patched = types.FunctionType(inner_code,
                                     globals,
                                     name,
                                     argdefs=None,
                                     closure=tuple(new_closure))
        patched_cell.cell_contents = patched
        if module:
            patched.__module__ = module
        if qualname:
            patched.__qualname__ = qualname
        return patched

    patcher_cell.cell_contents = patcher

    if not to_try:

        # # currently not used
        # def runner(patched, start_i: int, cell_dict: Dict[str,
        #                                                   types.CellType]):
        #     for i in range(start_i, len(statements)):
        #         # L.info("running statement %d", i)
        #         ret = patched(i)
        #     return ret

        # return runner
        return None

    # construct runner
    def runner(patched, start_i: int, cell_dict: Dict[str, types.CellType]):
        # run statements until end or error
        for i in range(start_i, len(statements)):
            try:
                # L.info("running statement %d", i)
                ret = patched(i)
                if ret is not patched:
                    return ret
            except Exception as e:
                traceback.print_exc(file=stream)
                stream.flush()
                while True:
                    try:
                        statements_new = get_func_statements(
                            filename, func_line_num, flags)
                        line_nums = [s.lineno for s in statements_new]
                        next_line_num = line_nums[i] if i < len(
                            line_nums) else 0
                    except:
                        next_line_num = 0
                    try:
                        resp = prompt_with_default(
                            "next statement", "%s;%d;%d" %
                            (filename, func_line_num, next_line_num),
                            stream).split(";", 2)
                    except:
                        traceback.print_exc(file=stream)
                        stream.flush()
                        continue
                    if resp[0] == "":
                        raise e
                    try:
                        filename_new = filename if len(resp) < 3 else resp[-3]
                        func_line_num_new = func_line_num if len(
                            resp) < 2 else int(resp[-2])
                        next_line_num = int(resp[-1])
                    except:
                        traceback.print_exc(file=stream)
                        stream.flush()
                        continue
                    try:
                        statements_new = get_func_statements(
                            filename_new, func_line_num_new, flags)
                    except:
                        if next_line_num != 0:
                            traceback.print_exc(file=stream)
                            stream.flush()
                            continue
                        statements_new = get_func_statements(
                            None, 1, flags, "def _():\n" + " pass\n" * (i + 1))
                    try:
                        line_nums = [s.lineno for s in statements_new]
                        use_embed = next_line_num <= 0
                        if next_line_num == 0:
                            next_line_num = line_nums[i]
                        elif next_line_num < 0:
                            next_line_num = -next_line_num
                        next_i = line_nums.index(next_line_num)
                        if use_embed:
                            statement_orig = statements_new[next_i]
                            statement_new: ast.If = ast.parse(
                                "if True:\n {0}_embed(funcs=[{0}_inner])\n raise"
                                .format(magic)).body[0]
                            statements_new[next_i] = ast.copy_location(
                                statement_new, statement_orig)
                        # run the compilation and wrapper construction anew
                        runner_new = run_statements_helper(
                            patcher_cell, statements_new, module, name,
                            qualname, globals, co_varnames, co_cellvars,
                            co_freevars, filename_new, func_line_num_new,
                            flags, magic, True, stream)
                        # TODO: patch global functions
                        patched_new = patcher_cell.cell_contents(cell_dict)
                        break
                    except:
                        traceback.print_exc(file=stream)
                        stream.flush()
                # finally run the new wrapper
                return runner_new(patched_new, next_i, cell_dict)
        return None

    return runner


class AnnotationRemover(ast.NodeTransformer):
    """
    removes type annotations of variables that are currently local but would
    become non-local
    """

    def __init__(self):
        super().__init__()
        self.is_top_level = True

    def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
        self.is_top_level = False
        return self.generic_visit(node)

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
        self.is_top_level = False
        return self.generic_visit(node)

    def visit_ClassDef(self, node: ast.ClassDef) -> Any:
        self.is_top_level = False
        return self.generic_visit(node)

    def generic_visit(self, node):
        is_top_level = self.is_top_level
        for field, old_value in ast.iter_fields(node):
            if isinstance(old_value, list):
                new_values = []
                for value in old_value:
                    if isinstance(value, ast.AST):
                        value = self.visit(value)
                        self.is_top_level = is_top_level
                        if value is None:
                            continue
                        elif not isinstance(value, ast.AST):
                            new_values.extend(value)
                            continue
                    new_values.append(value)
                old_value[:] = new_values
            elif isinstance(old_value, ast.AST):
                new_node = self.visit(old_value)
                self.is_top_level = is_top_level
                if new_node is None:
                    delattr(node, field)
                else:
                    setattr(node, field, new_node)
        return node

    def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
        if self.is_top_level and isinstance(node.target, ast.Name):
            C = lambda x: ast.copy_location(x, node)
            load = ast.Load()
            subscript = C(
                ast.Subscript(
                    value=C(
                        ast.Tuple(elts=[
                            C(ast.Constant(value=None, kind=None))
                            if node.value is None else super().generic_visit(
                                node.value),
                            super().generic_visit(node.annotation)
                        ],
                                  ctx=load)),
                    slice=C(
                        ast.Index(value=C(ast.Constant(value=0, kind=None)))),
                    ctx=load))
            return C(
                ast.Expr(value=subscript) if node.value is None else ast.
                Assign(targets=[super().generic_visit(node.target)],
                       value=subscript))
        return self.generic_visit(node)


def get_func_statements(filename: str,
                        func_line_num: int,
                        flags: int,
                        module_src: str = None):
    """
    get statements of function at line of source code
    :param module_src: module source code
    :param func_line_num: function line number
    :return: statements in function
    """
    if filename:
        with open(filename, "rb") as src_f:
            module_src = src_f.read()
    else:
        filename = "<string>"
    # get AST using flag PyCF_ONLY_AST
    module_ast: ast.Module = compile(module_src,
                                     filename,
                                     "exec",
                                     flags
                                     | ast.PyCF_ONLY_AST,
                                     dont_inherit=True)
    # retrieve function at specified line
    helper = GetFuncAtLine(func_line_num)
    func_ast: ast.FunctionDef = helper.visit(module_ast)
    # get all statements in function
    return func_ast.body


class GetFuncAtLine(ast.NodeVisitor):
    """
    visits AST to figure out what function is at a certain line
    """

    def __init__(self, func_line_num):
        super().__init__()
        self.func_line_num = func_line_num

    def generic_visit(self, node):
        """Called if no explicit visitor function exists for a node."""
        for field, value in ast.iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AST):
                        ret = self.visit(item)
                        if ret:
                            return ret
            elif isinstance(value, ast.AST):
                ret = self.visit(value)
                if ret:
                    return ret

    def visit_FunctionDef(self, node: ast.FunctionDef):
        # L.info("%s", f'{node.lineno=}')
        if node.lineno == self.func_line_num:
            return node
        return self.generic_visit(node)


PyCF_MASK = functools.reduce(operator.or_,
                             (getattr(__future__, fname).compiler_flag
                              for fname in __future__.all_feature_names))


def uncompile(c: types.CodeType):
    """
    "uncompile" code object by getting the source lines
    :param c: code object
    :return: source, filename, flags, func_line_num, private_prefix
    """
    if c.co_name == "<lambda>":
        raise NotImplementedError("Lambda functions not supported")
    if c.co_filename == "<string>":
        raise NotImplementedError("Code without source file not supported")

    filename = inspect.getfile(c)

    try:
        lines, func_line_num = inspect.getsourcelines(c)
    except IOError:
        raise RuntimeError("Source code not available")

    source = "".join(lines)

    # __X is mangled to _ClassName__X in methods. Find this prefix:
    private_prefix = None
    for name in c.co_names:
        m = re.match("^(_[A-Za-z][A-Za-z0-9_]*)__.*$", name)
        if m:
            private_prefix = m.group(1)
            break

    return (source, filename, c.co_flags & PyCF_MASK, func_line_num,
            private_prefix)


def parse_snippet(source: str,
                  filename: str,
                  mode: str,
                  flags: int,
                  firstlineno: int,
                  privateprefix: str = None) -> ast.Module:
    """
    like ast.parse, but accepts indented code snippet with a line number offset
    :param source: source code
    :param filename: filename
    :param mode: mode
    :param flags: flags
    :param firstlineno: first line number
    :param privateprefix: private prefix (unused)
    :return: parsed AST
    """
    args = filename, mode, flags | ast.PyCF_ONLY_AST, True
    prefix = "\n"
    try:
        a: ast.Module = compile(prefix + source, *args)
    except IndentationError:
        # Already indented? Wrap with dummy compound statement
        prefix = "with 0:\n"
        a: ast.Module = compile(prefix + source, *args)
        # Peel wrapper
        a.body = a.body[0].body
    ast.increment_lineno(a, firstlineno - 2)
    return a


def recompile(source,
              filename,
              mode,
              flags=0,
              firstlineno=1,
              privateprefix=None):
    """
    recompile output of uncompile back to a code object. source may also be
    preparsed AST.
    :param source: source code
    :param filename: filename
    :param mode: mode
    :param flags: flags
    :param firstlineno: first line number
    :param privateprefix: private prefix (unused)
    :raises RuntimeError: when failed to parse
    :return: compiled code
    """
    if isinstance(source, ast.AST):
        a = source
    else:
        a = parse_snippet(source, filename, mode, flags, firstlineno)

    node = a.body[0]

    if not isinstance(node, ast.FunctionDef):
        raise RuntimeError("Expecting function AST node")

    c0 = compile(a, filename, mode, flags, True)

    return c0


def prompt_with_default(prompt, def_val, stream, transform=(lambda x: x)):
    """
    prompts for input with default value and retries until transform does not
    raise an exception

    :param prompt: prompt to use
    :param def_val: default value
    :param stream: stream to use
    :param transform: transform function
    :return: transformed value
    """

    while True:
        print("%s [%s]: " % (prompt, def_val), file=stream, end="", flush=True)
        input_str = input()
        if input_str == "":
            return def_val
        try:
            return transform(input_str)
        except:
            pass


def update_locals(ls, target_frame=None):
    """
    schedules an update to the locals of the target frame (caller's frame by
    default) using `sys.settrace`

    :param ls: locals dictionary with updated values
    :param target_frame: target frame
    """
    if target_frame is None:
        target_frame = sys._getframe(1)
    old_f_trace = target_frame.f_trace
    old_trace = sys.gettrace()

    def trace_dummy(frame, event, arg):
        return None

    def update_locals_helper(frame, event, arg):
        if frame is target_frame:
            try:
                # must not get frame.f_locals after update
                f_locals = frame.f_locals
                f_locals.update(ls)
            except:
                pass
            sys.settrace(old_trace)
            frame.f_trace = old_f_trace

    target_frame.f_trace = update_locals_helper
    sys.settrace(trace_dummy)


def run_func(it=None, structure=1, *args, **kwargs):
    """
    This is a decorator to run a function, iterating with `it` as first argument
    and with the remaining args as subsequent arguments
    :param f: function to decorate
    :return: wrapped function
    """
    if it is None:

        def decorator(f: types.FunctionType):
            f(*args, **kwargs)

        return decorator

    if isinstance(structure, tuple):

        def decorator(f: types.FunctionType):
            for x in it:
                if f(*flatten_tuple(x, structure), *args, **kwargs) == "break":
                    break

        return decorator

    if structure == 1:

        def decorator(f: types.FunctionType):
            for x in it:
                if f(x, *args, **kwargs) == "break":
                    break

        return decorator
    if structure == -1:

        def decorator(f: types.FunctionType):
            for x in it:
                if f(*args, **kwargs) == "break":
                    break

        return decorator

    if structure == 0:

        def decorator(f: types.FunctionType):
            for x in it:
                if f(*x, *args, **kwargs) == "break":
                    break

        return decorator

    def decorator(f: types.FunctionType):
        for x in it:
            assert len(x) == structure
            if f(*x, *args, **kwargs) == "break":
                break

    return decorator


def flatten_tuple(data, structure, cache={}):
    """
    constructs a flattened tuple by JIT compilation of destructuring code, for
    e.g. the structure ((3,), (2,), (3,)) will destructure tuples that look like
    ((1, 2, 3), (4, 5), (6, 7, 8))

    :param data: original (nested) tuple
    :param structure: structure of the tuple
    :param cache: cache to store JIT compilations
    :return: flattened tuple
    """
    if not (unpacker := cache.get(structure)):

        def get_unstruct(structure, n_vars):
            if isinstance(structure, int):
                if structure == 0:
                    return "*x%d" % n_vars, n_vars + 1
                if structure == -1:
                    return "_", n_vars
                new_n_vars = n_vars + structure
                ret = ",".join("x%d" % i for i in range(n_vars, new_n_vars))
                n_vars = new_n_vars
                return ret, new_n_vars
            assert isinstance(structure, tuple)
            unstruct_code = "("
            for x in structure:
                unstruct_code_one, n_vars = get_unstruct(x, n_vars)
                unstruct_code += unstruct_code_one + ","
            return unstruct_code + ")", n_vars

        unstruct_code, n_vars = get_unstruct(structure, 0)
        local_dict = {}
        exec(
            compile(
                "def unpacker(data):\n %s = data\n return %s" %
                (unstruct_code,
                 ("x" +
                  ",x".join(map(str, range(n_vars))) if n_vars else "(,)")),
                "",
                "exec",
                dont_inherit=True), {}, local_dict)
        unpacker = local_dict["unpacker"]
        cache[structure] = unpacker
    return unpacker(data)


def reload_module(module_str, other_globals_list=None):
    import importlib
    cur_module = sys.modules[module_str]
    orig_mapping = {id(x): k for k, x in cur_module.__dict__.items()}
    all_mappings = []
    if other_globals_list is None:
        other_globals_list = [sys._getframe(1).f_globals]
    elif not isinstance(other_globals_list, list):
        other_globals_list = [other_globals_list]
    for gs in other_globals_list:
        for k2, x in gs.items():
            k = orig_mapping.get(id(x))
            if k is not None:
                all_mappings.append((gs, k2, k))
    new_module_gs = importlib.reload(cur_module).__dict__
    for gs, k2, k in all_mappings:
        if k in new_module_gs:
            gs[k2] = new_module_gs[k]
