#!/usr/bin/env python
import argparse
import os
import subprocess
import re
import sys
import itertools
import os.path
import time
import collections
import zlib
from six import with_metaclass

def readCL():
    parser = argparse.ArgumentParser()
    parser.add_argument("-f","--infile", default="Snakefile")
    parser.add_argument("-v","--verbose",action="store_true")
    parser.add_argument("-p","--print_rules",action="store_true")
    parser.add_argument("--ok",action="store_true",help="mark all rules as okay")
    parser.add_argument("rules", nargs="*")
    args = parser.parse_args()
    return args.infile, args.verbose, args.print_rules, args.rules, args.ok


class Singleton(type):
    """
    metaclass singleton pattern from
    http://stackoverflow.com/questions/6760685/creating-a-singleton-in-python
    """
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]

class DependencyGraphSingleton(with_metaclass(Singleton,object)):
    def __init__(self):
        self._rules = []
        self._required_in_nodes = set()
        self._forward_edges = {} #input node -> output node
        self._backward_edges = {} #output node -> input node
    def add_rule(self, rule):
        self._rules.append(rule)
        out_nodes = rule.out_nodes()
        for n in out_nodes:
            if n in self._backward_edges:
                raise Exception("ERROR: output file {n} listed in multiple rules".format(**vars()))
            self._backward_edges[n] = rule
    def backward_edges(self):
        """Dictionary from files to the rule that creates that file"""
        return self._backward_edges
    def rules(self):
        return self._rules
    def all_out_nodes(self):
        out = []
        for r in self._rules:
            for n in r.out_nodes():
                out.append(n)
        return out
    def setup(self, snake_dir):
        self.snake_dir = snake_dir
    @property
    def active_tags(self):
        """Store the list of currently active tags. Tags blocks rules from being run
        unless the rule was forced or all the rule's tags are active.
        """
        return self._active_tags
    @active_tags.setter
    def active_tags(self, tag_list):
        self._active_tags = tag_list
    def tags_ok(self, rule):
        """Check if the tags are okay to run the rule"""
        blocked_by_tags = [f for f in rule.out_nodes() + rule.in_nodes() if is_tag(f) and not f in self._active_tags]
        return not blocked_by_tags
    def check(self, rule):
        """Check whether to run the rule. This check depends only on the rule itself, not other rules upstream or downstream
        Returns: (boolean, reason_string)
        """
        outputs = rule.out_nodes()
        inputs = rule.in_nodes()
        #missing input files
        missing_inputs = [f for f in inputs if not os.path.exists(f) and not is_tag(f)]

        uncreatable_inputs = [f for f in missing_inputs if not f in self.backward_edges()] #input file is missing and it won't be generated by another rule
        if any(uncreatable_inputs):
            raise Exception("ERROR: can't find input files " + str(uncreatable_inputs))
        if not self.tags_ok(rule):
            return False, "blocked by tags"
        #check if any output files are missing
        missing_files = [not os.path.exists(f) for f in rule.all_nodes() if not is_tag(f)] #either missing inputs or outputs
        if any(missing_files) and not missing_inputs:
            return True,  "missing output"
        #check if the cmd changed since the last time it was run
        if rule.cmd_cache_stale() and rule._cachecheck:
            return True,  "bash command changed since last run"
        return self.check_timestamp(rule)
    def check_timestamp(self, rule):
        outputs = [f for f in rule.out_nodes() if os.path.exists(f)]
        inputs = [f for f in rule.in_nodes() if os.path.exists(f)]
        if not inputs:
            return False, "no input files"
        min_out_time = min([os.path.getmtime(f) for f in outputs if not is_tag(f)] + [1e20]) #default 1e20

        max_in_time = max([os.path.getmtime(f) for f in inputs if not is_tag(f)] +     [-1]) #default -1

        last_ok_time = rule.last_ok_time() #last time the rule was marked as ok

        if not rule._timecheck:
            return False, "timecheck turned off"

        if max_in_time <= last_ok_time:
            return False, "marked ok"

        if max_in_time > min_out_time:
            return True, "timecheck"
        else:
            return False, "timecheck failed"

def run(cmd):
    import subprocess
    pipes = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, executable='/bin/bash', shell=True)
    stdout, stderr = pipes.communicate()
    return_code = pipes.returncode
    return stdout, stderr, return_code

