#!/usr/bin/env python

def list_input_files(path):
    from functools import reduce
    from pathlib import Path
    path = Path(path)
    parts = path.parts
    glob_index = len(parts)
    for index, part in enumerate(parts):
        if "*" in part:
            glob_index = index
            break
    base_path = reduce(lambda x, y: Path(x).joinpath(y), parts[0:glob_index])
    glob = str(reduce(lambda x, y: Path(x).joinpath(y), parts[glob_index:], ''))

    if glob:
        files = [str(file) for file in base_path.glob(glob)]
    else:
        files = [base_path]
    return files


def main():
    import argparse
    import sys
    from csv import DictWriter
    from sof_utils import misc
    from pathlib import Path

    parser = argparse.ArgumentParser(
        description='Convert proximal femur labels from LabelStudio JSON_MIN format to csv format')
    parser.add_argument('input_label_file', type=str,
                        help='Path to input label json file(s). Multiple files can be used using glob pattern.')
    parser.add_argument('-V', '--version', action='store_true',
                        help="Print the version string")
    parser.add_argument('-f', '--file', type=str, default=None,
                        help='Write output to file instead to stdout')
    parser.add_argument('-s', '--skip-file', type=str, default=None,
                        help="Write skipped file names to the given file.")

    args = parser.parse_args()

    if args.version:
        import sof_utils
        print(sof_utils.__version__)
        exit(0)

    input_files = list_input_files(args.input_label_file)

    from itertools import chain
    labels_and_skipped = [misc.read_label_file(file) for file in input_files]
    # flatten lists
    labels = [label for label in chain.from_iterable([item[0] for item in labels_and_skipped])]
    skipped = [skip for skip in chain.from_iterable([item[1] for item in labels_and_skipped])]
    # remove for duplicates
    unique_labels = []
    filenames = set()
    for label in labels:
        if label["filename"] not in filenames:
            filenames.add(label["filename"])
            unique_labels.append(label)

    with (open(args.file, 'w') if args.file else sys.stdout) as fh:
        writer = DictWriter(fh, [name for name in labels[0].keys()])
        writer.writeheader()
        writer.writerows(labels)

    if args.skip_file:
        with open(args.skip_file, "w") as fh:
            fh.writelines([f"{item}\n" for item in skipped])


if __name__ == '__main__':
    main()
