#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import importlib
import argparse
import datetime
import json
import os
import sys
import subprocess
from collections import OrderedDict
from glob import glob
import re
import pwd
import pathlib

from indy_common.config_util import getConfig
from indy_common.config_helper import ConfigHelper
from storage.kv_store_rocksdb_int_keys import KeyValueStorageRocksdbIntKeys
from stp_core.common.log import Logger
from stp_core.common.log import getlogger

config = getConfig()

logger = getlogger()  # to make flake8 happy
clients = {}  # task -> (reader, writer)
INDY_USER = "indy"


class BaseUnknown:
    def __init__(self, val):
        self._val = val

    def _str(self):
        return str(self._val)

    def __str__(self):
        return self._str() if not self.is_unknown() else "unknown"

    def is_unknown(self):
        return self._val is None

    @property
    def val(self):
        return self._val

    @val.setter
    def val(self, val):
        self._val = val


class NewEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, BaseUnknown):
            return o.val
        elif isinstance(o, ConnectionStatsOut):
            return o.bindings
        else:
            return super().default(o)


class FloatUnknown(BaseUnknown):
    def _str(self):
        return "{:.2f}".format(self.val)


class UptimeUnknown(BaseUnknown):
    def _str(self):
        days, remainder = divmod(self.val, 86400)
        hours, remainder = divmod(remainder, 3600)
        minutes, seconds = divmod(remainder, 60)
        parts = []

        for s, v in zip(['day', 'hour', 'minute', 'second'],
                        [days, hours, minutes, seconds]):
            if v or len(parts):
                parts.append("{} {}{}".format(v, s, '' if v == 1 else 's'))

        return ", ".join(parts) if parts else '0 seconds'


class StateUnknown(BaseUnknown):
    def __str__(self):
        return self.val if not self.is_unknown() else 'in unknown state'


class NodesListUnknown(BaseUnknown):
    def __init__(self, val):
        super().__init__([] if val is None else val)

    def _str(self):
        return "\n".join("#  {}".format(alias) for alias in self.val)

    def __iter__(self):
        return iter(self.val)


class BaseStats(OrderedDict):
    shema = []

    def __init__(self, stats, verbose=False):
        if stats is None:
            logger.debug(
                "{} no stats found".format(type(self).__name__))

        for k, cls in self.shema:
            val = None if stats is None else stats.get(k)
            try:
                if issubclass(cls, BaseStats):
                    self[k] = cls(val, verbose=verbose)
                else:
                    self[k] = cls(val)
            except Exception as e:
                logger.warning(
                    "{} Failed to parse attribute '{}': {}".format(
                        type(self).__name__, k, e))
                self[k] = None

        self._verbose = verbose


class ConnectionStatsOut:

    def __init__(self, bindings, verbose):
        self.bindings = bindings
        self._verbose = verbose

    def __str__(self):
        if not self._verbose:
            data = ["{}".format(b['port']) for b in self.bindings]
        else:
            data = [
                "{}{}".format(
                    b['port'],
                    "/{} on {}".format(b['protocol'], b['ip'])
                ) for b in self.bindings
            ]

        data = list(set(data))

        return ", ".join(data)


