#!/usr/bin/env python3
import __future__

import ast
import copy
import functools
import inspect
import itertools
import logging
import operator
import os
import re
import sys
import types
import warnings
from typing import Any, Dict, List, Set

L = logging.getLogger("ipython_utils." + __file__)

magic = "_ipy_magic"
_ipy_magic_inner: None  # dummy to resolve IDE errors
globals()[magic + "_inner"] = lambda: None


def add_except_hook():
    sys.excepthook = excepthook


def excepthook(etype, value, tb, logger=L):
    import inspect
    L.error("uncaught exception", exc_info=(etype, value, tb))
    records = inspect.getinnerframes(tb)
    for record in reversed(records):
        if "/site-packages/" in record.filename:
            continue
        L.info("frame: %s", record)
        frame = record.frame
        msg = "Entering IPython console at {0.f_code.co_filename} at line {0.f_lineno}".format(
            frame)
        savehook = sys.excepthook  # save the exception hook
        embed(header=msg, frame=frame)
        sys.excepthook = savehook  # reset IPython's change to the exception hook
        break


def embed(funcs: List[types.FunctionType] = None,
          *,
          frame=None,
          header="",
          compile_flags=None,
          **kwargs):
    """
    This is copied from IPython/terminal/embed.py but modified to allow closures
    over local variables.

    Allowing closures over local variables is done by using a transformer added
    to shell.ast_transformers that replaces the AST completely with a call to a
    temporary injected function generated by a special closure. More precisely,
    we first dynamically create and run an outer function scope wrapping the
    cell code with the original local variables as its parameters. For
    illustration, we have the following generated when the code has `statement0;
    statement1; last_expression` and the local variables are x, y, z:

    ```
    def _ipy_magic_outer(x, y, z):
        del x, y, z
        def _ipy_magic_inner():
            nonlocal _ipy_magic_inner
            nonlocal x, y, z
            statement0
            statement1
            return last_expression
        return _ipy_magic_inner
    ```

    Then we patch the returned function's closure in order to carry forward the
    cells (and their contents) from previous cell executions. This will allow
    even inner closures to be closed over the right cells. The argument `funcs`
    allows you to have additional closed variables as the resolving of cells in
    this patching process will include cells in the closure of functions in
    `funcs`. Note that only such variable modifications are able to persist
    after the shell exits. This means that they must originally be closed over,
    because otherwise they are not stored as cells in the corresponding frames.
    Calling embed inside `try_all_statements` or another `embed` should use
    `_ipy_magic_inner` in `funcs` in order to get the right closure cells.

    Some known issues or unusual behaviour:
    * changes to local variables are only visible to code within the shell
    session (it is possible to make this work for cell variables in CPython, but
    it may be prone to breakage), unless `funcs` has a function with a closure
    containing the cells for these local variables
    * outer scope variables must already have had a closure referencing it to be
    visible (`nonlocal` does not work); changes to these are persistent
    * use of variables starting with `_ipy_magic` may result in undefined
    behaviour
    * return statements without an enclosing function does not give an error and
    the value returned is the result of the cell

    :param frame: frame to get locals and globals from
    :param funcs: list of functions to get closure cells from

    --- original docs below ---

    Call this to embed IPython at the current point in your program.

    The first invocation of this will create an :class:`InteractiveShellEmbed`
    instance and then call it.  Consecutive calls just call the already
    created instance.

    If you don't want the kernel to initialize the namespace
    from the scope of the surrounding function,
    and/or you want to load full IPython configuration,
    you probably want `IPython.start_ipython()` instead.

    Here is a simple example::

        from IPython import embed
        a = 10
        b = 20
        embed(header='First time')
        c = 30
        d = 40
        embed()

    Full customization can be done by passing a :class:`Config` in as the
    config argument.
    """
    from IPython.core.interactiveshell import InteractiveShell
    from IPython.terminal.embed import InteractiveShellEmbed
    from IPython.terminal.ipapp import load_default_config
    config = kwargs.get('config')
    if config is None:
        config = load_default_config()
        config.InteractiveShellEmbed = config.TerminalInteractiveShell
        kwargs['config'] = config
    using = kwargs.get('using', 'sync')
    if using:
        kwargs['config'].update({
            'TerminalInteractiveShell': {
                'loop_runner': using,
                'colors': 'NoColor',
                'autoawait': using != 'sync'
            }
        })
    #save ps1/ps2 if defined
    ps1 = None
    ps2 = None
    try:
        ps1 = sys.ps1
        ps2 = sys.ps2
    except AttributeError:
        pass
    #save previous instance
    saved_shell_instance = InteractiveShell._instance
    if saved_shell_instance is not None:
        cls = type(saved_shell_instance)
        cls.clear_instance()
    if frame is None:
        frame = sys._getframe(1)
    shell = InteractiveShellEmbed.instance(
        _init_location_id='%s:%s' % (frame.f_code.co_filename, frame.f_lineno),
        **kwargs)
    cell_dict = {}
    if funcs:
        if isinstance(funcs, types.FunctionType):
            func = funcs
            if func.__closure__:
                for name, val in zip(func.__code__.co_freevars,
                                     func.__closure__):
                    cell_dict[name] = val
        else:
            for func in funcs:
                if func.__closure__:
                    for name, val in zip(func.__code__.co_freevars,
                                         func.__closure__):
                        cell_dict[name] = val
    for k, v in frame.f_locals.items():
        if k not in cell_dict:
            cell_dict[k] = types.CellType(v)
    extra_globals = set()
    shell.ast_transformers.append(
        FixLocals(shell, frame, cell_dict, extra_globals, magic))
    from IPython.core.interactiveshell import DummyMod
    module = DummyMod()
    module.__dict__ = frame.f_globals
    shell(header=header,
          local_ns=frame.f_locals,
          module=module,
          compile_flags=compile_flags,
          _call_location_id='%s:%s' %
          (frame.f_code.co_filename, frame.f_lineno))
    InteractiveShellEmbed.clear_instance()
    #restore previous instance
    if saved_shell_instance is not None:
        cls = type(saved_shell_instance)
        cls.clear_instance()
        for subclass in cls._walk_mro():
            subclass._instance = saved_shell_instance
    if ps1 is not None:
        sys.ps1 = ps1
        sys.ps2 = ps2


