#!/usr/bin/env python

from typing import Union, Set
import tensorflow as tf


def ids_from_file(filename: Union[None, str]) -> Set[int]:
    """ Reads IDs from a file line by line and return them as a set of integers.
    :param filename: Path to file containing IDs
    :return: set of integer IDs
    """
    listed_ids = set()
    if not filename:
        return listed_ids

    with open(filename, 'r') as f:
        for sof_id in f.read().split():
            listed_ids.add(int(sof_id))
    return listed_ids


def load_model(model_path):
    import time

    PATH_TO_SAVED_MODEL = model_path + "/saved_model"

    print('Loading model...', end='')
    start_time = time.time()

    # Load saved model and build the detection function
    detect_fn = tf.saved_model.load(PATH_TO_SAVED_MODEL)

    end_time = time.time()
    elapsed_time = end_time - start_time
    print('Done! Took {} seconds'.format(elapsed_time))
    return detect_fn


def preprocess_image(image):
    padding = image.shape[1] % 2
    image = tf.image.pad_to_bounding_box(image, 0, 0, image.shape[0], image.shape[1] + padding)
    image = tf.cast(image, tf.uint8).numpy()

    assert image.shape[1] % 2 == 0

    half_width = int(image.shape[1] / 2)
    left_image = image[:, 0:half_width, ...]
    right_image = tf.image.flip_left_right(image[:, half_width:, ...]).numpy()

    left_image = tf.repeat(left_image, 3, -1)[tf.newaxis, ...]
    right_image = tf.repeat(right_image, 3, -1)[tf.newaxis, ...]

    return left_image, right_image


def detect_keypoints(image, detection_fn):
    left_img, right_img = preprocess_image(image)
    left_detections = detection_fn(left_img)
    right_detections = detection_fn(right_img)

    return left_detections, right_detections


def rel_x_to_right(rel_x, w):
    from math import floor

    padded = int((w % 2) != 0)

    downscaled_max_x = floor(w / 2)
    flipped_abs_x = rel_x * downscaled_max_x - padded
    abs_x = downscaled_max_x - flipped_abs_x
    abs_x += downscaled_max_x
    x = abs_x / (w - 1)
    return x


def rel_x_to_left(rel_x, w):
    from math import floor

    max_x = floor(w / 2)
    abs_x = rel_x * max_x
    x = abs_x / (w - 1)
    return x


def postproecess_detections(detections, left_right, width):
    import numpy as np

    score = tf.squeeze(detections['detection_scores']).numpy()
    label = tf.squeeze(detections['detection_classes']).numpy()
    keypoints = tf.squeeze(detections['detection_keypoints']).numpy()

    if label != 1:
        keypoints = np.zeros_like(keypoints)

    for i in range(12):
        if left_right == 'left':
            keypoints[i, 1] = rel_x_to_left(keypoints[i, 1], width)
        else:
            keypoints[i, 1] = rel_x_to_right(keypoints[i, 1], width)

    return label, keypoints, score


def label_to_text(label):
    if label == 1:
        return "Complete"
    elif label == 2:
        return "Incomplete"
    elif label == 3:
        return "Implant"
    else:
        return "Unknown"


def is_upside_down(left_label, left_kpts, left_score, right_label, right_kpts, right_score):
    def get_orientation(kpts):
        return kpts[3, 0] - kpts[0, 0] > 1e-3

    if left_label == 1 and right_label == 1:
        if left_score > right_score:
            return get_orientation(left_kpts)
        else:
            return get_orientation(right_kpts)
    elif left_label == 1 and right_label != 1:
        return get_orientation(left_kpts)
    elif left_label != 1 and right_label == 1:
        return get_orientation(right_kpts)
    else:
        # Both labels are not "Complete" => We don't have any key points to detect the orientation
        return False


def flip_img_and_kpts(image, left_label, left_kpts, left_score, right_label, right_kpts, right_score):
    image = tf.image.flip_left_right(tf.image.flip_up_down(image))
    (left_label, right_label) = (right_label, left_label)
    (left_score, right_score) = (right_score, left_score)
    (left_kpts, right_kpts) = (1 - right_kpts, 1 - left_kpts)

    return image, left_label, left_kpts, left_score, right_label, right_kpts, right_score


def put_text(img, text, color, lr, y_offset=0):
    import cv2
    import numpy as np
    x_offset = np.ceil(img.shape[-2] / 8)
    if lr == 'right':
        x_offset *= 5
    result = cv2.putText(img, str(text), (int(x_offset), 200 + y_offset), cv2.FONT_HERSHEY_PLAIN, 10, color, 10)
    return result


def draw_keypoint(img, kpts, color, small=False):
    import cv2

    for i in range(kpts.shape[0]):
        center = (int(kpts[i][1] * img.shape[1]), int(kpts[i][0] * img.shape[0]))
        if not small:
            img = cv2.circle(img, center, 15, color, 10)
        else:
            img = cv2.circle(img, center, 1, color, 1)

    return img


