#!/usr/bin/env python
import pandas as pd
import numpy
import argparse
import sys
import itertools
import csv
import resource

def readCL():
    parser = argparse.ArgumentParser()
    parser.add_argument("-k","--key",help="comma separated list of key columns")
    parser.add_argument("--key2",help="csv list of columns from file2")
    parser.add_argument("--outer",action="store_true")
    parser.add_argument("--left",action="store_true")
    parser.add_argument("--right",action="store_true")
    parser.add_argument("--keep1",help="comma separated list of column names",default="")
    parser.add_argument("--keep2",help="comma separated list of column names",default="")
    parser.add_argument("--default",help="default fill value for NULL in left, right, outer joins")
    parser.add_argument("--chunksize", type=int, help="files won't fit in memory: join piece by piece. will pre-sort files unless you use --sorted")
    parser.add_argument("--presorted",action="store_true", help="files already (string) sorted by keys before joining, so don't need to use unix sort")
    parser.add_argument("files", nargs="*")
    args = parser.parse_args()
    if args.outer:
        join_type = "outer"
    elif args.left:
        join_type = "left"
    elif args.right:
        join_type = "right"
    else:
        join_type = "inner"

    if args.keep1:
        keep1 = args.keep1.split(',')
    else:
        keep1 = None

    if args.keep2:
        keep2 = args.keep2.split(',')
    else:
        keep2 = None

    #assume stdin if unspecified
    if len(args.files) == 1:
        args.files = [sys.stdin] + args.files

    if args.key:
        keys1 = args.key.split(',')
    else:
        keys1 = None
    if args.key2:
        keys2 = args.key2.split(',')
    else:
        keys2 = None

    return args.files, keys1, keys2, join_type, keep1, keep2, args.default, args.chunksize, args.presorted




def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = itertools.tee(iterable)
    next(b, None)
    return itertools.izip_longest(a, b)


def readcsv(f):
    #TODO: add keep here to keep memory down
    hdr = None
    for line in csv.reader(f):
        if not hdr:
            hdr = line
        else:
            yield OrderedDict(zip(header, line))


class CsvKeyReader():
    def __init__(self, fin, keys):
        read = csv.reader(fin)
        self.hdr = read.next()
        self.reader = pairwise(read)
        self.key_indices = [self.hdr.index(k) for k in keys]
        self.topline, self.nextline = self.reader.next()
        self.topval = [self.topline[k] for k in self.key_indices]
        self.nextval = [self.nextline[k] for k in self.key_indices]
    def pop(self):
        #save value to be returned, then update
        out = self.topline
        #manually define last step
        if (self.nextline is not None):
            self.topline, self.nextline = self.reader.next()
            self.topval = [self.topline[k] for k in self.key_indices]
        else:
            self.topline, self.nextline = (None, None)
            self.topval = None
            self.nextval = None

        if (self.nextline is not None):
            self.nextval = [self.nextline[k] for k in self.key_indices]
            # print self.topval, self.nextval
            if self.topval >= self.nextval:
                #NOTE: current popping pattern doesn't support multiple keys with same value
                raise Exception("ERROR: key columns not string sorted in ascending order. {:s} >= {:s}".format(self.topval, self.nextval))
        else:
            self.nextval = None
        return out
    def at_breakpt(self):
        return self.topval != self.nextval
    

def check_memory_usage():
    #http://stackoverflow.com/questions/938733/total-memory-used-by-python-process
    #in bytes
    import resource
    return 1000*resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
                    