class FixLocals(object):

    def __init__(self, shell, frame: types.FrameType,
                 cell_dict: Dict[str, types.CellType], extra_globals: Set[str],
                 magic):
        self.shell = shell
        self.frame = frame
        self.cell_dict = cell_dict
        self.extra_globals = extra_globals
        self.magic = magic

    def visit(self, module_ast: ast.Module):
        try:
            if self.extra_globals:
                module_ast.body.insert(
                    0,
                    ast.copy_location(
                        ast.Global(names=list(self.extra_globals)),
                        module_ast.body[0] if module_ast.body else module_ast))
            CollectGlobals(self.extra_globals).visit(module_ast)
            patcher_cell = types.CellType()
            statement = module_ast.body[-1]
            if isinstance(statement, ast.Expr):
                module_ast.body[-1] = ast.copy_location(
                    ast.Return(value=statement.value), statement)
            runner = run_statements_helper(patcher_cell, module_ast.body, None,
                                           self.magic + "_shell", None,
                                           self.shell.user_global_ns,
                                           list(self.cell_dict.keys()), [], [],
                                           self.frame.f_code.co_filename, 0,
                                           self.shell.compile.flags,
                                           self.magic, False)
            patched = patcher_cell.cell_contents(self.cell_dict)
            self.shell.user_ns[self.magic + "_inner"] = patched
            return self.shell.compile.ast_parse(self.magic + "_inner()")
        except Exception as e:
            L.error("error %s", e, exc_info=True)
        # L.info("dump:\n%s", ast.dump(node))
        return module_ast


class CollectGlobals(ast.NodeVisitor):

    def __init__(self, found_globals: Set[str]):
        super().__init__()
        self.found_globals = found_globals

    def visit_FunctionDef(self, node: ast.FunctionDef):
        return

    def visit_Global(self, node: ast.Global) -> Any:
        self.found_globals.update(node.names)


def embed2(*, frame=None, header="", compile_flags=None, **kwargs):
    if frame is None:
        frame = sys._getframe(1)
    # import IPython
    # IPython.start_ipython(argv=[], user_ns=env, config=config)
    from IPython.core.interactiveshell import DummyMod, InteractiveShell
    from IPython.terminal.embed import InteractiveShellEmbed
    from IPython.terminal.ipapp import load_default_config
    config = kwargs.get('config')
    if config is None:
        config = load_default_config()
        config.InteractiveShellEmbed = config.TerminalInteractiveShell
        kwargs['config'] = config
    using = kwargs.get('using', 'sync')
    if using:
        kwargs['config'].update({
            'TerminalInteractiveShell': {
                'loop_runner': using,
                'colors': 'NoColor',
                'autoawait': using != 'sync'
            }
        })
    #save ps1/ps2 if defined
    ps1 = None
    ps2 = None
    try:
        ps1 = sys.ps1
        ps2 = sys.ps2
    except AttributeError:
        pass
    #save previous instance
    saved_shell_instance = InteractiveShell._instance
    if saved_shell_instance is not None:
        cls = type(saved_shell_instance)
        cls.clear_instance()
    if frame is None:
        frame = sys._getframe(1)
    shell = InteractiveShellEmbed.instance(
        _init_location_id='%s:%s' % (frame.f_code.co_filename, frame.f_lineno),
        **kwargs)

    # this almost worked in place of the mocked up globals, except that setting
    # a global in a newly created function with the global keyword did not work,
    # probably because they made a direct call to the underlying dict "set_item"
    # env = Glocals(frame.f_globals, {**frame.f_locals}, shell.user_ns_hidden)
    env = {}
    for (k, v) in itertools.chain(frame.f_globals.items(),
                                  frame.f_locals.items()):
        if k not in shell.user_ns_hidden.keys():
            env[k] = v
    module = DummyMod()
    module.__dict__ = env
    InteractiveShellEmbed.mainloop = mainloop
    shell(header=header,
          local_ns=env,
          module=module,
          compile_flags=compile_flags,
          _call_location_id='%s:%s' %
          (frame.f_code.co_filename, frame.f_lineno))
    InteractiveShellEmbed.clear_instance()
    #restore previous instance
    if saved_shell_instance is not None:
        cls = type(saved_shell_instance)
        cls.clear_instance()
        for subclass in cls._walk_mro():
            subclass._instance = saved_shell_instance
    if ps1 is not None:
        sys.ps1 = ps1
        sys.ps2 = ps2