class Rule(object):
    def __init__(self, out_nodes, in_nodes, cmd, timecheck, cachecheck):
        self._out_nodes = out_nodes
        self._in_nodes = in_nodes
        pipefail_cmd = "set -o pipefail"
        crash_on_first_error = "set -e"
        inputs = [("INPUT{i}".format(**vars()),var) for i,var in enumerate(self._in_nodes)]
        outputs = [("OUTPUT{i}".format(**vars()),var) for i,var in enumerate(self._out_nodes)]
        set_variables_cmd = " ".join(['{name}="{val}";'.format(**vars()) for name,val in inputs]) + " " +\
                            " ".join(['{name}="{val}";'.format(**vars()) for name,val in outputs])
        self._main_cmd = cmd

        self._print_cmd = self._main_cmd
        for name,val in inputs + outputs:
            self._print_cmd = self._print_cmd.replace("$"+name, val)

        cmd = pipefail_cmd + "\n" + crash_on_first_error + "\n" + set_variables_cmd + '\n' + self._main_cmd
        self._cmd = cmd
        self._timecheck = timecheck
        self._cachecheck = cachecheck
    def hash_(self):
        return abs(zlib.adler32(to_bytes(str(self.out_nodes()))))
    def ok_file(self):
        #write the time that the cmd is marked as ok
        out_hash = self.hash_()
        snake_dir = DependencyGraphSingleton().snake_dir
        ok_file = "{snake_dir}/ok_{out_hash}".format(**vars())
        return ok_file
    def mark_ok(self):
        with open(self.ok_file(),'w') as f_out:
            f_out.write(str(time.time()) + "\n")
        self.cache_cmd()
    def last_ok_time(self):
        if not os.path.exists(self.ok_file()):
            return 0
        else:
            return float(open(self.ok_file()).read().strip())
    def cmd_file(self):
        #write the cmd if changed
        out_hash = self.hash_()
        snake_dir = DependencyGraphSingleton().snake_dir
        cache_file = "{snake_dir}/cmd_{out_hash}".format(**vars())
        return cache_file
    def cmd_cache_stale(self):
        cache_file = self.cmd_file()
        return not os.path.exists(cache_file) or self.cmd() != open(cache_file).read()
    def cache_cmd(self):
        cache_file = self.cmd_file()
        if self.cmd_cache_stale():
            with open(cache_file, 'w') as f_out:
                f_out.write(self.cmd())
    def execute(self):
        # print("Executing",self._in_nodes, self._out_nodes)
        # print(self._cmd)
        self.cache_cmd()
        return run(self._cmd)
    def in_nodes(self):
        return self._in_nodes
    def out_nodes(self):
        return self._out_nodes
    def all_nodes(self):
        return set(self._in_nodes).union(set(self._out_nodes))
    def cmd(self):
        return self._cmd
    def print_cmd(self):
        return self._print_cmd
    def __str__(self):
        return ", ".join([print_filename(f) for f in self._out_nodes]) + " <- " + ", ".join([print_filename(f) for f in self._in_nodes])


def process_filename(f):
    if is_tag(f): #tags
        return f
    else:
        return os.path.realpath(os.path.expanduser(f))

def print_filename(f):
    """Print the relative file if it's cleaner"""
    relfile = os.path.relpath(f)
    if not relfile.startswith(".."):
        return relfile
    else:
        return f


def define_rule(outfiles, infiles, cmd, options):
    #remove symlinks
    outfiles = [process_filename(f) for f in outfiles] #realpath follows symlinks
    infiles = [process_filename(f) for f in infiles]
    graph = DependencyGraphSingleton()
    timecheck = 1 * (options.get("timecheck","True") not in ["False","0"])
    cachecheck = 1 * (options.get("cachecheck","True") not in ["False","0"])
    rule = Rule(outfiles, infiles, cmd, timecheck, cachecheck)
    graph.add_rule(rule)


def get_all(force=False):
    graph = DependencyGraphSingleton()
    nodes_to_update = set()
    rule_evals = collections.OrderedDict()
    for r in graph.rules():
        if force:
            nodes_to_update.update(r.out_nodes())
            rule_evals[r] = (True, "forced")
        elif nodes_to_update.intersection(set(r.in_nodes())) and graph.tags_ok(r):
            nodes_to_update.update(r.out_nodes())
            rule_evals[r] = (True, "input file updating")
        else:
            boolean, reason = graph.check(r)
            if boolean:
                nodes_to_update.update(r.out_nodes())
            rule_evals[r] = (boolean, reason)
    for r in graph.rules():
        if not r in rule_evals:
            rule_evals[r] = (False, "not downstream")
    return rule_evals