class BindingStats(BaseUnknown):

    @staticmethod
    def explore_bindings(port):
        ANYADDR_IPV4 = '*'

        def shell_cmd(command):
            res = None
            try:
                ret = subprocess.check_output(
                    command, stderr=subprocess.STDOUT, shell=True)
            except subprocess.CalledProcessError as e:
                logger.warning(
                    "Shell command '{}' failed, "
                    "return code {}, stderr: {}".format(
                        command, e.returncode, e.stderr)
                )
            except Exception as e:
                logger.warning(
                    "Failed to process shell command: '{}', "
                    "unexpected error: {}".format(command, e)
                )
            else:
                logger.debug("command '{}': stdout '{}'".format(command, ret))
                res = ret.decode().strip()

            return res

        if port is None:
            return None

        # TODO
        # 1. need to filter more (e.g. by pid) for such cases as:
        #   - SO_REUSEPORT or SO_REUSEADDR
        #   - tcp + udp
        # 2. procss ipv6 as well
        #
        # parse listening ip using 'ss' tool
        command = "ss -ln4 | sort -u | grep ':{}\s'".format(port)
        ret = shell_cmd(command)

        if ret is None:
            return None

        ips = []
        ips_with_netmasks = {}
        for line in ret.split('\n'):
            try:
                parts = re.compile("\s+").split(line)
                # format:
                # Netid State Recv-Q Send-Q Local Address:Port Peer Address:Port
                protocol, ip = parts[0], parts[4].split(":")[0]
            except Exception as e:
                logger.warning(
                    "Failed to parse protocol, ip from '{}', "
                    "error: {}".format(line, e)
                )
            else:
                if ip == ANYADDR_IPV4:
                    # TODO mask here seems not necessary,
                    # but requested in INDY-715
                    ip = "0.0.0.0/0"
                else:
                    if ip not in ips_with_netmasks:
                        # parse mask using 'ip' tool
                        # TODO more filtering by 'ip' tool itself if possible
                        command = "ip a | grep 'inet {}'".format(ip)
                        ret = shell_cmd(command)

                        try:
                            ip_with_netmask = re.match(
                                "^inet\s([^\s]+)", ret).group(1)
                        except Exception as e:
                            logger.debug(
                                "Failed to parse ip with mask: command {}, "
                                "stdout: {}, error {}".format(command, ret, e))
                            ip = "{}/unknown".format(ip)

                        ips_with_netmasks[ip] = ip_with_netmask

                    ip = ips_with_netmasks[ip]

                ips.append((protocol, ip))

        return list(set(ips))

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # change schema: ignoring any data received except port,
        # resolve it ourselves (requested in original task INDY-715)
        # TODO refactor

        bindings = self.explore_bindings(self.val)
        logger.info(
            "Found the following bindings "
            "with port {}: {}".format(self.val, bindings)
        )
        self.val = ConnectionStatsOut([] if bindings is None else [
            dict(port=self.val, protocol=protocol, ip=ip)
            for protocol, ip in bindings
        ], False)


class TransactionsStats(BaseStats):
    shema = [
        ("config", BaseUnknown),
        ("ledger", BaseUnknown),
        ("pool", BaseUnknown)
    ]


class AverageStats(BaseStats):
    shema = [
        ("read-transactions", FloatUnknown),
        ("write-transactions", FloatUnknown)
    ]


class MetricsStats(BaseStats):
    shema = [
        ("uptime", UptimeUnknown),
        ("transaction-count", TransactionsStats),
        ("average-per-second", AverageStats)
    ]


class NodeStats(BaseStats):
    shema = [
        ("Name", BaseUnknown),
        ("did", BaseUnknown),
        ("verkey", BaseUnknown),
        ("Node_port", BindingStats),
        ("Client_port", BindingStats),
        ("Metrics", MetricsStats)
    ]


class PoolStats(BaseStats):
    shema = [
        ("Total_nodes_count", BaseUnknown),
        ("Reachable_nodes", NodesListUnknown),
        ("Reachable_nodes_count", BaseUnknown),
        ("Unreachable_nodes", NodesListUnknown),
        ("Unreachable_nodes_count", BaseUnknown)
    ]


class SoftwareStats(BaseStats):
    shema = [
        ("indy-node", BaseUnknown),
        ("sovrin", BaseUnknown)
    ]

    @staticmethod
    def pkgVersion(pkgName):
        try:
            pkg = importlib.import_module(pkgName)
        except ImportError as e:
            logger.warning("Failed to import {}: {}".format(pkgName, e))
        else:
            try:
                return pkg.__version__
            except AttributeError as e:
                logger.warning(
                    "Failed to get version of {}: {}".format(pkgName, e))
                return None

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        pkgMappings = {
            'indy-node': 'indy_node'
        }

        for pkgName, obj in self.items():
            if obj is None or obj.is_unknown():
                self[pkgName] = BaseUnknown(
                    self.pkgVersion(pkgMappings.get(pkgName, pkgName)))