# class Glocals(dict):

#     def __init__(self, gs, ls, hidden_keys):
#         self.gs = gs
#         self.ls = ls
#         self.hidden_keys = hidden_keys
#         self.extras = {}

#     def __getitem__(self, k):
#         if k not in self.hidden_keys:
#             if k in self.gs:
#                 return self.gs[k]
#             if k in self.ls:
#                 return self.ls[k]
#         return self.extras[k]

#     def __setitem__(self, k, v):
#         if k not in self.hidden_keys:
#             if k in self.gs:
#                 self.gs[k] = v
#                 return
#             if k in self.ls:
#                 print("writing ls[%s]" % k)
#                 self.ls[k] = v
#                 return
#         self.extras[k] = v

#     def update(self, other):
#         for k, v in other.items():
#             self[k] = v


def mainloop(
    self,
    local_ns=None,
    module=None,
    stack_depth=0,
    compile_flags=None,
):
    """Embeds IPython into a running python program.

    Parameters
    ----------
    local_ns, module
        Working local namespace (a dict) and module (a module or similar
        object). If given as None, they are automatically taken from the scope
        where the shell was called, so that program variables become visible.
    stack_depth : int
        How many levels in the stack to go to looking for namespaces (when
        local_ns or module is None). This allows an intermediate caller to
        make sure that this function gets the namespace from the intended
        level in the stack. By default (0) it will get its locals and globals
        from the immediate caller.
    compile_flags
        A bit field identifying the __future__ features
        that are enabled, as passed to the builtin :func:`compile` function.
        If given as None, they are automatically taken from the scope where
        the shell was called.

    """
    from IPython.core import compilerop
    from IPython.core.interactiveshell import DummyMod

    # Get locals and globals from caller
    if ((local_ns is None or module is None or compile_flags is None)
            and self.default_user_namespaces):
        call_frame = sys._getframe(stack_depth).f_back

        if local_ns is None:
            local_ns = call_frame.f_locals
        if module is None:
            global_ns = call_frame.f_globals
            try:
                module = sys.modules[global_ns['__name__']]
            except KeyError:
                warnings.warn("Failed to get module %s" % \
                    global_ns.get('__name__', 'unknown module')
                )
                module = DummyMod()
                module.__dict__ = global_ns
        if compile_flags is None:
            compile_flags = (call_frame.f_code.co_flags & compilerop.PyCF_MASK)

    # Save original namespace and module so we can restore them after
    # embedding; otherwise the shell doesn't shut down correctly.
    orig_user_module = self.user_module
    orig_user_ns = self.user_ns
    orig_compile_flags = self.compile.flags

    # Update namespaces and fire up interpreter

    # The global one is easy, we can just throw it in
    if module is not None:
        self.user_module = module

    # But the user/local one is tricky: ipython needs it to store internal
    # data, but we also need the locals. We'll throw our hidden variables
    # like _ih and get_ipython() into the local namespace, but delete them
    # later.
    # FIX: allow passing self.user_module.__dict__ as local_ns, resulting in
    # the same object as self.user_ns
    # BEGIN_FIX
    #if local_ns is not None:
    if local_ns is self.user_module.__dict__:
        self.user_ns = local_ns
        self.init_user_ns()
    elif local_ns is not None:
        # END_FIX
        reentrant_local_ns = {
            k: v
            for (k, v) in local_ns.items()
            if k not in self.user_ns_hidden.keys()
        }
        self.user_ns = reentrant_local_ns
        self.init_user_ns()

    # Compiler flags
    if compile_flags is not None:
        self.compile.flags = compile_flags

    # make sure the tab-completer has the correct frame information, so it
    # actually completes using the frame's locals/globals
    self.set_completer_frame()

    with self.builtin_trap, self.display_trap:
        self.interact()

    # now, purge out the local namespace of IPython's hidden variables.
    # FIX: do not revert namespace if it was the special case
    #if local_ns is not None:
    if local_ns is not None and local_ns is not self.user_module.__dict__:
        local_ns.update({
            k: v
            for (k, v) in self.user_ns.items()
            if k not in self.user_ns_hidden.keys()
        })

    # Restore original namespace so shell can shut down when we exit.
    self.user_module = orig_user_module
    self.user_ns = orig_user_ns
    self.compile.flags = orig_compile_flags


