#!/usr/bin/env python3
#
# README
#
# This script assumes:
# --------------------
#
# 1. `STACK_COMPANION=1` was set in `/etc/indy/indy_config.py` on the node where
#     the recording took place.
# 2. The 'recording' passed as the argument to this script has been captured by
#    running `nscapture` AFTER stopping the node (indy-node service). There are
#    plans to research how node state (recording in this case) can be captured
#    without bringing down the indy-node service.
#
# Installation:
# -------------
#
# If this script is used on a server or workstation that has indy-node installed
# and indy-node is known to be working, you do not need to install anything.
# Otherwise, install indy-node and it's dependencies before running this script.
#
# TODO:
#
# 1. Create an issue in indy-node stating that the following must be done on
#    macos for pbc to compile (pbc has a dependency on openssl headers)
#    > cd /usr/local/include
#    > ln -s ../opt/openssl/include/openssl .
#    The above fix was taken from:
#    https://www.anintegratedworld.com/mac-osx-fatal-error-opensslsha-h-file-not-found/

from indy_common.config_helper import NodeConfigHelper
from indy_common.config_util import getConfig
from indy_node.server.config_helper import create_config_dirs
from indy_node.server.node import Node
from os import listdir, makedirs
from plenum.common.constants import CLIENT_STACK_SUFFIX, KeyValueStorageType
from plenum.recorder.recorder import Recorder
from plenum.recorder.replayable_node import create_replayable_node_class
from plenum.recorder.replayer import get_recorders_from_node_data_dir, \
    prepare_node_for_replay_and_replay
from plenum.server.replicas import Replicas
from plenum.server.replica import Replica
from plenum.server.general_config import ubuntu_platform_config
from storage.helper import initKeyValueStorageIntKeys
from stp_core.common.log import Logger
from stp_core.loop.looper import Looper
from stp_core.types import HA
from typing import Tuple

import argparse
import importlib
import json
import logging
import os
import shutil
import sys
import tarfile
import tempfile
import zipfile

logger = logging.getLogger()


def get_system_dirs():
    # TODO detect the system someday (help support windows)
    return ubuntu_platform_config


# TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
#       Move to common module?
class Error(Exception):
    """Base class for exceptions in this module."""
    pass


# TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
#       Move to common module?
class ArchiveError(Error):
    """Exception raised for errors encounterd while identifying and extracting
    tarfiles and zipfiles.

    Attributes:
        message -- explanation of the error
    """

    def __init__(self, message):
        self.message = message