def get_upstream(base_node, force=False):
    graph = DependencyGraphSingleton()
    upstream_nodes = set([base_node])
    for r in graph.rules()[::-1]:
        if upstream_nodes.intersection(set(r.out_nodes())):
            upstream_nodes.update(r.in_nodes())
    nodes_to_update = set()
    rule_evals = collections.OrderedDict()
    for r in graph.rules():
        if upstream_nodes.intersection(set(r.out_nodes())):
            if force:
                nodes_to_update.update(r.out_nodes())
                rule_evals[r] = (True, "forced")
            elif nodes_to_update.intersection(set(r.in_nodes())) and graph.tags_ok(r):
                nodes_to_update.update(r.out_nodes())
                rule_evals[r] = (True, "input file updating")
            else:
                boolean, reason = graph.check(r)
                if boolean:
                    nodes_to_update.update(r.out_nodes())
                rule_evals[r] = (boolean, reason)
    for r in upstream_nodes:
        if not r in rule_evals:
            rule_evals[r] = (False, "no change")
    for r in graph.rules():
        if not r in rule_evals:
            rule_evals[r] = (False, "not upstream")
    return rule_evals


def get_downstream(base_node, force=False):
    graph = DependencyGraphSingleton()
    downstream_nodes = set([base_node])
    for r in graph.rules():
        if downstream_nodes.intersection(set(r.in_nodes())):
            downstream_nodes.update(r.out_nodes())
    nodes_to_update = set()
    rule_evals = collections.OrderedDict()
    for r in graph.rules():
        if downstream_nodes.intersection(set(r.in_nodes())):
            if force:
                nodes_to_update.update(r.out_nodes())
                rule_evals[r] = (True,"force")
            elif nodes_to_update.intersection(set(r.in_nodes())) and graph.tags_ok(r):
                nodes_to_update.update(r.out_nodes())
                rule_evals[r] = (True,"input file updating")
            else:
                boolean, reason = graph.check(r)
                if boolean:
                    nodes_to_update.update(r.out_nodes())
                rule_evals[r] = (boolean, reason)
    for r in downstream_nodes:
        if not r in rule_evals:
            rule_evals[r] = (False, "no change")
    for r in graph.rules():
        if not r in rule_evals:
            rule_evals[r] = (False, "not downstream")
    return rule_evals


def get_exact(base_node, force=False):
    graph = DependencyGraphSingleton()
    rule_evals = collections.OrderedDict()
    for r in graph.rules():
        if not (base_node in r.out_nodes()):
            continue
        if force:
            rule_evals[r] = (True, "forced")
        else:
            boolean, reason = graph.check(r)
            rule_evals[r] = (boolean, reason)
    for r in graph.rules():
        if not r in rule_evals:
            rule_evals[r] = (False, "not considered")
    return rule_evals


def get_required_rules(expression):
    prefix, body = parse_rule_input(expression)
    body = process_filename(body)
    force = prefix.startswith("+")
    exact = "=" in prefix
    downstream = "^" in prefix
    upstream = not "^" in prefix
    if exact:
        return get_exact(body, force)
    elif upstream:
        return get_upstream(body, force)
    elif downstream:
        return get_downstream(body, force)
    else:
        raise Exception("ERROR: couldn't parse expression {expression}".format(**vars()))

def regex_arg(arg):
    """
    Example:
        >>> regex_arg("+=@.py")
        ["+=pawk.py","+=a.py"]
    Args:
        args: string
    Returns:
        list of
    """
    prefix, body = parse_rule_input(arg)
    if body[0] == "@": #"@" means regex
        regex = body[1:]
        matches = []
        for n in DependencyGraphSingleton().all_out_nodes():
            if re.findall(regex, n):
                matches.append(prefix + n)
        return matches
    else:
        return [arg]

def combine_rule_evals(list_of_rule_evals):
    output_rule_eval = collections.OrderedDict()
    for rule_eval in list_of_rule_evals:
        for r in rule_eval:
            run_status, reason = rule_eval[r]
            if run_status:
                output_rule_eval[r] = rule_eval[r]
            elif not r in output_rule_eval:
                output_rule_eval[r] = rule_eval[r]
    return output_rule_eval