def save_visualization(example, left_label, left_kpts, left_score, right_label, right_kpts, right_score, upside_down,
                       preview_dir):
    import numpy as np
    from pathlib import Path
    out_image = tf.cast(tf.repeat(example['image'], 3, -1), tf.float32) / 255
    if right_label == 1:
        right_color = (0, 1, 0)
    elif right_label == 2:
        right_color = (1, 0, 0)
    elif right_label == 3:
        right_color = (0, 0, 1)
    if left_label == 1:
        left_color = (0, 1, 0)
    elif left_label == 2:
        left_color = (1, 0, 0)
    elif left_label == 3:
        left_color = (0, 0, 1)

    if upside_down == 1:
        out_image = tf.image.flip_left_right(tf.image.flip_up_down(out_image))
    out_image = out_image.numpy()

    small = min(out_image.shape[0], out_image.shape[1]) < 500

    if left_label == 1:
        out_image = draw_keypoint(out_image, left_kpts, left_color, small)
    if right_label == 1:
        out_image = draw_keypoint(out_image, right_kpts, right_color, small)

    if not small:
        out_image = put_text(out_image, label_to_text(right_label), right_color, 'right')
        out_image = put_text(out_image, str(right_score), right_color, 'right', 150)
        out_image = put_text(out_image, label_to_text(left_label), left_color, 'left')
        out_image = put_text(out_image, str(left_score), left_color, 'left', 150)

    out_image = tf.convert_to_tensor(out_image)
    out_image = tf.cast(out_image * 255, tf.uint8)

    encoded_image = tf.image.encode_png(out_image)
    if isinstance(example['image/id'], tf.Tensor):
        filename = f"{example['image/id'].numpy()}V{example['image/visit'].numpy()}-annotated.png"
    else:
        filename = f"{example['image/id']}V{example['image/visit']}-annotated.png"
    tf.io.write_file(str(Path(preview_dir).joinpath(filename)), encoded_image)


def process_example(example, detection_fn, preview_dir=None, filename=None):
    if not isinstance(example, tf.Tensor):
        image = example['image']
    else:
        image = example
        example = {
            'image': image,
            'image/id': filename,
            'image/visit': 'NA'
        }

    left_detections, right_detections = detect_keypoints(image, detection_fn)
    left_label, left_kpts, left_score, = postproecess_detections(left_detections, 'left', image.shape[1])
    right_label, right_kpts, right_score = postproecess_detections(right_detections, 'right', image.shape[1])

    upside_down = 0
    if is_upside_down(left_label, left_kpts, left_score, right_label, right_kpts, right_score):
        upside_down = 1
        image, left_label, left_kpts, left_score, right_label, right_kpts, right_score = flip_img_and_kpts(image,
                                                                                                           left_label,
                                                                                                           left_kpts,
                                                                                                           left_score,
                                                                                                           right_label,
                                                                                                           right_kpts,
                                                                                                           right_score)

    row = {
        'id': example['image/id'].numpy() if isinstance(example['image/id'], tf.Tensor) else example['image/id'],
        'visit': example['image/visit'].numpy() if isinstance(example['image/visit'], tf.Tensor) else example[
            'image/visit'],
        'width': image.shape[1],
        'height': image.shape[0],
        'upside_down': upside_down,
        'left_class': label_to_text(right_label),
        'left_score': right_score,
        'right_class': label_to_text(left_label),
        'right_score': left_score
    }
    for i in range(12):
        row[f"left_kp{i}x"] = right_kpts[i, 1]
        row[f"left_kp{i}y"] = right_kpts[i, 0]
    for i in range(12):
        row[f"right_kp{i}x"] = left_kpts[i, 1]
        row[f"right_kp{i}y"] = left_kpts[i, 0]

    if preview_dir:
        save_visualization(example, left_label, left_kpts, left_score, right_label, right_kpts, right_score,
                           upside_down, preview_dir)

    return row


def main():
    import argparse
    import tensorflow_datasets as tfds
    from csv import DictWriter
    import sys
    from tqdm import tqdm
    from pathlib import Path

    parser = argparse.ArgumentParser(
        description='Detect keypoints on hip radiographs. Important: the TFDS SOF_hip package must be in the python '
                    'path.')
    parser.add_argument('model_path', type=str,
                        help='Path to saved detection model.')
    parser.add_argument('-f', '--file', type=str, default=None,
                        help='Write CSV data to the given file instead of stdout.')
    parser.add_argument('-V', '--version', action='store_true',
                        help="Print the version string")
    parser.add_argument('--data_dir', '-d', type=str, default=None,
                        help='Path to the TFDS dataset.')
    parser.add_argument('--configuration', '-c', type=str, default='unsupervised_raw',
                        choices=['unsupervised_raw', 'unsupervised_raw_tiny'],
                        help='Dataset configuration.')
    parser.add_argument('--preview-dir', '-p', type=str,
                        help='If given, path to write images with visualizations of keypoints to.')
    parser.add_argument('--from-images', '-i', type=str, default=None,
                        help='Detect keypoint from png images instead of the SOF_hip TFDS dataset')

    args = parser.parse_args()

    if args.version:
        import sof_utils
        print(sof_utils.__version__)
        exit(0)

    detection_fn = load_model(args.model_path)

    if not args.from_images:
        import SOF_hip
        ds_name = f"SOF_hip/{args.configuration}"
        ds = tfds.load(ds_name, split='train', data_dir=args.data_dir if args.data_dir else None)

        rows = [process_example(example, detection_fn, args.preview_dir) for
                example in tqdm(ds, desc="Detecting keypoints ", unit=' images')]
    else:
        image_files = [f for f in Path(args.from_images).glob('*.png')]

        rows = []

        for file in image_files:
            encoded_img = tf.io.read_file(str(file))
            img = tf.image.decode_png(encoded_img, 1)
            row = process_example(img, detection_fn, args.preview_dir, filename=file.stem)
            rows.append(row)

    op_file = lambda: open(args.file, "w") if args.file else sys.stdout

    keys = rows[0].keys()

    with op_file() as fh:
        writer = DictWriter(fh, fieldnames=keys)
        writer.writeheader()
        writer.writerows(rows)


if __name__ == '__main__':
    main()