class ValidatorStats(BaseStats):
    shema = [
        ("response-version", BaseUnknown),
        ("timestamp", BaseUnknown),
        ("Node_info", NodeStats),
        ("state", StateUnknown),
        ("enabled", BaseUnknown),
        ("Pool_info", PoolStats),
        ("software", SoftwareStats)
    ]

    @staticmethod
    def get_process_state():
        ctl = os.getenv('INDY_CONTROL', 'systemctl')
        if ctl == "systemctl":
            return ValidatorStats.get_process_state_via_systemctl()
        elif ctl == "supervisorctl":
            return ValidatorStats.get_process_state_via_supervisorctl()
        else:
            return "Invalid value for INDY_CONTROL environment variable: '%s'" % ctl

    @staticmethod
    def get_process_state_via_systemctl():
        ret = subprocess.check_output(
            'systemctl is-failed indy-node; exit 0',
            stderr=subprocess.STDOUT, shell=True
        )
        ret = ret.decode().strip()
        if ret == 'inactive':
            return 'stopped'
        elif ret == 'active':
            return 'running'
        else:
            logger.info(
                "Non-expected output for indy-node "
                "is-failed state: {}".format(ret)
            )
            return None

    @staticmethod
    def get_process_state_via_supervisorctl():
        ret = subprocess.check_output(
            "supervisorctl status indy-node | awk '{print $2}'; exit 0",
            stderr=subprocess.STDOUT, shell=True
        )
        ret = ret.decode().strip()
        if ret == 'STOPPED':
            return 'stopped'
        elif ret == 'RUNNING':
            return 'running'
        else:
            logger.info(
                "Non-expected output for indy-node "
                "status: {}".format(ret)
            )
            return None

    @staticmethod
    def get_enabled_state():
        ctl = os.getenv('INDY_CONTROL', 'systemctl')
        if ctl == "systemctl":
            return ValidatorStats.get_enabled_state_via_systemctl()
        elif ctl == "supervisorctl":
            return ValidatorStats.get_enabled_state_via_supervisorctl()
        else:
            return "Invalid value for INDY_CONTROL environment variable: '%s'" % ctl

    @staticmethod
    def get_enabled_state_via_systemctl():
        ret = subprocess.check_output(
            'systemctl is-enabled indy-node; exit 0',
            stderr=subprocess.STDOUT, shell=True
        )
        ret = ret.decode().strip()
        if ret in ('enabled', 'static'):
            return True
        elif ret == 'disabled':
            return False
        else:
            logger.info(
                "Non-expected output for indy-node "
                "is-enabled state: {}".format(ret)
            )
            return None

    @staticmethod
    def get_enabled_state_via_supervisorctl():
        ret = subprocess.check_output(
            "supervisorctl status indy-node | awk '{print $2}'; exit 0",
            stderr=subprocess.STDOUT, shell=True
        )
        ret = ret.decode().strip()
        if ret in ('RUNNING', 'BACKOFF', 'STARTING'):
            return True
        elif ret == 'STOPPED':
            return False
        else:
            logger.info(
                "Non-expected output for indy-node "
                "is-enabled state: {}".format(ret)
            )
            return None

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # TODO move that to classes too

        if self['state'].is_unknown():
            self['state'] = StateUnknown(self.get_process_state())

        if self['enabled'].is_unknown():
            self['enabled'] = BaseUnknown(self.get_enabled_state())

    def __str__(self):
        # TODO moving parts to other classes seems reasonable but
        # will drop visibility of output
        lines = [
                    "Validator {} is {}".format(self['Node_info']['Name'], self['state']),
                    "Timestamp:        {}".format(self['timestamp']),
                    "Validator DID:    {}".format(self['Node_info']['did']),
                    "Verification Key: {}".format(self['Node_info']['verkey']),
                    "Node Port:        {}".format(self['Node_info']['Node_port']),
                    "Client Port:      {}".format(self['Node_info']['Client_port']),
                    "Metrics:",
                    "  Uptime: {}".format(self['Node_info']['Metrics']['uptime']),
                    "#  Total Config Transactions:  {}".format(
                        self['Node_info']['Metrics']['transaction-count']['config']),
                    "  Total Ledger Transactions:  {}".format(
                        self['Node_info']['Metrics']['transaction-count']['ledger']),
                    "  Total Pool Transactions:    {}".format(
                        self['Node_info']['Metrics']['transaction-count']['pool']),
                    "  Total Audit Transactions:    {}".format(
                        self['Node_info']['Metrics']['transaction-count']['audit']),
                    "  Read Transactions/Seconds:  {}".format(
                        self['Node_info']['Metrics']['average-per-second']['read-transactions']),
                    "  Write Transactions/Seconds: {}".format(
                        self['Node_info']['Metrics']['average-per-second']['write-transactions']),
                    "Reachable Hosts:   {}/{}".format(
                        self['Pool_info']['Reachable_nodes_count'],
                        self['Pool_info']['Total_nodes_count'])
                ] + [
                    "#  {}".format(alias)
                    for alias in self['Pool_info']['Reachable_nodes']
                ] + [
                    "Unreachable Hosts: {}/{}".format(
                        self['Pool_info']['Unreachable_nodes_count'],
                        self['Pool_info']['Total_nodes_count']
                    )
                ] + [
                    "#  {}".format(alias)
                    for alias in self['Pool_info']['Unreachable_nodes']
                ] + [
                    "#Software Versions:"
                ] + [
                    "#  {}: {}".format(pkgName, self['software'][pkgName])
                    for pkgName in self['software'].keys()
                ]

        # skip lines with started with '#' if not verbose
        # or remove '#' otherwise
        return ("\n".join(
            [l[(1 if l[0] == '#' else 0):]
             for l in lines if self._verbose or l[0] != '#'])
        )