def get_tags(expression):
    _, body = parse_rule_input(expression)
    if is_tag(body):
        yield body

def is_tag(f):
    return f.startswith("%")

def parse_rule_input(expression):
    regex_match = re.findall("^(\+?(\^|=)?)(.*)",expression)[0]
    prefix, body = regex_match[0], regex_match[2]
    return prefix, body

def sort_rules(rule_set):
    graph = DependencyGraphSingleton()
    for r in graph.rules():
        if r in rule_set:
            yield r


def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = itertools.tee(iterable)
    next(b, None)
    if sys.version_info > (3,0):
        zip_fn = zip
    else:
        zip_fn = itertools.izip
    return zip_fn(a, b)

def to_str(s):
    if sys.version_info >= (3,0):
        return str(s, "utf-8")
    else:
        return str(s)

def to_bytes(s):
    if sys.version_info >= (3,0):
        return bytes(s, "utf-8")
    else:
        return bytes(s)

def list_condition_blocks(l, fn):
    """
    Args:
      l: list
      fn: function acting on elements of l

    Returns:
      a list of lists. sublists are consecutive elements
      from l with the same fn value

    Example:
    list_condition_blocks(["a","b","abc","d","e"],len)
    returns
    [["a","b"],["abc"],["d","e"]]
    """
    x1 = l[0]
    blocks = [[x1]]
    for x1,x2 in pairwise(l):
        if x1 == None:
            raise Exception("ERROR: can't have None elements in l")
        if (x2 != None) and fn(x2) == fn(x1):
            blocks[-1].append(x2)
        else:
            blocks.append([x2])
    return blocks




def preprocess_snakefile(snakefile_string):
    def indent_depth(x):
        return len(x) - len(x.lstrip())

    def is_rule_def(x):
        return "<-" in x

    def subtract_indent_depth(x, depth):
        """
        Remove leading whitespace up to 'depth'
        """
        return x[depth:]

    def preprocess_bash(bash_str):
        """
        replace $[var] with eval(var)
        """
        regex1 = "(\$\[.*?\])"
        regex2 = "\$\[(.*?)\]"
        out = ""
        for piece in re.split(regex1, bash_str):
            match = re.findall(regex2, piece)
            if match:
                out += "+ str({0}) +".format(match[0])
            else:
                #use repr so when python exec's the code later it gets back the original string
                out += repr(piece)
        return out

    def take_all_rule_blocks(l):
        """
        take_fn(l) --> block, remaining_list

        take_chunks(l, take_fn) --> [block]
        """
        while l:
            block, l = take_rule_block(l)
            yield block


    def take_rule_block(l):
        block = [l[0]]
        if "<-" in l[0]:
            rule_depth = indent_depth(l[0])
            for i in l[1:]:
                depth_i = indent_depth(i)
                if "<-" in i or depth_i <= rule_depth:
                    break
                else:
                    block.append(i)
        else:
            for i in l[1:]:
                if "<-" in i:
                    break
                else:
                    block.append(i)
        block_len = len(block)
        return block, l[block_len:]

    def process_rule_string(rule_string):
        options = re.findall("\[(.*)\]\s*$",rule_string)
        if options:
            outputs, inputs, options = re.findall("(^.*)<-(.*)\[(.*)\]\s*$",rule_string)[0]
        else:
            outputs, inputs = re.findall("(^.*)<-(.*$)",rule_string)[0]
        outputs = [(o.strip()) for o in outputs.split(",")]
        inputs = [(i.strip()) for i in inputs.split(",")]
        if options:
            options = options.split()
            options = dict([(l.split(":")[0],l.split(":")[1]) for l in options])
        else:
            options = {}
        return outputs, inputs, options

    lines = snakefile_string.split('\n')
    lines = [l for l in lines if not l.strip().startswith("#")] #remove comments (TODO: handle comments at the end of lines)
    indent_levels = list_condition_blocks(lines, is_rule_def)
    skip = False
    for block in take_all_rule_blocks(lines):
        if "<-" in block[0]:
            outputs, inputs, options = process_rule_string(block[0])
            if "cmd" in options:
                cmd_str = options["cmd"]
            elif len(block) > 1:
                depth = indent_depth(block[1])
                cmd_str = preprocess_bash('\n'.join([subtract_indent_depth(l,depth) for l in block[1:] if l]))
            else:
                raise Exception("ERROR: parser couldn't find provided command!")
            outputs_string = "[" + ",".join(outputs) + "]"
            inputs_string = "[" + ','.join(inputs) + "]"
            leading_whitespace = re.findall("^\s*",block[0])[0]
            yield leading_whitespace + """define_rule({outputs_string},{inputs_string},{cmd_str},{options})""".format(**vars())
        else:
            for l in block:
                yield l


