#!/usr/bin/python
# -*- mode: python; coding: utf-8; -*-

"""
Parse input text into elementary discourse segments and output them

USAGE:
discourse_segmenter [GLOBAL_OPTIONS] type [TYPE_SPECIFIC_OPTIONS] [FILEs]

@author: Wladimir Sidorenko <Uladzimir Sidarenka>

"""

##################################################################
# Imports
from __future__ import print_function, unicode_literals
from dsegmenter.bparseg import BparSegmenter, CTree, read_trees, read_segments, \
    trees2segs
from dsegmenter.edseg import EDSSegmenter, CONLL

from collections import defaultdict
from sklearn.cross_validation import KFold
import argparse
import codecs
import glob
import sys
import os
import re

##################################################################
# Constants and Variables
DEFAULT_ENCODING = "utf-8"
ENCODING = DEFAULT_ENCODING
EDSEG = "edseg"
BPARSEG = "bparseg"
CV = "cv"
TEST = "test"
TRAIN = "train"
SEGMENT = "segment"
Segmenter = None
N_FOLDS = 10

##################################################################
# Methods
def _set_train_test_args(a_parser):
    """Add CLI options common to train and test mode to ArgumentParser instance.

    @param a_parser - ArgumentParser instance to which new arguments should
                      be added

    @return \c void

    """
    a_parser.add_argument("--bpar-sfx", help = """suffix of the names of BitPar files""", \
                              type = str, default = "")
    a_parser.add_argument("--seg-sfx", help = """suffix of the names of segmentation files""", \
                              type = str, default = "")
    a_parser.add_argument("bpar_dir", help="directory containing BitPar files", \
                              type = str)
    a_parser.add_argument("seg_dir", help="directory containing segmentation files", \
                              type = str)

def _read_files(a_files, a_encoding = DEFAULT_ENCODING, a_skip_line = ""):
    """Return iterator over lines of the input file.

    @param a_files - files to read from
    @param a_encoding - text encoding used for input/output
    @param a_skip_line - line which should be skipped during iteration

    @return iterator over input lines

    """
    if not a_files:
        for line in sys.stdin:
            line = line.decode(a_encoding)
            if line == a_skip_line:
                print(line.encode(a_encoding))
            else:
                yield line.rstrip()
    else:
        for fname in a_files:
            with codecs.open(fname, encoding = a_encoding, errors = "replace") as ifile:
                for line in ifile:
                    if line == a_skip_line:
                        print(line.encode(a_encoding))
                    else:
                        yield line.rstrip()

def _align_files(a_bpar_dir, a_seg_dir, a_bpar_sfx, a_seg_sfx):
    """Align BitPar and segment files in two directories.

    @param a_bpar_dir - directory containing files with BitPar trees
    @param a_seg_dir - directory containing files with discourse segments
    @param a_bpar_sfx - suffix of the names of BitPar files
    @param a_seg_sfx - suffix of the names of segmentation files

    @return iterator over list of 2-tuples with BitPar and segment file

    """
    segf = ""; basefname = "";
    BP_SFX_RE = re.compile(re.escape(a_bpar_sfx))
    bpar_files = glob.iglob(os.path.join(a_bpar_dir, '*' + a_bpar_sfx))
    for bpf in bpar_files:
        basefname = BP_SFX_RE.sub("", os.path.basename(bpf))
        segf = os.path.join(a_seg_dir, basefname + a_seg_sfx)
        if os.path.isfile(segf) and os.access(segf, os.R_OK):
            yield (bpf, segf)
        else:
            print(\
                "WARNING: No counterpart file found for BitPar file '{:s}'.".format(bpf), \
                    file = sys.stderr)