def str_from_timestamp(ts):
    return datetime.datetime.fromtimestamp(ts).strftime("%Y-%m-%d %H:%M:%S")


def get_validator_stats(stats, verbose, _json):
    logger.debug("Data {}".format(stats))
    vstats = ValidatorStats(stats, verbose)

    if _json:
        return json.dumps(vstats, indent=2, cls=NewEncoder)

    return vstats


def format_key(key):
    return "{:15}".format('"{}": '.format(key))


def make_indent(indent):
    return indent * "{:5}".format("")


def format_value(value):
    return " {:10}".format(str(value))


def create_print_tree(stats: dict, indent=0, lines=[]):
    for key, value in stats.items():
        if isinstance(value, dict):
            lines.append(make_indent(indent) + format_key(key))
            create_print_tree(value, indent + 1, lines)
        elif isinstance(value, list):
            lines.append(make_indent(indent) + format_key(key))
            for line in value:
                lines.append(make_indent(indent + 1) + format_value(line))
        else:
            if isinstance(value, dict) and not value or \
                    isinstance(value, list) and not value:
                value = 'n/a'
            lines.append(make_indent(indent) + format_key(key) + format_value(value))
    return lines


def set_log_owner(log_path):
    def set_own():
        try:
            os.chown(log_path, indy_uid, indy_gid)
        except PermissionError as e:
            print("Cannot set owner of {} file to indy".format(log_path))
            print("The owner of {} must be {}:{}".format(log_path, INDY_USER, INDY_USER))
            sys.exit(1)

    indy_ids = pwd.getpwnam(INDY_USER)
    indy_uid = indy_ids.pw_uid
    indy_gid = indy_ids.pw_gid
    if os.path.exists(log_path):
        f_stat = os.stat(log_path)
        if f_stat.st_uid != indy_uid or f_stat.st_gid != indy_gid:
            set_own()
    else:
        pathlib.Path(log_path).touch()
        set_own()


def remove_log_handlers():
    for hndl in logger.root.handlers:
        logger.root.removeHandler(hndl)

    for hndl in logger.handlers:
        logger.removeHandler(hndl)


def read_json_data(info_path, num: int = 1, from_start=False, _from: int = None, _to: int = None):
    assert _from <= _to if _from != None and _to != None else True

    db_path, db_name = os.path.split(info_path)
    db = KeyValueStorageRocksdbIntKeys(db_path, db_name, read_only=True)

    it = db.iterator()

    if _from != None:
        it.seek_for_prev(db.to_byte_repr(_from))
        # Handle the case when 'from' timestamp is ahead from the first record
        try:
            # Try to make a step forward and back
            _ = next(it)
            _ = next(reversed(it))
        except StopIteration:
            it.seek_to_first()
    elif _to != None or from_start:
        it.seek_to_first()
    else:
        it.seek_to_last()
        it = reversed(it)

    json_data = []
    cnt = 0

    for ts, data in it:
        try:
            obj = json.loads(data.decode())
        except json.JSONDecodeError:
            print("DB '{}' contains invalid records".format(info_path))
            return None

        ts = int(ts)
        obj['timestamp'] = "{} ({})".format(str_from_timestamp(ts), ts)

        if _from is not None or _to is not None:
            if _from is None or _from <= ts:
                if _to is None or ts <= _to:
                    json_data.append(obj)
                else:
                    break
        else:
            if from_start:
                json_data.append(obj)
            else:
                json_data.insert(0, obj)

            if num != -1:
                cnt += 1
                if cnt == num:
                    break

    return json_data


def compile_json_ouput(file_paths):
    output_data = {}
    for file_path in file_paths:
        json_data = read_json_data(file_path)
        if json_data:
            for json_obj in json_data:
                output_data.update(json_obj)
    return output_data


def get_info_paths(node_names, basedir):
    postfix = "_info_db"
    info_paths = []
    _info_paths = glob(os.path.join(basedir, "*{}".format(postfix)))
    if node_names:
        for info_path in _info_paths:
            match = False
            for node_name in node_names:
                name = os.path.basename(info_path[:-len(postfix)])
                if node_name.lower() == name:
                    match = True
                    break
            if match:
                info_paths.append(info_path)
    else:
        info_paths = _info_paths
    return info_paths