def get_input_fn():
    if sys.version_info > (3,0):
        return input
    else:
        return raw_input


def y_n_input(message):
    out = get_input_fn()(message)
    while True:
        if out.lower() == "y":
            return True
        elif out.lower() == "n":
            return False
        else:
            out = get_input_fn()("Please input y or n" + '\n')

def get_snake_dir(infile):
    infile_absolute_path = os.path.abspath(infile)
    infile_dir = os.path.dirname(infile_absolute_path)
    snake_dir = infile_dir + "/.snake"
    return snake_dir


def mark_rules_ok(rules):
    print("The following rules with be marked as ok:")
    print_rules(rules)
    confirm_or_die()
    for r in rules:
        r.mark_ok()

def run_rules(rules, verbose, rule_evals=None):
    print("---")
    for r in rules:
        print("Running: " + str(r) + '\n')
        if verbose:
            print(r.print_cmd())
            print("---")
        stdout, stderr, return_code = r.execute()
        if return_code not in [0,141]: #ignore 141, which is SIGPIPE error
            sys.stderr.write("ERROR in command: \n" + r.print_cmd() + '\n\n' + to_str(stderr) + '\n')
            if y_n_input("Erase output files from step that errored? [y/n]" + '\n'):
                for n in r.out_nodes():
                    if is_tag(n):
                        continue
                    if os.path.exists(n):
                        print("Removing... ", n)
                        os.remove(n)
            sys.exit(-1)
        print(to_str(stdout))
        print(to_str(stderr))

def print_rules_by_status(rules, rule_evals):
    print("The following steps will be run:")
    for i,r in enumerate(rules):
        run_status, reason = rule_evals[r]
        if run_status:
            print("  " + str(i+1)+":" + " " + str(r) + " " + "[reason: "+reason+"]")
    print()
    print("The following steps will *not* be run:")
    for i,r in enumerate(rules):
        run_status, reason = rule_evals[r]
        if not run_status:
            print("  " + str(i+1)+":" + " " + str(r) + " " + "[reason: "+reason+"]")

def print_rules(rules, rule_evals=None):
    for i,r in enumerate(rules):
        s = "  " + str(i+1)+":" + " " + str(r)
        if rule_evals:
            run_status, reason = rule_evals[r]
            s += " " + "[reason: " + reason + "]"
        print(s)

def confirm_or_die():
    if not y_n_input("Confirm? [y/n]" + "\n"):
        print("Exiting...")
        sys.exit(-1)


if __name__ == "__main__":
    infile, verbose, do_print_rules, args, ok = readCL()

    snake_dir = get_snake_dir(infile)
    if not os.path.exists(snake_dir):
        os.mkdir(snake_dir)

    with open(infile) as f_in:
        curr_dir = os.path.dirname(infile)
        if not curr_dir: curr_dir = "."
        os.chdir(curr_dir)
        parsed = '\n'.join(list(preprocess_snakefile(f_in.read())))
        # print(parsed)
        exec(parsed)

    #initialize DependencyGraph:
    graph = DependencyGraphSingleton()
    graph.setup(snake_dir)

    graph.active_tags = set([t for a in args for t in get_tags(a)])

    if not args:
        rule_evals = get_all()
    else:
        rule_evals = combine_rule_evals([get_required_rules(k) for a in args for k in regex_arg(a)])

    required_rules = [r for r,r_eval in rule_evals.items() if r_eval[0]]

    #print_rules
    if do_print_rules and not args:
        print_rules(graph.rules())
        sys.exit(0)

    if ok:
        mark_rules_ok(required_rules)
        sys.exit(-1)

    #run_rules
    if verbose:
        print_rules_by_status(graph.rules(), rule_evals)
    elif not required_rules:
        print("Nothing to do.")
        sys.exit(-1)
    else:
        print("The following steps will be run, in order:")
        print_rules(required_rules)
    confirm_or_die()

    run_rules(required_rules, verbose, rule_evals)