def try_all_statements(f: types.FunctionType):
    """
    This is a decorator to convert a function using `run_statements_helper`, and
    `try` all statements, allowing the user to modify the function on the fly or
    drop into a shell when an exception occurs. Refer to `run_statements_helper`
    for more details. The wrapper is created by filling in the cells in
    `cell_dict` with the argument values.
    :param f: function to decorate
    :return: wrapped function
    """
    (source, filename, flags, func_line_num,
     private_prefix) = uncompile(f.__code__)
    fmod_ast = parse_snippet(source, filename, "exec", flags, func_line_num,
                             private_prefix)
    # L.info("%s", f'{f.__code__.co_freevars=}')
    co_varnames = f.__code__.co_varnames
    co_freevars = f.__code__.co_freevars
    patcher_cell = types.CellType()
    # TODO: add recursive try_all loop helper
    runner = run_statements_helper(patcher_cell, fmod_ast.body[0].body,
                                   f.__module__, f.__name__, f.__qualname__,
                                   f.__globals__, co_varnames,
                                   f.__code__.co_cellvars, co_freevars,
                                   filename, fmod_ast.body[0].lineno, flags,
                                   magic, True)

    f_closure = f.__closure__ or ()
    f_defaults = f.__defaults__
    f_kwdefaults = f.__kwdefaults__
    co_flags = f.__code__.co_flags
    co_argcount = f.__code__.co_argcount
    co_posonlyargcount = f.__code__.co_posonlyargcount
    co_kwonlyargcount = f.__code__.co_kwonlyargcount

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        cell_dict = dict(zip(co_freevars, f_closure))
        patched = patcher_cell.cell_contents(cell_dict)
        # copy the arguments into the right cells
        # `co_varnames`` is ordered as follows:
        #    co_posonlyargcount
        # <- ^   ->
        #      <- defaulted -> <- (use kwdefaults)  ->
        # <- co_argcount    -> <- co_kwonlyargcount ->
        # <positional_args...> <kw_only_args...      > [<args>] [<kwargs>] [<other_local_vars>]
        # L.info("%s",f'{args=} {kwargs=}')
        non_default_count = co_argcount - (len(f_defaults)
                                           if f_defaults else 0)
        assert len(args) >= non_default_count
        for i in range(co_argcount):
            if i < len(args):
                # L.info("setting %s to %s", co_varnames[i], args[i])
                cell_dict[co_varnames[i]].cell_contents = args[i]
            elif i >= non_default_count:
                cell_dict[co_varnames[i]].cell_contents = f_defaults[
                    i - non_default_count]
            else:
                assert i >= co_posonlyargcount
                cell_dict[co_varnames[i]].cell_contents = kwargs.pop(
                    co_varnames[i])
        next_i = co_argcount
        for i in range(co_kwonlyargcount):
            varname = co_varnames[next_i]
            cell_dict[varname].cell_contents = kwargs.pop(
                varname) if varname in kwargs else f_kwdefaults[varname]
            next_i += 1
        if co_flags & inspect.CO_VARARGS:
            cell_dict[co_varnames[next_i]].cell_contents = args[co_argcount:]
            next_i += 1
        else:
            assert len(args) <= co_argcount
        if co_flags & inspect.CO_VARKEYWORDS:
            cell_dict[co_varnames[next_i]].cell_contents = kwargs
        else:
            assert len(kwargs) == 0
        return runner(patched, 0, cell_dict)

    wrapper.__wrapped__ = None
    return wrapper


class TryBlockTransformer(ast.NodeTransformer):

    def __init__(self):
        self.flag = False

    def visit_Lambda(self, node: ast.Lambda) -> Any:
        if len(node.args.args) == 1 and node.args.args[0].arg == "try_all":
            self.flag = True
        return self.generic_visit(node)

    def visit_For(self, node: ast.For) -> Any:
        if self.flag:
            self.flag = False
            # TODO: new idea: convert `for x, (y, z) in it:` within try_all_statements to
            #   it2 = iter(it)
            #   if i == 0:
            #       try:
            #           _ = next(it)
            #       except:
            #           return True
            #   for _ in [0]:
            #       if i==1:
            #           statement1
            #       if i==2:
            #           statement2
            #       ...

            # TODO: automatically convert `for x, (y, z) in it:` to:
            #   @run_func(it, (1, (1, 1)))
            #   @try_all_statements
            #   def _(x, y, z):
            # and `continue` to `return` and `break` to `return True`
            # how to make variables that become local into nonlocal?
            #   leak locals into cell_dict
        return self.generic_visit(node)