def check_unsigned(s):
    res = None
    try:
        res = int(s)
    except ValueError:
        pass
    if res is None or res < 0:
        raise argparse.ArgumentTypeError(("{!r} is incorrect, "
                                          "should be int >= 0").format(s, ))
    else:
        return res


def check_N_arg(s):
    res = None
    try:
        res = int(s)
    except ValueError:
        pass
    if res is None or (res <= 0 and res != -1):
        raise argparse.ArgumentTypeError(("{!r} is incorrect, "
                                          "should be int > 0 or -1").format(s, ))
    else:
        return res


def parse_args():
    config_helper = ConfigHelper(config)

    parser = argparse.ArgumentParser(
        description=(
            "Tool to explore and gather historycal statistics about running validator"
        ),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "-n", "--num", metavar="NUM", type=check_N_arg,
        default=1,
        help=("Number of records to print (from the end by default), -1 means all records")
    )

    parser.add_argument(
        "-s", "--from_start", action="store_true",
        help=("Count records from the start")
    )

    parser.add_argument(
        "--names", metavar="Node1,Node2",
        default=None,
        help=("Comma-separated list of nodes to print records for.")
    )

    parser.add_argument(
        "--frm", metavar="TIMESTAMP", type=check_unsigned,
        default=None,
        help=("Print records from timestamp (ignores -n and -s)")
    )

    parser.add_argument(
        "--to", metavar="TIMESTAMP", type=check_unsigned,
        default=None,
        help=("Print records to timestamp (ignores -n and -s)")
    )

    output_group = parser.add_argument_group(
        "output", "specify output formatting"
    )

    output_group.add_argument(
        "--fields", metavar="Path.To.Field1,Path.To.Fie.ld2",
        default=None,
        help=("Comma-separated list of JSON nodes to print (ignores -v and --json, case-sensitive).")
    )

    output_group.add_argument(
        "--json", action="store_true",
        help="Format output as JSON (ignores -v)"
    )

    output_group.add_argument(
        "-v", "--verbose", action="store_true",
        help="Verbose mode (command line)"
    )

    db_group = parser.add_argument_group(
        "database", "settings for exploring validator stats from database"
    )

    db_group.add_argument(
        "-d", "--basedir", metavar="PATH",
        default=config_helper.node_info_dir,
        help=("Path to databases base directory")
    )

    other_group = parser.add_argument_group(
        "other", "other settings"
    )

    other_group.add_argument(
        "--log", metavar="FILE",
        default=os.path.join(
            config_helper.log_base_dir,
            os.path.basename(sys.argv[0] + ".log")
        ),
        help="Path to log file")

    return parser.parse_args()


def check_args(args):
    if args.frm != None and args.to != None and args.frm > args.to:
        print("'From' timestamp can not be greater that 'to' timestamp")
        return False

    return True


def main():
    global logger

    args = parse_args()

    if not check_args(args):
        return -1

    remove_log_handlers()

    if args.log:
        set_log_owner(args.log)

    Logger().enableFileLogging(args.log)

    logger.debug("Cmd line arguments: {}".format(args))

    node_names = str(args.names).split(",") if args.names else None

    info_paths = get_info_paths(node_names, args.basedir)

    if len(info_paths) == 0:
        print("No storages found")
        return -1

    fields_paths = None
    if args.fields:
        fields_paths = args.fields.split(",")

    for info_path in info_paths:
        json_data = read_json_data(info_path, args.num, args.from_start, args.frm, args.to)
        if json_data:
            print("====================================================")
            print(json_data[0]['Node_info']['Name'])
            print("====================================================")
            for json_obj in json_data:
                if fields_paths:
                    stats = ""
                    for fields_path in fields_paths:
                        path = fields_path.split(".")
                        _obj = json_obj
                        for node in path:
                            try:
                                _obj = _obj[node]
                            except KeyError:
                                print("Incorrect fields path '{}': there is no such key '{}'".format(fields_path, node))
                                return -1
                        if stats != "":
                            stats += "\n"
                        stats += "{}: {}".format(fields_path, _obj)
                elif args.json:
                    stats = json.dumps(json_obj)
                elif args.verbose:
                    stats = "{}".format(os.linesep).join(create_print_tree(json_obj, lines=[]))
                else:
                    stats = get_validator_stats(json_obj, args.verbose, args.json)
                print(stats)
                print('----------------------------------------------------')

    logger.info("Done")


if __name__ == "__main__":
    sys.exit(main())
