#!/usr/bin/env python

'''
Quick runs Clojure code quickly.

Commands:

  eval          FORM                             Evals given form.
  repl          [PORT]                           Connects a repl to a running nREPL server.
  run           NAMESPACE[/FUNCTION] [ARGS...]   Runs existing defn.
  lein          [TASK ARGS...]                   Runs a Leiningen task.
  start         [PORT]                           Start a nREPL server.
  kill          [PORT]                           Kill a running nREPL server.
  restart       [PORT]                           Restart a running nREPL server.

Running with no arguments will read code from stdin.
'''

# for python 3 compatibility, in case nrepl author ever ports to python 3
from __future__ import (absolute_import, division, print_function, unicode_literals)
from future import standard_library
from future.builtins import *
from future.utils import native_str, native

import os, os.path, sys, subprocess, re, readline, warnings, uuid, __builtin__, time
from pprint import pformat
import nrepl

devnull = open(os.devnull,'w')
debug = sys.stderr if 'DEBUG' in os.environ else devnull

def find_root(cwd=None):
    if cwd is None: cwd = os.getcwd()
    while not os.path.dirname(cwd) == cwd:
        if os.path.isfile(os.path.join(cwd, 'project.clj')):
            return cwd
        cwd = os.path.dirname(cwd)
    return None

def get_port(toplevel=None):
    ''' Figure out what port the nREPL server is set to. '''
    # check to see if LEIN_REPL_PORT is set.
    if os.environ.get('LEIN_REPL_PORT') is not None and os.environ.get('LEIN_REPL_PORT').isdigit():
        return int(os.environ.get('LEIN_REPL_PORT'))

    # look for first project.clj in current directory and all parent directories to
    # find the project directory. Check to see if a .nrepl-port file exists in the
    # project directory.
    if toplevel is None or toplevel == False:
        root = find_root()
        if root is not None:
            nrepl_port = os.path.join(root, '.nrepl-port')
            if os.path.isfile(nrepl_port):
                with open(nrepl_port) as file:
                    port = file.read()
                    if port is not None and port.isdigit():
                        return int(port)
            else:
                return None

    # Check to see if ~/.lein/repl-port file exists.
    if toplevel is None or toplevel == True:
        repl_port = os.path.expanduser('~/.lein/repl-port')
        if os.path.isfile(repl_port):
            with open(repl_port) as file:
                port = file.read()
                if port is not None and port.isdigit():
                    return int(port)
    # no running port number could be found
    return None

# TODO: fix race condition
def start_nrepl(port=None, toplevel=False):
    cwd = {'cwd': os.path.expanduser('~')} if toplevel else {}
    proc = subprocess.Popen(
        ['lein','repl',':headless']+([':port',port] if port is not None else []),
        stdout=subprocess.PIPE, stderr=devnull, stdin=devnull, **cwd)

    line = proc.stdout.readline()
    port = None
    while port is None and line is not None:
        print(line, end='', file=sys.stderr)
        match = re.match(r"^.*port ([0-9]+).*$", line.rstrip())
        if match is not None:
            port = int(match.group(1))
            break
        line = proc.stdout.readline()

    if port is None:
        raise Exception("Could not determine port of nREPL instance")
    return port

def start(port=None, quiet=True, toplevel=None):
    running_port = get_port(toplevel=toplevel)
    if running_port:
        if not quiet:
            print('nREPL server is already running on port %s' % (running_port), file=sys.stderr)
        return running_port
    # if no port could be found, spawn a "lein repl :headless" as a background
    # process, then read the port number from stdout.
    return start_nrepl(port=port, toplevel=toplevel)

def kill(port=None, quiet=True):
    running_port = get_port()
    if running_port is None:
        print('No nREPL session is currently running')
        sys.exit(1)
    evaluate('(System/exit 0)', port=running_port)
    if not quiet:
        print('Terminated nREPL instance on port %s' % (running_port))

def restart(port=None):
    kill()
    port = start_nrepl(port)
    print("Restarted nREPL on port %s" % (port))

class EvalError(Exception): pass

current_id = None
def write(con, message):
    if 'id' in message: current_id = message['id']
    con.write(message)
    print('\033[93m'+pformat(message)+'\033[0m', file=debug)

def read(con):
    message = con.read()
    print('\033[91m'+pformat(message)+'\033[0m', file=debug)
    return message

def load_repload(port=None, con=None, session=None, id='main'):
    code = """(require 'repload)
              (require '[repload :refer [repload]])
              (repload)"""
    evaluate(code, port=port, con=con, session=session, out=debug, err=debug)