def _read_trees_segments(a_bpar_dir, a_seg_dir, a_bpar_sfx, a_seg_sfx, \
                             a_fname2item = False, a_encoding = DEFAULT_ENCODING):
    """Read input files containing discourse segments and BitPar trees.

    @param a_bpar_dir - directory containing files with BitPar trees
    @param a_seg_dir - directory containing files with discourse segments
    @param a_bpar_sfx - suffix of the names of BitPar files
    @param a_seg_sfx - suffix of the names of segmentation files
    @param a_fname2item - generate mappings from filenames to trees
    @param a_encoding - text encoding used for input/output

    @return 2-tuple with a list of segments and a list of trees

    """
    if a_fname2item:
        trees = defaultdict(list); segments = defaultdict(list)
    else:
        trees = []; segments = []
    ts, segs = trees, segments
    tree2seg = {}; toks2trees = {}; toks2segs = {};
    bpar_seg_files = _align_files(a_bpar_dir, a_seg_dir, a_bpar_sfx, a_seg_sfx)
    # do tree/segment alignment
    for bpf, segf in bpar_seg_files:
        if a_fname2item:
            ts, segs = trees[bpf], segments[segf]
        with codecs.open(bpf, 'r', encoding = a_encoding) as ibpf:
            toks2trees, _ = read_trees(ibpf)
        with codecs.open(segf, 'r', encoding = a_encoding) as isegf:
            toks2segs = read_segments(isegf)
        tree2seg = trees2segs(toks2trees, toks2segs)
        for t, s in tree2seg.iteritems():
            ts.append(t)
            segs.append(s)
    return (trees, segments)

def _output_segment_forrest(a_forrest, a_segmenter, a_output, a_encoding):
    """Split CONLL sentences in elementary discourse units and output them.

    @param a_forrest - pointer to CONLL forrest
    @param a_segmnter - pointer to discourse segmenter
    @param a_output - boolean flag indicating whether dependency trees
    should be printed
    @param a_encoding - text encoding used for output

    @return \c void

    """
    if a_forrest.is_empty():
        return
    else:
        if a_output:
            print(unicode(a_forrest).encode(a_encoding))
        sds_list = [a_segmenter.segment(sent) for sent in a_forrest]
        for sds in sds_list:
            sds.pretty_print(a_encoding = a_encoding)
        a_forrest.clear()

def edseg_segment(a_ilines, a_output_trees, a_encoding = DEFAULT_ENCODING):
    """Perform rule-based segmentation of CONLL dependency trees.

    @param a_ilines - iterator over input lines
    @param a_output_trees - boolean flag indicating whether dependency trees
    should be printed
    @param a_encoding - text encoding used for input/output

    @return \c void

    """
    forrest = CONLL()
    segmenter = EDSSegmenter()
    for line in a_ilines:
        if not line:
            # print collected sentences
            _output_segment_forrest(forrest, segmenter, a_output_trees, a_encoding)
            # output line
            print(line.encode(a_encoding))
        # otherwise, append the line to the CONLL forrest
        else:
            forrest.add_line(line)
            istart = True
    # output remained EDUs
    _output_segment_forrest(forrest, segmenter, a_output_trees, a_encoding)

def bparseg_segment(a_segmenter, a_ilines, a_encoding = DEFAULT_ENCODING, \
                    a_ostream = sys.stdout):
    """Perform machine-learning driven segmentation of BitPar constituency trees.

    @param a_segmenter - pointer to BitPar segmenter
    @param a_ilines - iterator over input lines
    @param a_encoding - text encoding used for input/output
    @param a_ostream - output stream

    @return \c void

    """
    segments = []
    for ctree in CTree.parse_lines(a_ilines):
        segments = a_segmenter.segment([ctree])
        print(u'\n'.join([unicode(s[-1]) for s in segments]).encode(a_encoding), \
              file = a_ostream)