class NodeStateReplayer:
    logger = None
    cleanup = True
    replay_dir = None
    log_level = logging.WARNING

    s1 = "zipfile/tarfile"
    s1tempdir = None
    s1_is_temp = False
    s1_dirname = None
    s1_filename = None
    s1_name = None
    s1_extension = None

    def __init__(self, log_level=0, cleanup=True, replay_dir=None):
        logger.setLevel(log_level)
        self.cleanup = cleanup
        if replay_dir:
            path = os.path.abspath(replay_dir)
            os.makedirs(path, exist_ok=True)
            self.replay_dir = path
        logger.debug("Initializing NodeStateReplayer...")

    # TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
    #       Create base class with abstract _cleanup function?
    def _cleanup(self):
        logger.debug("Cleaning up...")
        # Clean-up temp dirs if tar/zip unpacked using a temp dir
        if self.s1_is_temp:
            if self.cleanup:
                logger.info("Removing temp dir {} containing {} contents"
                            .format(self.s1tempdir, self.s1))
                shutil.rmtree(self.s1tempdir)
            else:
                logger.debug("Skipping removal of temp directory {}".format(
                    self.s1dir))
                print("Skipping removal of temp directory {}".format(
                    self.s1dir))

    # TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
    #       Create base class with _eprint function?
    # Print an error message to stderr and conditionally exit
    def _eprint(self, *args, **kwargs):
        doexit = kwargs.pop('exit', None)
        print(*args, file=sys.stderr, **kwargs)
        if doexit:
            self._cleanup()
            exit(1)

    # TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
    #       Create base class with print_zipfile function?
    # Print the contents of a zipfile root directory. Does not traverse
    # directories.
    def print_zipfile(self, filename):
        f = zipfile.ZipFile(filename)
        for info in f.infolist():
            logger.debug(info.filename)

    # TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
    #       Create base class with print_directory function?
    # Print the contents of a directory
    def print_directory(self, d):
        only_files = [f for f in listdir(d) if os.path.isfile(os.path.join(d, f))]
        only_dirs = [f for f in listdir(d) if os.path.isdir(os.path.join(d, f))]
        logger.debug("In {}:".format(d))
        for file in only_files:
            logger.debug("\t{}".format(file))
        for dir_ent in only_dirs:
            self.print_directory(os.path.join(d, dir_ent))

    # TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
    #       Create base class with _extract_tarball function?
    # Extract a tarball
    def _extract_tarball(self, file):
        tempdir = None
        logger.debug("Extracting tarball {}".format(file))
        if self.log_level <= logging.INFO:
            # TODO create print_tarfile method
            # self.print_tarfile(file)
            pass
        tempdir = tempfile.mkdtemp()
        if tempdir:
            with tarfile.open(file, "r") as tar:
                tar.extractall(tempdir)
        else:
            s = "Failed to create tempdir in which to unpack {}"
            raise ArchiveError(s.format(file))
        return tempdir

    # TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
    #       Create base class with _extract_zipfile function?
    # Extract a zipfile
    def _extract_zipfile(self, file):
        tempdir = None
        logger.debug("Extracting zipfile {}".format(file))
        if self.log_level <= logging.INFO:
            self.print_zipfile(file)
        tempdir = tempfile.mkdtemp()
        if tempdir:
            with zipfile.ZipFile(file, 'r') as zip_f:
                zip_f.extractall(tempdir)
        else:
            s = "Failed to create tempdir in which to unpack {}"
            raise ArchiveError(s.format(file))
        return tempdir

    # TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
    #       Create base class with _extract_archive function?
    # Extract a tarball/zipfile into a temprary directory
    def _extract_archive(self, file):
        logger.debug("Extracting archive {}".format(file))
        # Is the file a tarfile (tar, gz2, zip, tar.gz, etc.) or zipfile?
        if tarfile.is_tarfile(file):
            return self._extract_tarball(file)
        elif zipfile.is_zipfile(file):
            return self._extract_zipfile(file)
        else:
            raise ArchiveError("{} must be a tarfile or a zipfile".format(file))

    def split_name(self, file):
        abs_file = os.path.abspath(file)
        base_name = os.path.basename(abs_file)
        base_name_elements = base_name.split('.')
        file_name = None
        file_extension = None
        if len(base_name_elements) > 0:
            file_name = base_name_elements[0]
        if len(base_name_elements) > 1:
            file_extension = ".".join(base_name_elements[1:])
        return os.path.dirname(abs_file), base_name, file_name, file_extension

    def unpack_recording(self, s1=None):
        # Validate inputs
        validation_errors = []

        # Capture the name of s1
        self.s1 = s1
        (self.s1_dirname, self.s1_filename, self.s1_name,
         self.s1_extension) = self.split_name(s1)

        # s1 must either be a tarball, zipfile, or a directory
        if os.path.isfile(s1):
            # If s1 is a file, it must be a tarball or a zipfile
            # Extract s1 contents into a temporary directory
            try:
                tempdir = self._extract_archive(s1)
            except ArchiveError as e:
                validation_errors.append(e)
            else:
                self.s1tempdir = tempdir
                self.s1dir = tempdir
                self.s1_is_temp = True
        elif os.path.isdir(s1):
            if self.log_level <= logging.INFO:
                self.print_directory(s1)
            self.s1dir = s1
        else:
            s = "{} must be a existing tarfile/zipfile or a directory."
            validation_errors.append(s.format(s1))

        # Emit validation errors and exit
        if validation_errors:
            for validation_error in validation_errors:
                self._eprint(validation_error)
            self._cleanup()
            exit(1)

        return

    def get_updated_config(self):
        # Update current config with config file from the recording zip
        return getConfig()

    def get_recorders_from_node_data_dir(self, node_data_dir, node_name) -> Tuple[Recorder, Recorder]:
        node_rec_path = os.path.join(node_data_dir, node_name, 'recorder')
        client_stack_name = node_name + CLIENT_STACK_SUFFIX
        client_rec_path = os.path.join(node_data_dir, client_stack_name, 'recorder')
        # TODO: Change to rocksdb
        client_rec_kv_store = initKeyValueStorageIntKeys(
            KeyValueStorageType.Leveldb, client_rec_path, client_stack_name)
        node_rec_kv_store = initKeyValueStorageIntKeys(
            KeyValueStorageType.Leveldb,
            node_rec_path,
            node_name)

        return Recorder(node_rec_kv_store, skip_metadata_write=True), \
               Recorder(client_rec_kv_store, skip_metadata_write=True)

    def update_loaded_config(self, config):
        config.STACK_COMPANION = 2
        import stp_zmq.kit_zstack
        importlib.reload(stp_zmq.kit_zstack)
        import plenum.common.stacks
        importlib.reload(plenum.common.stacks)
        import plenum.server.node
        importlib.reload(plenum.server.node)
        import indy_node.server.node
        importlib.reload(indy_node.server.node)

    def replay_node(self, s1):
        self.unpack_recording(s1)
        node_name = self.s1_name

        if self.replay_dir:
            replay_node_dir = self.replay_dir
        else:
            replay_node_dir = tempfile.TemporaryDirectory().name

        logger.info("Replay node directory {}".format(replay_node_dir))

        orig_node_pool_dir = self.s1dir

        general_config_dir = create_config_dirs(replay_node_dir)

        recording_config_path = os.path.join(orig_node_pool_dir, 'config')
        for filename in os.listdir(recording_config_path):
            if os.path.isfile(os.path.join(recording_config_path, filename)):
                shutil.copy(os.path.join(recording_config_path, filename), general_config_dir)

        config = getConfig(general_config_dir)
        self.update_loaded_config(config)

        pool_name = config.NETWORK_NAME

        # strip left '/' for proper path join
        app_dir = os.path.join(replay_node_dir, get_system_dirs().NODE_INFO_DIR.lstrip('/'))
        pool_dir = os.path.join(app_dir, pool_name)
        orig_node_data_dir = os.path.join(orig_node_pool_dir, 'data')
        makedirs(pool_dir, exist_ok=True)

        for file in listdir(orig_node_pool_dir):
            if file.endswith('.json') or file.endswith('_genesis'):
                shutil.copy(os.path.join(orig_node_pool_dir, file), pool_dir)

        shutil.copytree(os.path.join(orig_node_pool_dir, 'keys'),
                        os.path.join(pool_dir, 'keys'))
        shutil.copytree(os.path.join(self.s1dir, 'plugins'),
                        os.path.join(app_dir,
                             'plugins'))

        node_rec, client_rec = get_recorders_from_node_data_dir(
            os.path.join(orig_node_pool_dir, 'data'), node_name)
        start_times_file = os.path.join(orig_node_data_dir, node_name, 'start_times')
        with open(start_times_file, 'r') as f:
            start_times = json.loads(f.read())

        replayable_node_class = create_replayable_node_class(Replica,
                                                             Replicas,
                                                             Node)
        node_ha = HA("0.0.0.0", 9701)
        client_ha = HA("0.0.0.0", 9702)

        config.STACK_COMPANION = 2
        node_config_helper = NodeConfigHelper(node_name, config,
                                              chroot=replay_node_dir)

        log_file_name = os.path.join(node_config_helper.log_dir,
                                     node_name + ".log")

        Logger().apply_config(config)
        Logger().enableFileLogging(log_file_name)

        logger.setLevel(config.logLevel)
        logger.debug('Replay goes to {}'.format(replay_node_dir))
        logger.debug("You can find logs in {}".format(log_file_name))

        with Looper(debug=config.LOOPER_DEBUG) as looper:
            replaying_node = replayable_node_class(node_name,
                                                   config_helper=node_config_helper,
                                                   ha=node_ha, cliha=client_ha)
            replaying_node = prepare_node_for_replay_and_replay(looper,
                                                                replaying_node,
                                                                node_rec,
                                                                client_rec,
                                                                start_times)
            logger.debug('Replaying node, size: {}, root_hash: {}'.format(
                replaying_node.domainLedger.size,
                replaying_node.domainLedger.root_hash
            ))


# TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
#       Move to common module
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError(
            'Boolean value (yes, no, true, false, y, n, 1, or 0) expected.')


# TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
#       Move to common module
levels = {
    'notset': logging.NOTSET,
    'debug': logging.DEBUG,
    'info': logging.INFO,
    'warning': logging.WARNING,
    'error': logging.ERROR,
    'critical': logging.CRITICAL
}


# TODO: Promote code reuse between nscapture, nsreplay, and nsdiff.
#       Move to common module
def loglevel(v):
    if v.lower() in levels.keys():
        return levels[v.lower()]
    else:
        raise argparse.ArgumentTypeError(
            'Expected one of the following: {}.'.format(
                ', '.join(levels.keys())))


def program_args():
    parser = argparse.ArgumentParser()

    recording_help = (".zip, .tar.gz, .tgz file or directory name containing a"
                      " recording. nscapture can be used to capture the result of a"
                      " recording. A recording is accomplished by setting `STACK_COMPANION=1` "
                      " in `/etc/indy/indy_config.py`")
    parser.add_argument("recording", help=recording_help)

    cleanup_help = ("Cleanup temporary files/directories?"
                    " [CLEANUP]: yes, no, y, n, true, false, t, f, 1, 0"
                    " Default: True")
    parser.add_argument("-c", "--cleanup", type=str2bool, nargs='?', const=True,
                        default=True, help=cleanup_help)

    log_level_help = ("Logging level."
                      " [LOG-LEVEL]: notset, debug, info, warning, error,"
                      " critical Default: notset")
    parser.add_argument("-l", "--log-level", type=loglevel, nargs='?',
                        const=logging.WARNING, default=logging.WARNING, help=log_level_help)

    parser.add_argument('-o', '--output-dir', default=None,
                        help='the directory where the recording will be replayed. Defaults to crated temp directory')

    return parser


def parse_args(argv=None, parser=program_args()):
    return parser.parse_args(args=argv)


def main(args):
    replayer = NodeStateReplayer(log_level=args.log_level, replay_dir=args.output_dir)
    replayer.replay_node(args.recording)
    return 0


if __name__ == '__main__':
    arguments = parse_args()

    sys.exit(main(arguments))