# TODO: handle exit status
def evaluate(code, port=None, con=None, out=sys.stdout, err=sys.stderr,
             session=None, id='main', toplevel=None, **kwargs):
    # connect to the nREPL server
    if port is None: port = start(toplevel=toplevel)
    if con is None:
        url = 'nrepl://localhost:%i' % (port)
        con = nrepl.connect(native_str(url))

    # get a new session id
    if session is None:
        write(con, {'id': 'init', 'op': 'clone'})
        message = read(con)
        session = message['new-session']
        if session is None: raise Exception("No session returned from clone op")
        if 'REPLOAD' in os.environ: load_repload(port=port, con=con, session=session, id=id)

    # run the eval command on the code
    opts = {'op': 'eval', 'code': code, 'id': id, 'session': session}
    opts.update(kwargs)
    write(con, opts)

    # handle any messages sent back until all the task ids are status done
    value = None
    ids = {id: True}
    while any(ids):
        try:
            message = read(con);
            if message is None:
                print('Got None message', file=debug)
                break
            elif not isinstance(message,dict):
                warnings.warn('Got non-dict message: '+pformat(message))
            elif 'status' in message and isinstance(message['status'],list) and len(message['status'])>0:
                if message['status'][0] == 'done' and 'id' in message:
                    del ids[message['id']]
                elif message['status'][0] == 'eval-error':
                    pass
                elif message['status'][0] == 'need-input':
                    input_id = str(uuid.uuid4())
                    opts = {'op': 'stdin',
                            'session': session,
                            'stdin': sys.stdin.read(1) if sys.stdin.isatty() else sys.stdin.read(),
                            'id': input_id}
                    ids[input_id] = True
                    write(con, opts)
                else:
                    raise Exception('Eval returned bad exit status: %s' % (message['status'][0]))
            elif 'value' in message:
                value = message
            elif 'out' in message:
                print(message['out'], end='', file=out)
            elif 'err' in message:
                print(message['err'], end='', file=err)
            else: 
                warnings.warn("Don't know how to handle message: %s" % (message))
        except KeyboardInterrupt:
            if current_id is not None:
                opts = {'op': 'interrupt',
                        'session': session,
                        'interrupt-id': current_id,
                        'id': 'interrupt-'+str(uuid.uuid4())}
                write(con, opts)
    return value
    
def repl(port=None, con=None):
    readline_config = ['set blink-matching-paren on']
    for line in readline_config:
        readline.parse_and_bind(line)

    # get a connection
    if port is None: port = start()
    if con is None:
        url = 'nrepl://localhost:%i' % (port)
        con = nrepl.connect(native_str(url))
    
    # get a session
    write(con, {'id': 'init', 'op': 'clone'})
    message = read(con);
    session = message['new-session']
    if session is None: raise Exception("No session returned from clone op")
    if 'REPLOAD' in os.environ: load_repload(port=port, con=con, session=session, id=id)

    # get the default namespace
    ns = evaluate('nil', port=port, con=con, session=session)['ns']

    # send a dummy form to get the current namespace
    # start the repl
    while (True):
        try:
            line = input('%s=>' % (ns))
            code = """(try %s (catch Exception e (clojure.stacktrace/print-cause-trace e)))""" % (line)
            result = evaluate(code, port=port, con=con, session=session)
            if result is not None:
                if 'ns' in result:
                    ns = result['ns']
                print(result['value'])
        except (EOFError, KeyboardInterrupt):
            print()
            break
        except (EvalError) as e:
            print(str(e), file=sys.stderr)
        except:
            print('Unexpected error:', sys.exc_info()[0])
            raise

def escape(*symbols):
    return ' '.join(map(lambda x: '"'+x.replace(r'["\\]',r'\\\1')+'"', symbols))

def run(name, *args):
    code = """(let [raw (symbol %s)
                      ns (symbol (or (namespace raw) raw))
                      m-sym (if (namespace raw) (symbol (name raw)) '-main)]
                  (require ns)
                  (try ((ns-resolve ns m-sym) %s)
                    (catch Exception e
                      (let [c (:exit-code (ex-data e))]
                        (clojure.stacktrace/print-cause-trace e)
                        (if (number? c) c 0)))))""" % (escape(name), escape(*args))
    result = evaluate(code, ns='user')
    return None if result is None else int(result['value'])

def lein(*args):
    cwd = os.getcwd()
    root = find_root()
    code = """(binding [*cwd* %s, *exit-process?* false]
                       (System/setProperty "leiningen.original.pwd" %s)
                       (defmethod leiningen.core.eval/eval-in :default
                         [project form]
                         (leiningen.core.eval/eval-in
                           (assoc project :eval-in :nrepl) form))
                       (defmethod leiningen.core.eval/eval-in :trampoline
                         [& _] (throw (Exception. "trampoline disabled")))
                       (try (-main %s))
                       (catch clojure.lang.ExceptionInfo e
                         (let [c (:exit-code (ex-data e))]
                           (if (number? c) c 0))))""" % (
                                   escape(cwd if root is None else root),
                                   escape(cwd),
                                   escape(*args))
    result = evaluate(code, ns='leiningen.core.main', toplevel=True)
    return None if result is None or result['value'] == 'nil' else __builtin__.int(result['value']) % 128

# parse command-line arguments
exit_code = None
if len(sys.argv) > 2 and sys.argv[1] == 'eval': print(evaluate(sys.argv[2])['value'])
elif len(sys.argv) > 1 and sys.argv[1] == 'start': start(*sys.argv[2:3], quiet=False)
elif len(sys.argv) > 1 and sys.argv[1] == 'kill': kill(*sys.argv[2:3], quiet=False)
elif len(sys.argv) > 1 and sys.argv[1] == 'restart': restart(*sys.argv[2:3])
elif len(sys.argv) > 1 and sys.argv[1] == 'repl': repl(*sys.argv[2:3])
elif len(sys.argv) > 2 and sys.argv[1] == 'run': exit_code = run(*sys.argv[2:])
elif len(sys.argv) > 1 and sys.argv[1] == 'lein': exit_code = lein(*sys.argv[2:])
elif not sys.stdin.isatty(): print(evaluate(sys.stdin.read())['value'])
else:
    print(__doc__, file=sys.stderr)
    sys.exit(1)

exit_code = 0 if exit_code is None else exit_code
print("Exited with status: "+pformat(exit_code), file=debug)
sys.exit(exit_code)