class MergeCsvReader():
    def __init__(self, fin1, keys1, fin2, keys2, chunksize):
        self.f1 = CsvKeyReader(fin1, keys1)
        self.f2 = CsvKeyReader(fin2, keys2)
        self.pop_cnt = 0
        self.chunksize = chunksize
    def __iter__(self):
        return self
    def next(self):
        """
        pops from file1 and file2 a total number of lines
        approximately equal to chunksize (+- 2) and returns
        dataframes df1 and df2 from those files respectively
        """
        df1_rows = []
        df2_rows = []
        cnt = 0
        end = self.pop_cnt + chunksize
        while True:
            cnt += 1
            r1, r2 = self.pop_one()
            df1_rows.append(r1)
            df2_rows.append(r2)
            if r1 is None and r2 is None:
                break
            if self.pop_cnt > end:
                # sys.stderr.write("blockdone" + '\n')
                # sys.stderr.flush()
                break

        df1_rows = [r for r in df1_rows if not r is None]
        df2_rows = [r for r in df2_rows if not r is None]

        if df1_rows:
            df1 = pd.DataFrame(df1_rows, columns=self.f1.hdr)
        else:
            df1 = pd.DataFrame(columns=self.f1.hdr)

        if df2_rows:
            df2 = pd.DataFrame(df2_rows, columns=self.f2.hdr)
        else:
            df2 = pd.DataFrame(columns=self.f2.hdr)
        if not df1_rows and not df2_rows:
            raise StopIteration
        else:
            return df1,df2

    def pop_one(self):
        if (self.f1.topline is None and self.f2.topline is None):
            return [None, None]
        elif self.f1.topline is None:
            self.pop_cnt += 1
            return [None,self.f2.pop()]
        elif self.f2.topline is None:
            self.pop_cnt += 1
            return [self.f1.pop(),None]
        else:
            if self.f1.topval < self.f2.topval:
                self.pop_cnt += 1
                return [self.f1.pop(),None]
            elif self.f1.topval > self.f2.topval:
                self.pop_cnt += 1
                return [None,self.f2.pop()]
            else:
                #when equal pop both to preserve constraint:
                #at all times, for some value C
                #all values from both files <=C have been
                #returned
                self.pop_cnt += 2
                return [self.f1.pop(), self.f2.pop()]


            
def check_dup(df, key_list):
    if not key_list: return
    for col in key_list:
        if df.duplicated(col).any():
            sys.stderr.write("WARNING: duplicate key in column " + str(repr(col)) + "\n")


def join(df1, df2, keys, keys2, join_type, default):
    if not keys:
        v1 = set(df1.columns.values)
        v2 = set(df2.columns.values)
        join_keys = v1.intersection(v2)
        sys.stderr.write("WARNING: joining on " + str(join_keys) + '\n')
        check_dup(df1, join_keys)
        check_dup(df2, join_keys)
        merged = df1.merge(df2, sort=False,how=join_type)
    # b = b.dropna(axis=1)
    # print "A",a
    # print "B",b
    else:
        check_dup(df1, keys)
        check_dup(df2, keys2)
        merged = df1.merge(df2, left_on=keys, right_on=keys2, sort=False, how=join_type)
    if default:
        merged.fillna(default, inplace=True)
    return merged

def join_all(df_list, keys, keys2, join_type, default):
    merged = df_list[0]
    for df in df_list[1:]:
        merged = join(merged, df, keys, keys2, join_type, default)
    return merged

def cat_dfs(partial_dfs):
    hdr = False
    for df in partial_dfs:
        if not hdr:
            df.to_csv(sys.stdout, index=False)
            hdr = True
        else:
            df.to_csv(sys.stdout, mode="a", index=False, header=None) #jtrigg@20151101: changing header=False to header=None. Not sure what header=False does, but header=None assumes there is no header line in the file

            
if __name__ == "__main__":
    filelist, keys, keys2, join_type, keep1, keep2, default, chunksize, presorted = readCL()

    if keys:
        if not keys2:
            keys2 = keys

    if chunksize:
        hdr = False
        for df_list in MergeCsvReader(open(filelist[0]), keys, open(filelist[1]), keys2, chunksize):
            df = join_all(df_list, keys, keys2, join_type, default)
            if not hdr:
                df.to_csv(sys.stdout, index=False)
                hdr = True
            else:
                df.to_csv(sys.stdout, mode="a", index=False, header=None) #jtrigg@20151101: changing header=False to header=None. Not sure what header=False does, but header=None assumes there is no header line in the file
        sys.exit(-1)


        
    if keys:
        names1 = None
        if keep1:
            names1 = keys[:] + keep1

        names2= None
        if keep2:
            names2 = keys2[:] + keep2

        #list of dataframes to join
        #first one uses names1, others use names2
        df_list = [pd.read_csv(filelist[0], dtype=numpy.str_, usecols = names1)] + \
                  [pd.read_csv(infile, dtype=numpy.str_, usecols = names2) for infile in filelist[1:]]

    else:
        df_list = [pd.read_csv(infile, dtype=numpy.str_) for infile in filelist]


    merged = join_all(df_list, keys, keys2, join_type, default)
    merged.to_csv(sys.stdout, index=False)