def bparseg_test(a_segmenter, a_trees, a_segments):
    """Evaluate performance of segment classification.

    @param a_segmenter - pointer to BitPar segmenter
    @param a_trees - list of BitPar trees
    @param a_segments - list of discourse segments corresponding
    to BitPar trees

    @return \c void

    """
    macro_f1, micro_f1 = a_segmenter.test(a_trees, a_segments)
    print("Macro F1-score: {:.2%}".format(macro_f1), file = sys.stderr)
    print("Micro F1-score: {:.2%}".format(micro_f1), file = sys.stderr)


def _cnt_stat(a_gold_segs, a_pred_segs):
    """Estimate the number of true positives, false positives, and false negatives

    @param a_gold_segs - gold segments
    @param a_pred_segs - predicted segments

    @return 3-tuple with true positives, false positives, and false negatives

    """
    tp = fp = fn = 0
    for gs, ps in zip(a_gold_segs, a_pred_segs):
        gs = gs.lower(); ps = ps.lower()
        if gs == "none":
            if ps != "none":
                fp += 1
        elif gs == ps:
            tp += 1
        else:
            fn += 1
    return tp, fp, fn


def crossval(a_segmenter, a_path, a_fname2trees, a_fname2segs, \
             a_output = False, a_out_dir = ".", a_out_sfx = ".tree", \
             a_folds = N_FOLDS, a_encoding = ENCODING):
    """Train and evaluate model using n-fold cross-validation.

    @param a_segmenter - pointer to untrained segmenter instance
    @param a_path - path in which to store the model
    @param a_fname2trees - mapping from file names to trees
    @param a_fname2segs - mapping from file names to segments
    @param a_output - boolean flag indicating whether output files should be produced
    @param a_out_dir - directory for writing output files
    @param a_out_sfx - suffix which should be appended to output files
    @param a_folds - number of folds
    @param a_encoding - default output encoding

    @return 3-tuple containing list of macro F-scores, micro F-scores, and F1_{tp,fp}

    """
    # do necessary imports
    from sklearn.metrics import precision_recall_fscore_support
    from sklearn.externals import joblib
    import numpy as np

    # check conditions
    assert len(a_fname2trees) == len(a_fname2segs), \
        "Unmatching number of files with trees and segments."
    # make file names in `a_fname2trees` and `a_fname2segs` uniform and convert
    # segment classes to strings
    a_fname2segs = {os.path.splitext(os.path.basename(k))[0] + a_out_sfx: \
                    [str(iseg) for iseg in v] for k, v in a_fname2segs.iteritems()}
    ofname2ifname = {os.path.splitext(os.path.basename(k))[0] + a_out_sfx: k \
                     for k in a_fname2trees}
    a_fname2trees = {os.path.splitext(os.path.basename(k))[0] + a_out_sfx: v \
                     for k, v in a_fname2trees.iteritems()}
    # estimate the number of and generate folds
    fnames = a_fname2trees.keys()
    n_fnames = len(fnames)
    if n_fnames < 2:
        print("Insufficient number of samples for cross-validation: {:d}.".format(\
            n_fnames), file = sys.stderr)
        return -1
    folds = KFold(n_fnames, min(len(fnames), a_folds))
    # generate features for trees
    fname2feats = {fname: [a_segmenter.featgen(t) for t in trees] \
                   for fname, trees in a_fname2trees.iteritems()}
    # initialize auxiliary variables
    F1_macro = F1_micro = F1_tpfp = 0.
    macro_f1 = 0.; macro_F1s = []
    micro_f1 = 0.; micro_F1s = []
    best_macro_f1 = float("-inf")
    best_i = -1                 # index of the best run
    istart = ilen = 0
    trees = []
    pred_segs = []
    out_fnames = []
    fname2range = {}
    # fname2gld_pred = {}
    processed_fnames = {}
    in_fname = test_fname = out_fname = ""
    train_feats = train_segs = None
    test_feats = []; test_segs = []
    tp = fp = fn = tp_i = fp_i = fn_i = 0
    # iterate over folds
    for i, (train, test) in enumerate(folds):
        print("Fold: {:d}".format(i), file = sys.stderr)
        train_feats = [feat for k in train for feat in fname2feats[fnames[k]]]
        train_segs = [seg for k in train for seg in a_fname2segs[fnames[k]]]
        istart = 0
        for k in test:
            ilen = len(fname2feats[fnames[k]])
            fname2range[fnames[k]] = [istart, istart + ilen]
            istart += ilen
            test_feats += fname2feats[fnames[k]]
            test_segs += a_fname2segs[fnames[k]]
        # train classifier model
        a_segmenter.model = BparSegmenter.DEFAULT_PIPELINE
        a_segmenter.model.fit(train_feats, train_segs)
        # obtain new predictions
        pred_segs = a_segmenter.model.predict(test_feats)
        # update statistics and F1 scores
        tp_i, fp_i, fn_i = _cnt_stat(test_segs, pred_segs)
        tp += tp_i; fp += fp_i; fn += fn_i
        _, _, macro_f1, _ = precision_recall_fscore_support(test_segs, pred_segs, average = 'macro', \
                                                            pos_label = None)
        _, _, micro_f1, _ = precision_recall_fscore_support(test_segs, pred_segs, average = 'micro', \
                                                            pos_label = None)
        macro_F1s.append(macro_f1); micro_F1s.append(micro_f1)
        print("Macro F1: {:.2%}".format(macro_f1), file = sys.stderr)
        print("Micro F1: {:.2%}".format(micro_f1), file = sys.stderr)
        # update maximum macro F-score and store the most successful model
        if macro_f1 > best_macro_f1:
            best_i = i
            best_macro_f1 = macro_f1
            joblib.dump(a_segmenter.model, a_path)
        # generate new output files, if necessary
        if a_output:
            for k in test:
                test_fname = fnames[k]
                if test_fname in processed_fnames and processed_fnames[test_fname] > macro_f1:
                    continue
                processed_fnames[test_fname] = macro_f1
                # fname2gld_pred[test_fname] = [(test_segs[i], pred_segs[i]) \
                #                               for i in xrange(*fname2range[test_fname])]
                in_fname = ofname2ifname[test_fname]
                out_fname = os.path.join(a_out_dir, test_fname)
                with open(out_fname, "w") as ofile:
                    print("(TEXT", file = ofile)
                    bparseg_segment(a_segmenter, _read_files([in_fname], a_encoding), \
                                    a_encoding, ofile)
                    print(")", file = ofile)
        test_feats = []; test_segs = []
        del trees[:]; del out_fnames[:]; fname2range.clear()
    print("Average macro F1: {:.2%} +/- {:.2%}".format(np.mean(macro_F1s), np.std(macro_F1s)), \
          file = sys.stderr)
    print("Average micro F1: {:.2%} +/- {:.2%}".format(np.mean(micro_F1s), np.std(micro_F1s)), \
          file = sys.stderr)
    if tp or fp or fn:
        F1_tpfp = (2. * tp / (2. * tp + fp + fn))
    print("F1_{{tp,fp}} {:.2%}".format(F1_tpfp), file = sys.stderr)
    return (macro_F1s, micro_F1s, F1_tpfp, best_i)

