#!/usr/bin/env python

'''
Quick runs Clojure code quickly.

Commands:

  eval          FORM                             Evals given form.
  repl          [PORT]                           Connects a repl to a running nREPL server.
  run           NAMESPACE[/FUNCTION] [ARGS...]   Runs existing defn.
  lein          [TASK ARGS...]                   Runs a Leiningen task.
  start         [PORT]                           Start a nREPL server.
  kill          [PORT]                           Kill a running nREPL server.
  restart       [PORT]                           Restart a running nREPL server.

Running with no arguments will read code from stdin.
'''

# for python 3 compatibility, in case nrepl author ever ports to python 3
from __future__ import (absolute_import, division, print_function, unicode_literals)
from future import standard_library
from future.builtins import *
from future.utils import native_str

import os, os.path, sys, subprocess, re, readline, warnings, uuid
from pprint import pformat
import nrepl

debug = sys.stderr if 'DEBUG' in os.environ else open(os.devnull, 'wb')

def find_root(cwd=None):
    if cwd is None: cwd = os.getcwd()
    while not os.path.dirname(cwd) == cwd:
        if os.path.isfile(os.path.join(cwd, 'project.clj')):
            return cwd
        cwd = os.path.dirname(cwd)
    return None

def get_port():
    ''' Figure out what port the nREPL server is set to. '''
    # check to see if LEIN_REPL_PORT is set.
    if os.environ.get('LEIN_REPL_PORT') is not None and os.environ.get('LEIN_REPL_PORT').isdigit():
        return int(os.environ.get('LEIN_REPL_PORT'))

    # look for first project.clj in current directory and all parent directories to
    # find the project directory. Check to see if a .nrepl-port file exists in the
    # project directory.
    root = find_root()
    if root is not None:
        nrepl_port = os.path.join(root, '.nrepl-port')
        if os.path.isfile(nrepl_port):
            with open(nrepl_port) as file:
                port = file.read()
                if port is not None and port.isdigit():
                    return int(port)

    # Check to see if ~/.lein/repl-port file exists.
    repl_port = os.path.expanduser('~/.lein/repl-port')
    if os.path.isfile(repl_port):
        with open(repl_port) as file:
            port = file.read()
            if port is not None and port.isdigit():
                return int(port)
    # no running port number could be found
    return None

# TODO: fix race condition
def start_nrepl(port=None):
    proc = subprocess.Popen(
        ['lein','repl',':headless']+([':port',port] if port is not None else []),
        stdout=subprocess.PIPE, stderr=open(os.devnull,'wb'))
    line = proc.stdout.readline().rstrip()
    if line is not None:
        match = re.match(r"^nREPL server started on port ([0-9]+) on host ([^ ]+)$", line)
        port = None
        if match is not None:
            port = match.group(1)
        else:
            raise Exception("Could not match lein repl output: "+line)
        if port is not None:
            port = int(port)
            #print(line, file=sys.stderr)
    return port

def start(port=None, quiet=True):
    running_port = get_port()
    if running_port:
        if not quiet:
            print('nREPL server is already running on port %s' % (running_port), file=sys.stderr)
        return running_port
    # if no port could be found, spawn a "lein repl :headless" as a background
    # process, then read the port number from stdout.
    return start_nrepl(port)

def kill(port=None, quiet=True):
    running_port = get_port()
    if running_port is None:
        print('No nREPL session is currently running')
        sys.exit(1)
    evaluate('(System/exit 0)', port=running_port)
    if not quiet:
        print('Terminated nREPL instance on port %s' % (running_port))

def restart(port=None):
    kill()
    port = start_nrepl(port)
    print("Restarted nREPL on port %s" % (port))

class EvalError(Exception): pass

current_id = None
def write(con, message):
    if 'id' in message: current_id = message['id']
    con.write(message)
    print('\033[93m'+pformat(message)+'\033[0m', file=debug)

def read(con):
    message = con.read()
    print('\033[91m'+pformat(message)+'\033[0m', file=debug)
    return message