def run_statements_helper(patcher_cell: types.CellType,
                          statements: List[ast.stmt], module: str, name: str,
                          qualname: str, globals: dict, co_varnames: List[str],
                          co_cellvars: List[str], co_freevars: List[str],
                          filename: str, func_line_num: int, flags: int,
                          magic: str, to_try: bool):
    """
    If `to_try` is true, this changes a sequence of statements to a function
    that can run individual statements like so (if it hit the end of the
    statement without returning, it will return _ipy_magic_inner):
    ```
    def _ipy_magic_outer(a, b, c):
        del a, b, c
        def _ipy_magic_inner(_i):
            nonlocal _ipy_magic_inner
            nonlocal a, b, c
            if _i == 0:
                statement0
                return _ipy_magic_inner
            elif _i == 1:
                statement1
                return _ipy_magic_inner
            ...
        return _ipy_magic_inner
    ```

    If `to_try` is false, the function is just wrapped as follows:
    ```
    def _ipy_magic_outer(a, b, c):
        del a, b, c
        def _ipy_magic_inner():
            nonlocal _ipy_magic_inner
            nonlocal a, b, c
            statement0
            ...
        return _inner
    ```

    It then compiles the generated code, runs `_ipy_magic_outer` to get an
    instance of `_ipy_magic_inner`, and constructs a patcher, which takes a
    mapping of var names to the associated cell in order to update the closure
    of `_ipy_magic_inner`. This is to maintain the local variable state across
    multiple calls to allow the statements to be run sequentially. The
    `nonlocal` in `_inner` converts both local variables and cell variables to
    free (closure) variables. If some of these were not fully converted, we add
    the extra variables to `_ipy_magic_outer` and recompile.
    
    We provide the instance of the real function running the code as
    `_ipy_magic_inner` in the scope of the statements so that it is possible to
    call `embed(func=_ipy_magic_inner)` to achieve full control over local
    variables in the embedded shell.

    We then construct the runner function. If `to_try` is true, the statements
    are run in sequence from the given starting index. When an uncaught
    exception is raised, we catch it and allow the user to update the source
    code of the function and specify a line to start running from. After doing
    so, we redo the whole process of the compilation of `_ipy_magic_outer` and
    patcher/runner construction, patch with the current cell dict and start
    running the function from the specified line using the new runner.

    Note that we use a cell to return the patcher because by passing this cell
    to future recursive calls, the original wrapper can be updated with the
    patched function as it will use this cell's contents as the patcher.

    If the user enters a negative number `-x` for the next statement line, an
    embedded shell will be injected into the statement at line `x`, after which
    a raise will once again drop it back here. The special case of `0` is
    treated as `-x` where `x` is the default next statement line (the line of
    the statement which failed).

    General Info:
    If `f` is a function, then
    `f.__code__.co_freevars` are variables in the parent functions which were
    closed over (by any code or child functions) and `f.__code__.co_cellvars`
    variables in the scope of `f` which are closed over by a child function.
    `f.__closure__` gives all the cells in correspondence with
    `f.__code__.co_freevars`. We use FunctionType to alter which cells the
    closure refers to, as mentioned by
    https://stackoverflow.com/questions/59276834/how-to-set-the-content-of-a-closure-cell
    . We are not able to use pure Python to get the cell of a cell variable in
    the stack frame (any variable which was closed over), and we have to
    probably use `ctypes.pythonapi` and follow something like
    `framelocalsproxy_getval`. Fortunately, this is not needed for our limited
    use-case where we already have the function to be converted.

    We are not able to use pure Python to get the cell of a cell variable in
    the stack frame (any variable which was closed over), and we have to
    probably use `ctypes.pythonapi` and follow something like
    `framelocalsproxy_getval`. Fortunately, this is not needed for our limited
    use-case where we already have the function to be converted. However, we
    might need this in the future to allow modifications to a function's cell
    variables (it is almost impossible to modify non-cell variables).

    :param patcher_cell: cell to place the patcher into
    :param statements: statements to wrap
    :param module: name of module (optional)
    :param name: name of function
    :param qualname: qualified name of function (optional)
    :param globals: function globals
    :param co_varnames: function co_varnames
    :param co_cellvars: function co_cellvars
    :param co_freevars: function co_freevars
    :param filename: filename
    :param func_line_num: function line number
    :param flags: function flags
    :param magic: the magic string to use
    :param to_try: True iff exceptions from every statement is to be caught
    :return: runner
        runner(patched, start_i, cell_dict): runs statements from index
        `start_i` after specialising with the `patcher` and `cell_dict`
    """

    statements = [
        AnnotationRemover().visit(statement) for statement in statements
    ]
    # `.0` is an iteration variable that appears because of comprehension
    # inlining in Python 3.12
    co_varnames = [x for x in co_varnames if x != ".0"]
    co_cellvars = list(co_cellvars)
    co_varnames_set = set(co_varnames)
    co_cellvars_set = set(co_cellvars)
    while True:
        # L.info("%s", f'{co_varnames=} {co_cellvars=} {co_freevars=}')
        # co_cellvars might be repeated in co_varnames if it is a parameter
        all_locals = co_varnames + [
            x for x in co_cellvars if x not in co_varnames_set
        ] + list(co_freevars)
        # if func uses more nonlocal variables (more bindings from outer
        # functions), the below will throw an error when compiled, otherwise we
        # can update the list of all locals
        # magic = "_ipy_magic_%s_%d" % (name, id(closure))
        if to_try:
            module_ast = ast.parse("""
                def {0}_outer({1}):
                    {2}del {1}
                    def {0}_inner({0}_i):
                        nonlocal {0}_inner
                        {2}nonlocal {1}
                        if {0}_i == 0:
                            return {0}_inner
                    return {0}_inner
            """.strip().format(magic, ",".join(all_locals),
                               "" if len(all_locals) else "pass # "))
            outer_ast: ast.FunctionDef = module_ast.body[0]
            inner_ast: ast.FunctionDef = outer_ast.body[1]
            if_ast_template: ast.If = inner_ast.body.pop()
            if_asts = []
            for i, statement in enumerate(statements):
                if_ast = copy.deepcopy(if_ast_template)
                # make it compare `_i` with the real statement index
                if_ast.test.comparators[0].value = i
                return_ast: ast.Return = if_ast.body[0]
                if_ast.body.insert(0, statement)
                if_asts.append(if_ast)
            # add all the if statements to the body of `_inner`
            inner_ast.body.extend(if_asts)
        else:
            module_ast = ast.parse("""
                def {0}_outer({1}):
                    {2}del {1}
                    def {0}_inner():
                        nonlocal {0}_inner
                        {2}nonlocal {1}
                    return {0}_inner
            """.strip().format(magic, ",".join(all_locals),
                               "" if len(all_locals) else "pass # "))
            outer_ast: ast.FunctionDef = module_ast.body[0]
            inner_ast: ast.FunctionDef = outer_ast.body[1]
            inner_ast.body.extend(statements)
        # execute definition of `_outer` with empty locals and correct globals
        local_dict = {}
        exec(compile(module_ast, filename, "exec", flags, dont_inherit=True),
             globals, local_dict)
        # execute `_outer` to get `_inner`
        _outer = local_dict[magic + "_outer"]
        _inner: types.FunctionType = _outer(
            *[None for i in range(len(all_locals))])
        # L.info(
        #     "%s",
        #     f'{_inner.__code__.co_varnames=} {_inner.__code__.co_cellvars=}')
        # if there are extra local variables in `co_varnames` and `co_cellvars`,
        # add it to `all_locals` and start over
        if to_try:
            assert _inner.__code__.co_varnames[0] == magic + "_i"
            co_varnames_new = _inner.__code__.co_varnames[1:]
        else:
            co_varnames_new = _inner.__code__.co_varnames
        # must check original set due to a bug in comprehension inlining since
        # Python 3.12. the bug causes co_varnames to include comprehension
        # variables arguably incorrectly, and variables can disappear from
        # co_varnames expectedly. See
        # https://github.com/python/cpython/issues/121377 .
        vars_changed = False
        for varname in co_varnames_new:
            if varname not in co_varnames_set:
                co_varnames.append(varname)
                vars_changed = True
        co_varnames_set = set(co_varnames)
        for varname in _inner.__code__.co_cellvars:
            if varname not in co_cellvars_set:
                co_cellvars.append(varname)
                vars_changed = True
        co_cellvars_set = set(co_cellvars)
        if vars_changed:
            continue
        else:
            break

    all_locals_len = len(all_locals)

    # construct patcher
    def patcher(cell_dict: Dict[str, types.CellType]):
        # patch `_inner` with the correct closure, recalling original free
        # variables, previously added cells, and remembering new cells
        _inner: types.FunctionType = _outer(
            *[None for i in range(all_locals_len)])
        c = _inner.__code__
        inner_code = c.replace(co_filename=filename, co_name=name)
        new_closure = [
            cell_dict.setdefault(x, y)
            for x, y in zip(inner_code.co_freevars, _inner.__closure__)
        ]
        patched = types.FunctionType(inner_code,
                                     globals,
                                     name,
                                     argdefs=None,
                                     closure=tuple(new_closure))
        cell_dict[magic + "_inner"].cell_contents = patched
        if module:
            patched.__module__ = module
        if qualname:
            patched.__qualname__ = qualname
        return patched

    patcher_cell.cell_contents = patcher

    if not to_try:

        def runner(patched, start_i: int, cell_dict: Dict[str,
                                                          types.CellType]):
            for i in range(start_i, len(statements)):
                # L.info("running statement %d", i)
                ret = patched(i)
            return ret

        return runner

    # construct runner
    def runner(patched, start_i: int, cell_dict: Dict[str, types.CellType]):
        # run statements until end or error
        for i in range(start_i, len(statements)):
            try:
                # L.info("running statement %d", i)
                ret = patched(i)
                if ret is not patched:
                    return ret
            except:
                L.info("exception raised", exc_info=True)
                while True:
                    filename_new = prompt_with_default("filename", filename)
                    func_line_num_new = prompt_with_default(
                        "function line num", func_line_num, int)
                    try:
                        with open(filename, "rb") as src_f:
                            module_src = src_f.read()
                        # get AST using flag PyCF_ONLY_AST
                        module_ast: ast.Module = compile(module_src,
                                                         filename,
                                                         "exec",
                                                         flags
                                                         | ast.PyCF_ONLY_AST,
                                                         dont_inherit=True)
                        # retrieve function at specified line
                        helper = GetFuncAtLine(func_line_num_new)
                        func_ast: ast.FunctionDef = helper.visit(module_ast)
                        # get all statements in function
                        statements_new = func_ast.body
                        line_nums = [s.lineno for s in statements_new]
                        # prompt for next statement line
                        next_line_num = prompt_with_default(
                            "next statement line num", line_nums[i], int)
                        use_embed = next_line_num <= 0
                        if next_line_num == 0:
                            next_line_num = line_nums[i]
                        elif next_line_num < 0:
                            next_line_num = -next_line_num
                        next_i = [
                            i for i, num in enumerate(line_nums)
                            if num == next_line_num
                        ][0]
                        if use_embed:
                            magic_embed = magic + "_embed"
                            if magic_embed not in co_varnames_set:
                                co_varnames.append(magic_embed)
                                co_varnames_set.add(magic_embed)
                            if magic_embed not in cell_dict:
                                cell_dict[magic_embed] = types.CellType(embed)
                            statement_orig = statements_new[next_i]
                            statement_new: ast.With = ast.parse(
                                "if True:\n {0}_embed(funcs=[{0}_inner])\n raise"
                                .format(magic)).body[0]
                            statements_new[next_i] = ast.copy_location(
                                statement_new, statement_orig)
                        # run the compilation and wrapper construction anew
                        runner_new = run_statements_helper(
                            patcher_cell, statements_new, module, name,
                            qualname, globals, co_varnames, co_cellvars,
                            co_freevars, filename_new, func_line_num_new,
                            flags, magic, True)
                        # TODO: patch global functions
                        patched_new = patcher_cell.cell_contents(cell_dict)
                        break
                    except:
                        L.error("error in setting up new function",
                                exc_info=True)
                # finally run the new wrapper
                return runner_new(patched_new, next_i, cell_dict)
        return None

    return runner