def main(argv):
    """Read input text and segment it into elementary discourse units.

    @param argv - command line arguments

    @return \c 0 on success, non-\c 0 otherwise

    """
    # process arguments
    parser = argparse.ArgumentParser(description = """Script for segmenting text
into elementary discourse units.""")

    # define global options
    parser.add_argument("-e", "--encoding", help = "input encoding of text", nargs = 1, \
                            type = str, default = DEFAULT_ENCODING)
    parser.add_argument("-s", "--skip-line", help = """lines which should be ignored during the
processing and output without changes (defaults to empty lines)""", type = str, default = "")

    # add type-specific subparsers
    subparsers = parser.add_subparsers(help="type of discourse segmenter to use", dest = "dtype")

    # edgseg argument parser
    parser_edseg = subparsers.add_parser(EDSEG, help = "rule-based discourse segmenter for CONLL\
 dependency trees")
    parser_edseg.add_argument("-o", "--output-trees", help="output dependency trees along with\
 segments", action = "store_true")
    parser_edseg.add_argument("files", help="input files", nargs = '*', metavar="file")

    # bpar argument parser
    parser_bpar = subparsers.add_parser(BPARSEG, help = """machine-learning driven discourse\
 segmenter for BitPar constituency trees""")
    bpar_subparsers = parser_bpar.add_subparsers(help = "action to perform", dest = "mode")
    parser_bpar_train = bpar_subparsers.add_parser(TRAIN, help = """train new model on BitPar
 and segment files""")
    parser_bpar_train.add_argument("--cross-validate", help = "train model in cross-validation mode")
    parser_bpar_train.add_argument("model", help = "path to file in which to store the trained model", \
                                       type = str)
    _set_train_test_args(parser_bpar_train)

    parser_bpar_cv = bpar_subparsers.add_parser(CV, help = """train and evaluate model
    using cross-validation""")
    parser_bpar_cv.add_argument("-o", "--output-dir", help = """output directory (leave empty for
    no output)""", type = str, default = "")
    parser_bpar_cv.add_argument("model", help = """path to the file in which the best trained model
    should be stored""", type = str)
    _set_train_test_args(parser_bpar_cv)

    parser_bpar_test = bpar_subparsers.add_parser(TEST, help = """test model on BitPar
 and segment files""")
    parser_bpar_test.add_argument("-m", "--model", help = "path to file containing model", \
                                         type = str, default = BparSegmenter.DEFAULT_MODEL)
    _set_train_test_args(parser_bpar_test)

    parser_bpar_segment = bpar_subparsers.add_parser(SEGMENT, help = """split BitPar trees\
 into discourse units""")
    parser_bpar_segment.add_argument("-m", "--model", help = "path to file containing model", \
                                         type = str, default = BparSegmenter.DEFAULT_MODEL)
    parser_bpar_segment.add_argument("files", help="input files", nargs = '*', metavar="file")

    args = parser.parse_args()

    # process input files
    ifiles = []
    if hasattr(args, "files"):
        ifiles = _read_files(args.files, args.encoding, args.skip_line)

    # process input with edseg
    if args.dtype == EDSEG:
        edseg_segment(ifiles, args.output_trees)
    # process input with bparseg
    else:
        if args.mode == TRAIN or args.mode == CV:
            # make sure there is a directory for storing the model
            mdir = os.path.dirname(args.model)
            if mdir == '':
                pass
            elif os.path.exists(mdir):
                if not os.path.isdir(mdir) or not os.access(mdir, os.R_OK):
                    puts >> sys.stderr, "Can't write to directory '{:s}'.".format(mdir)
            else:
                os.makedirs(mdir)
            segmenter = BparSegmenter()
            trees, segments = _read_trees_segments(args.bpar_dir, args.seg_dir, args.bpar_sfx, \
                                                   args.seg_sfx, args.mode == CV, args.encoding)
            if args.mode == TRAIN:
                segmenter.train(trees, segments, args.model)
            else:
                crossval(segmenter, args.model, trees, segments, bool(args.output_dir), args.output_dir)
        else:
            assert os.path.exists(args.model) and os.path.isfile(args.model) and \
                os.access(args.model, os.R_OK), "Can't read model file '{:s}'.".format(args.model)
            segmenter = BparSegmenter(a_model = args.model)
            if args.mode == TEST:
                trees, segments = _read_trees_segments(args.bpar_dir, args.seg_dir, args.bpar_sfx, \
                                                           args.seg_sfx, args.encoding)
                bparseg_test(segmenter, trees, segments)
            else:
                bparseg_segment(segmenter, ifiles, args.encoding)

##################################################################
# Main
if __name__ == "__main__":
    main(sys.argv[1:])