def evaluate(code, port=None, out=sys.stdout, err=sys.stderr, **kwargs):
    # connect to the nREPL server
    if port is None: port = start()
    url = 'nrepl://localhost:%i' % (port)
    con = nrepl.connect(native_str(url))

    # get a new session id
    write(con, {'id': 'init', 'op': 'clone'})
    message = read(con)
    session = message['new-session']
    if session is None: raise Exception("No session returned from clone op")

    # run the eval command on the code
    opts = {'op': 'eval', 'code': code, 'id': 'main', 'session': session}
    opts.update(kwargs)
    write(con, opts)

    # handle any messages sent back until all the task ids are status done
    value = None
    ids = {'main': True}
    while any(ids):
        try:
            message = read(con);
            if message is None:
                print('Got None message', file=debug)
                break
            elif not isinstance(message,dict):
                warnings.warn('Got non-dict message: '+pformat(message))
            elif 'status' in message and isinstance(message['status'],list) and len(message['status'])>0:
                if message['status'][0] == 'done' and 'id' in message:
                    del ids[message['id']]
                elif message['status'][0] == 'eval-error':
                    opts = {'op': 'eval', 'code': '(.printStackTrace *e)', 'session': session}
                    write(con, opts)
                    stacktrace = read(con)
                    stacktrace_status = read(con) if 'value' in stacktrace else stacktrace
                    raise EvalError(stacktrace['err'])
                elif message['status'][0] == 'need-input':
                    id = str(uuid.uuid4())
                    opts = {'op': 'stdin',
                            'session': session,
                            'stdin': sys.stdin.read(1) if sys.stdin.isatty() else sys.stdin.read(),
                            'id': id}
                    ids[id] = True
                    write(con, opts)
                else:
                    raise Exception('Eval returned bad exit status: %s' % (message['status'][0]))
            elif 'value' in message:
                value = message
            elif 'out' in message:
                print(message['out'], end='', file=out)
            elif 'err' in message:
                print(message['err'], end='', file=err)
            else: 
                warnings.warn("Don't know how to handle message: %s" % (message))
        except KeyboardInterrupt:
            if current_id is not None:
                opts = {'op': 'interrupt',
                        'session': session,
                        'interrupt-id': current_id,
                        'id': 'interrupt-'+str(uuid.uuid4())}
                write(con, opts)
            
    return value
    
def repl(port=None):
    # send a dummy form to get the current namespace
    ns = evaluate('nil', port)['ns']
    # start the repl
    while (True):
        try:
            line = input('%s=>' % (ns))
            result = evaluate(line, port)
            if result is not None:
                if 'ns' in result:
                    ns = result['ns']
                print(result['value'])
        except (EOFError, KeyboardInterrupt):
            print()
            break
        except (EvalError) as e:
            print(str(e), file=sys.stderr)
        except:
            print('Unexpected error:', sys.exc_info()[0])
            raise

def escape(*symbols):
    return ' '.join(map(lambda x: '"'+x.replace(r'["\\]',r'\\\1')+'"', symbols))

def run(name, *args):
    code = """(do
                (require '[clojure.stacktrace :refer [print-cause-trace]])
                (let [raw (symbol %s)
                      ns (symbol (or (namespace raw) raw))
                      m-sym (if (namespace raw) (symbol (name raw)) '-main)]
                  (require ns)
                  (try ((ns-resolve ns m-sym) %s)
                  (catch Exception e
                    (let [c (:exit-code (ex-data e))]
                      (when-not (and (number? c) (zero? c))
                        (print-cause-trace e)))))))""" % (escape(name), escape(*args))
    result = evaluate(code, ns='user')
    return None if result is None else result['value']

def lein(*args):
    cwd = os.getcwd()
    root = find_root()
    code = """(binding [*cwd* %s, *exit-process?* false]
                       (System/setProperty "leiningen.original.pwd" %s)
                       (defmethod leiningen.core.eval/eval-in :default
                         [project form]
                         (leiningen.core.eval/eval-in
                           (assoc project :eval-in :nrepl) form))
                       (defmethod leiningen.core.eval/eval-in :trampoline
                         [& _] (throw (Exception. "trampoline disabled")))
                       (try (-main %s)
                         (catch clojure.lang.ExceptionInfo e
                           (let [c (:exit-code (ex-data e))]
                             (when-not (and (number? c) (zero? c))
                               (throw e))))))""" % (
                                   escape(cwd if root is None else root),
                                   escape(cwd),
                                   escape(*args))
    result = evaluate(code, ns='leiningen.core.main')
    return None if result is None else result['value']

if 'REPLOAD' in os.environ:
    evaluate("(require 'repload) (require '[repload :refer [repload]]) (repload)",
             out=open(os.devnull,'wb'))

# parse command-line arguments
if len(sys.argv) > 2 and sys.argv[1] == 'eval': print(evaluate(sys.argv[2])['value'])
elif len(sys.argv) > 1 and sys.argv[1] == 'start': start(*sys.argv[2:3], quiet=False)
elif len(sys.argv) > 1 and sys.argv[1] == 'kill': kill(*sys.argv[2:3], quiet=False)
elif len(sys.argv) > 1 and sys.argv[1] == 'restart': restart(*sys.argv[2:3])
elif len(sys.argv) > 1 and sys.argv[1] == 'repl': repl(*sys.argv[2:3])
elif len(sys.argv) > 2 and sys.argv[1] == 'run': run(*sys.argv[2:])
elif len(sys.argv) > 1 and sys.argv[1] == 'lein': lein(*sys.argv[2:])
elif not sys.stdin.isatty(): print(evaluate(sys.stdin.read())['value'])
else:
    print(__doc__, file=sys.stderr)
    sys.exit(1)

sys.exit(0)
