#!/usr/bin/env python
import argparse
import pandas as pd
import sys

def readCL():
    usagestr = "%prog"
    parser = argparse.ArgumentParser()
    parser.add_argument("-f","--infile",default=sys.stdin)
    parser.add_argument("-g","--group_by",help="csv list of columns to aggregate by", nargs="*")
    parser.add_argument("-a","--agg_fns",help="function to aggregate by. Options: 'sum','mean','max','strmax','min','strmin','std','cnt','pctile_73','val0','val1'", nargs="*", default=[])
    parser.add_argument("-c","--agg_cols",help="columns to aggregate", nargs="*")
    parser.add_argument("-d","--delimiter",default=",")
    parser.add_argument("--lam", nargs="*", default=[], help="function takes a dataframe and returns a value to add a new column or returns a subset of dataframe rows filter to keep only those rows. Example 'lambda x: x.iloc[0,:]'")
    parser.add_argument("--append",help="append computed columns to existing file", action="store_true")
    args = parser.parse_args()

    #the duplicate pivot uses two primitive functions, val0 and val1
    if "dup" in args.agg_fns:
        args.agg_fns.remove("dup")
        args.agg_fns.append("val0")
        args.agg_fns.append("val1")

    def parse_lam(s):
        if not s.strip().startswith("lambda"):
            return "lambda x:" + s
        else:
            return s
    args.lam = [eval(parse_lam(i)) for i in args.lam]
    
    return args.infile, args.group_by, args.agg_fns, args.agg_cols, args.lam, args.append, args.delimiter

def sum_fn(array):
    return sum([float(x) for x in array])
def mean_fn(array):
    return sum([float(x) for x in array]) / len([float(x) for x in array])
def max_fn(array):
    return max([float(x) for x in array])
def min_fn(array):
    return min([float(x) for x in array])
def strmax_fn(array):
    return max([str(x) for x in array])
def strmin_fn(array):
    return min([str(x) for x in array])
def std_fn(array):
    import numpy
    return numpy.std([float(x) for x in array])
def median_fn(array):
    import numpy
    return numpy.median([float(x) for x in array])
def pctile_fn(array, pctile):
    #pctile in range [0,100]
    import numpy
    return numpy.percentile([float(x) for x in array], pctile)
def cnt_fn(array):
    return len(array)
def val0_fn(array):
    if len(array)>0:
        return list(array)[0]
    else:
        return ""
def val1_fn(array):
    if len(array)>1:
        return list(array)[1]
    else:
        return ""

def aggstr_to_fn(agg_str):
    if agg_str == "sum":
        return sum_fn
    elif agg_str == "mean":
        return mean_fn
    elif agg_str == "max":
        return max_fn
    elif agg_str == "strmax":
        return strmax_fn
    elif agg_str == "min":
        return min_fn
    elif agg_str == "strmin":
        return strmin_fn
    elif agg_str == "median":
        return median_fn
    elif agg_str == "std":
        return std_fn
    elif agg_str == "cnt":
        return cnt_fn
    elif agg_str.startswith("pctile"):
        #agg_str = pctile_95 or pctile_5
        def fn(array):
            return pctile_fn(array, pctile)
        #rename function because function name is printed later
        fn.__name__ = "pctile_"+str(pctile) + "_fn"
        return fn
    elif agg_str == "val0":
        return val0_fn
    elif agg_str == "val1":
        return val1_fn
    else:
        # sys.stderr.write("WARNING: interpreting -a argument as a python lambda\n")
        # return eval(agg_str)
        raise Exception("ERROR: unknown aggregate string")

        
def groupby_agg(df_groups, cols, agg_fn_list):
    agg_dict = dict((c, agg_fn_list) for c in cols)
    df_out = df_groups.agg(agg_dict)
    #rename multiindex
    df_out.columns = [c[0] + "_" + (c[1].split("_",1))[0] if c[1] else c[0] for c in df_out.columns]
    return df_out

def df_agg(df, cols, agg_fn_list):
    df_out = pd.DataFrame()
    for c in cols:
        for agg_fn in agg_fn_list:
            agg_str = agg_fn.__name__.rsplit("_",1)[0]
            df_out.loc["0",c+"_"+agg_str] = df.loc[:,[c]].apply(agg_fn).values[0]
    return df_out

def groupby_apply_lambda(df_groups, lambda_list):
    #jtrigg@20150714: TODO try the lambda returning a dictionary syntax
    #like this?
    #http://stackoverflow.com/questions/15259547/conditional-sums-for-pandas-aggregate
    df_list = [df_groups.apply(l) for l in lambda_list]

    def rename_cols(df, i):
        if hasattr(df, "columns") and list(df.columns) == ["0"]:
            df.columns = ["lambda_" + str(i)]
        elif hasattr(df, "name"):
            df.name = "lambda_" + str(i)
        return df
    df_list = [rename_cols(df,i) for i,df in enumerate(df_list)]
    df_out = pd.concat(df_list,axis=1)
    if df_out.index.name in df_out.columns:
        df_out = df_out.reset_index(drop=True)
    else:
        df_out = df_out.reset_index()
    # print df_out
    return df_out
    
    
# def groupby_transform(df_out, df_groups, cols, agg_fn):
#     for c in cols:
#         df_out["_" + c + "_" + agg_str] = df_groups[c].transform(agg_fn)
#     return df_out

if __name__ == "__main__":
    infile, group_by_cols, agg_str_list, agg_cols, lambda_fn_list, append, delimiter = readCL()
    df = pd.read_csv(infile, delimiter=delimiter)

    #special treatment for cnt
    if not agg_cols and group_by_cols and agg_str_list == ["cnt"]:
        df_groups = df.groupby(group_by_cols)
        df_out = df_groups.size()
        # print dir(df_out)
        df_out = pd.DataFrame({"_".join(group_by_cols) + '_cnt' : df.groupby( group_by_cols ).size()}).reset_index()
        # df_out = df_out.rename(columns={'$a': 'a', '$b': 'b'})
        # print pd.DataFrame(df_out)
        # df_out.name = "cnt"
        # df_out.to_csv(sys.stdout, header=True)
    else:
        if not agg_cols:
            agg_cols = [d for d in df.columns.values if not d in group_by_cols]

        agg_fn_list = [aggstr_to_fn(a) for a in agg_str_list]


        if lambda_fn_list:
            df_groups = df.groupby(group_by_cols, as_index=True)
            df_out = groupby_apply_lambda(df_groups, lambda_fn_list)
        elif group_by_cols:
            df_groups = df.groupby(group_by_cols, as_index=False)
            df_out = groupby_agg(df_groups, agg_cols, agg_fn_list)
        else:
            df_out = pd.DataFrame(df_agg(df, agg_cols, agg_fn_list))

    if append:
        # if agg_str_list:
        df_out = df.merge(df_out, on=group_by_cols, how='left') #, suffixes = ('', '_' + agg_str_list[0]))

    df_out.to_csv(sys.stdout, index=False)