class AnnotationRemover(ast.NodeTransformer):
    """
    removes type annotations of variables that are currently local but would
    become non-local
    """

    def __init__(self):
        super().__init__()
        self.is_top_level = True

    def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
        self.is_top_level = False
        return self.generic_visit(node)

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
        self.is_top_level = False
        return self.generic_visit(node)

    def visit_ClassDef(self, node: ast.ClassDef) -> Any:
        self.is_top_level = False
        return self.generic_visit(node)

    def generic_visit(self, node):
        is_top_level = self.is_top_level
        for field, old_value in ast.iter_fields(node):
            if isinstance(old_value, list):
                new_values = []
                for value in old_value:
                    if isinstance(value, ast.AST):
                        value = self.visit(value)
                        self.is_top_level = is_top_level
                        if value is None:
                            continue
                        elif not isinstance(value, ast.AST):
                            new_values.extend(value)
                            continue
                    new_values.append(value)
                old_value[:] = new_values
            elif isinstance(old_value, ast.AST):
                new_node = self.visit(old_value)
                self.is_top_level = is_top_level
                if new_node is None:
                    delattr(node, field)
                else:
                    setattr(node, field, new_node)
        return node

    def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
        if self.is_top_level and isinstance(node.target, ast.Name):
            C = lambda x: ast.copy_location(x, node)
            load = ast.Load()
            subscript = C(
                ast.Subscript(
                    value=C(
                        ast.Tuple(elts=[
                            C(ast.Constant(value=None, kind=None))
                            if node.value is None else super().generic_visit(
                                node.value),
                            super().generic_visit(node.annotation)
                        ],
                                  ctx=load)),
                    slice=C(
                        ast.Index(value=C(ast.Constant(value=0, kind=None)))),
                    ctx=load))
            return C(
                ast.Expr(value=subscript) if node.value is None else ast.
                Assign(targets=[super().generic_visit(node.target)],
                       value=subscript))
        return self.generic_visit(node)


