#!/usr/bin/env python

import tensorflow_datasets as tfds
import SOF_hip
import tensorflow as tf


def convert_example(example):
    import sof_utils.dataset_utils as utils
    from hashlib import sha256

    image = example['image']
    encoded_image = tf.image.encode_png(image=image).numpy()
    image_key = sha256(encoded_image).hexdigest()
    bbox = example['object/bbox'].numpy()
    keypoints = example['object/keypoints'].numpy()
    keypoints_y = keypoints[..., 0].tolist()
    keypoints_x = keypoints[..., 1].tolist()

    labels = ["Complete", "Incomplete", "Implant"]
    class_text = labels[example['object/class'].numpy()]

    feature_dict = {
        'image/height': utils.int64_feature(image.shape[0]),
        'image/width': utils.int64_feature(image.shape[1]),
        'image/filename': utils.bytes_feature(example['image/filename'].numpy()),
        'image/source_id': utils.bytes_feature(
            f"{example['image/id'].numpy()}V{example['image/visit'].numpy()}{example['image/left_right'].numpy().decode('utf8')}".encode(
                'utf8')),
        'image/sha256': utils.bytes_feature(image_key.encode('utf8')),
        'image/encoded': utils.bytes_feature(encoded_image),
        'image/format': utils.bytes_feature("png".encode('utf8')),
        'image/upside_down': utils.int64_feature(example['image/upside_down'].numpy()),
        'image/object/class/label': utils.int64_list_feature([example['object/class'].numpy()]),
        'image/object/class/text': utils.bytes_list_feature([class_text.encode('utf8')]),
        'image/object/bbox/ymin': utils.float_list_feature([bbox[0]]),
        'image/object/bbox/xmin': utils.float_list_feature([bbox[1]]),
        'image/object/bbox/ymax': utils.float_list_feature([bbox[2]]),
        'image/object/bbox/xmax': utils.float_list_feature([bbox[3]]),
        'image/object/keypoint/y': utils.float_list_feature(keypoints_y),
        'image/object/keypoint/x': utils.float_list_feature(keypoints_x)
    }

    return tf.train.Example(features=tf.train.Features(feature=feature_dict))


def convert_to_TFObjectDetection(ds, out_file, num_shards=1):
    from tqdm import tqdm
    import contextlib2
    from sof_utils import tf_record_creation_util

    with contextlib2.ExitStack() as tf_record_close_stack:
        output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
            tf_record_close_stack, out_file, num_shards)
        for idx, example in tqdm(enumerate(ds), unit=' examples', desc='Converting dataset'):
            converted_example = convert_example(example)
            shard_idx = idx % num_shards
            output_tfrecords[shard_idx].write(converted_example.SerializeToString())


def main():
    import argparse
    import sys

    parser = argparse.ArgumentParser(
        description='Convert SOF_hip TFDS dataset into another format.')
    parser.add_argument('output_file', type=str,
                        help='Path to destination file.')
    parser.add_argument('--data_dir', type=str,
                        help='TFDS data dir.')
    parser.add_argument('-V', '--version', action='store_true',
                        help="Print the version string")
    parser.add_argument('--configuration', '-c', type=str, default='keypoint_detection',
                        choices=['keypoint_detection'],
                        help='Dataset configuration to use. Note: not all configuration are supported by all output format. Default is \'keypoint_detection\'')
    parser.add_argument('--format', '-f', type=str, default='TFObjectDetection',
                        choices=['TFObjectDetection'],
                        help='Dataset configuration to use. Note: not all configuration are supported by all output format. Default is \'TFObjectDetection\'')
    parser.add_argument('--split', '-s', type=str, default='train',
                        help='Split to convert. Default to \'train\'')
    parser.add_argument('--num_shards', '-n', type=int, default=1,
                        help='Number of shards to split the resulting dataset into.')

    args = parser.parse_args()

    if args.version:
        import sof_utils
        print(sof_utils.__version__)
        exit(0)

    ds = tfds.load(f"SOF_hip/{args.configuration}", data_dir=args.data_dir)[args.split]

    if args.format == 'TFObjectDetection':
        convert_to_TFObjectDetection(ds, args.output_file, args.num_shards)
    else:
        print(f"Error: Unknown format: {args.format}", file=sys.stderr)
        exit(1)


if __name__ == '__main__':
    main()