class GetFuncAtLine(ast.NodeVisitor):

    def __init__(self, func_line_num):
        super().__init__()
        self.func_line_num = func_line_num

    def generic_visit(self, node):
        """Called if no explicit visitor function exists for a node."""
        for field, value in ast.iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AST):
                        ret = self.visit(item)
                        if ret:
                            return ret
            elif isinstance(value, ast.AST):
                ret = self.visit(value)
                if ret:
                    return ret

    def visit_FunctionDef(self, node: ast.FunctionDef):
        # L.info("%s", f'{node.lineno=}')
        if node.lineno == self.func_line_num:
            return node
        return self.generic_visit(node)


PyCF_MASK = functools.reduce(operator.or_,
                             (getattr(__future__, fname).compiler_flag
                              for fname in __future__.all_feature_names))


def uncompile(c: types.CodeType):
    """uncompile(codeobj) -> (source, filename, flags, func_line_num, private_prefix)."""
    if c.co_name == "<lambda>":
        raise NotImplementedError("Lambda functions not supported")
    if c.co_filename == "<string>":
        raise NotImplementedError("Code without source file not supported")

    filename = inspect.getfile(c)

    try:
        lines, func_line_num = inspect.getsourcelines(c)
    except IOError:
        raise RuntimeError("Source code not available")

    source = "".join(lines)

    # __X is mangled to _ClassName__X in methods. Find this prefix:
    private_prefix = None
    for name in c.co_names:
        m = re.match("^(_[A-Za-z][A-Za-z0-9_]*)__.*$", name)
        if m:
            private_prefix = m.group(1)
            break

    return (source, filename, c.co_flags & PyCF_MASK, func_line_num,
            private_prefix)


def parse_snippet(source: str,
                  filename: str,
                  mode: str,
                  flags: int,
                  firstlineno: int,
                  privateprefix_ignored: str = None) -> ast.Module:
    """Like ast.parse, but accepts indented code snippet with a line number offset."""
    args = filename, mode, flags | ast.PyCF_ONLY_AST, True
    prefix = "\n"
    try:
        a: ast.Module = compile(prefix + source, *args)
    except IndentationError:
        # Already indented? Wrap with dummy compound statement
        prefix = "with 0:\n"
        a: ast.Module = compile(prefix + source, *args)
        # Peel wrapper
        a.body = a.body[0].body
    ast.increment_lineno(a, firstlineno - 2)
    return a


def recompile(source,
              filename,
              mode,
              flags=0,
              firstlineno=1,
              privateprefix=None):
    """Recompile output of uncompile back to a code object. Source may also be preparsed AST."""
    if isinstance(source, ast.AST):
        a = source
    else:
        a = parse_snippet(source, filename, mode, flags, firstlineno)

    node = a.body[0]

    if not isinstance(node, ast.FunctionDef):
        raise RuntimeError("Expecting function AST node")

    c0 = compile(a, filename, mode, flags, True)

    return c0


def prompt_with_default(prompt, def_val, transform=(lambda x: x)):

    while True:
        input_str = input("%s [%s]:" % (prompt, def_val))
        if input_str == "":
            return def_val
        try:
            return transform(input_str)
        except:
            pass


def run_func(it=None, structure=1, *args, **kwargs):
    """
    This is a decorator to run a function, iterating with `it` as first argument
    and with the remaining args as subsequent arguments
    :param f: function to decorate
    :return: wrapped function
    """
    if it is None:

        def decorator(f: types.FunctionType):
            f(*args, **kwargs)

        return decorator

    if isinstance(structure, tuple):

        def decorator(f: types.FunctionType):
            for x in it:
                if f(*flatten_tuple(x, structure), *args, **kwargs) == "break":
                    break

        return decorator

    if structure == 1:

        def decorator(f: types.FunctionType):
            for x in it:
                if f(x, *args, **kwargs) == "break":
                    break

        return decorator
    if structure == -1:

        def decorator(f: types.FunctionType):
            for x in it:
                if f(*args, **kwargs) == "break":
                    break

        return decorator

    if structure == 0:

        def decorator(f: types.FunctionType):
            for x in it:
                if f(*x, *args, **kwargs) == "break":
                    break

        return decorator

    def decorator(f: types.FunctionType):
        for x in it:
            assert len(x) == structure
            if f(*x, *args, **kwargs) == "break":
                break

    return decorator


def flatten_tuple(data, structure, cache={}):
    if not (unpacker := cache.get(structure)):

        def get_unstruct(structure, n_vars):
            if isinstance(structure, int):
                if structure == 0:
                    return "*x%d" % n_vars, n_vars + 1
                if structure == -1:
                    return "_", n_vars
                new_n_vars = n_vars + structure
                ret = ",".join("x%d" % i for i in range(n_vars, new_n_vars))
                n_vars = new_n_vars
                return ret, new_n_vars
            assert isinstance(structure, tuple)
            unstruct_code = "("
            for x in structure:
                unstruct_code_one, n_vars = get_unstruct(x, n_vars)
                unstruct_code += unstruct_code_one + ","
            return unstruct_code + ")", n_vars

        unstruct_code, n_vars = get_unstruct(structure, 0)
        local_dict = {}
        exec(
            compile(
                "def unpacker(data):\n %s = data\n return %s" %
                (unstruct_code,
                 ("x" +
                  ",x".join(map(str, range(n_vars))) if n_vars else "(,)")),
                "",
                "exec",
                dont_inherit=True), {}, local_dict)
        unpacker = local_dict["unpacker"]
        cache[structure] = unpacker
    return unpacker(data)
