#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import importlib
import argparse
import asyncio
import concurrent.futures
import curses
import datetime
import json
import os
import sys
import time
import subprocess
from collections import OrderedDict
from glob import glob
import re

from indy_common.config_util import getConfig
from indy_common.config_helper import ConfigHelper
from stp_core.common.log import Logger
from stp_core.common.log import getlogger

config = getConfig()

logger = None  # to make flake8 happy
clients = {}   # task -> (reader, writer)


class BaseUnknown:
    def __init__(self, val):
        self._val = val

    def _str(self):
        return str(self._val)

    def __str__(self):
        return self._str() if not self.is_unknown() else "unknown"

    def is_unknown(self):
        return self._val is None

    @property
    def val(self):
        return self._val


class NewEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, BaseUnknown):
            return o.val
        elif isinstance(o, ConnectionStatsOut):
            return o.bindings
        else:
            return super().default(o)


class FloatUnknown(BaseUnknown):
    def _str(self):
        return "{:.2f}".format(self.val)


class TimestampUnknown(BaseUnknown):
    def _str(self):
        return "{}".format(
            datetime.datetime.fromtimestamp(self.val).strftime(
                "%A, %B %{0}d, %Y %{0}I:%M:%S %p".format(
                    '#' if os.name == 'nt' else '-'))
        )


class UptimeUnknown(BaseUnknown):
    def _str(self):
        days, remainder = divmod(self.val, 86400)
        hours, remainder = divmod(remainder, 3600)
        minutes, seconds = divmod(remainder, 60)
        parts = []

        for s, v in zip(['day', 'hour', 'minute', 'second'],
                        [days, hours, minutes, seconds]):
            if v or len(parts):
                parts.append("{} {}{}".format(v, s, '' if v == 1 else 's'))

        return ", ".join(parts) if parts else '0 seconds'


class StateUnknown(BaseUnknown):
    def __str__(self):
        return self.val if not self.is_unknown() else 'in unknown state'


class NodesListUnknown(BaseUnknown):
    def __init__(self, val):
        super().__init__([] if val is None else val)

    def _str(self):
        return "\n".join("#  {}".format(alias) for alias in self.val)

    def __iter__(self):
        return iter(self.val)


class BaseStats(OrderedDict):
    shema = []

    def __init__(self, stats, verbose=False):
        if stats is None:
            logger.debug(
                "{} no stats found".format(type(self).__name__))

        for k, cls in self.shema:
            val = None if stats is None else stats.get(k)
            try:
                if issubclass(cls, BaseStats):
                    self[k] = cls(val, verbose=verbose)
                else:
                    self[k] = cls(val)
            except Exception as e:
                logger.warning(
                    "{} Failed to parse attribute '{}': {}".format(
                        type(self).__name__, k, e))
                self[k] = None

        self._verbose = verbose


class ConnectionStats(BaseStats):
    shema = [
        ("ip", BaseUnknown),
        ("port", BaseUnknown),
        ("protocol", BaseUnknown)
    ]


class ConnectionStatsOut:

    def __init__(self, bindings, verbose):
        self.bindings = bindings
        self._verbose = verbose

    def __str__(self):
        if not self._verbose:
            data = ["{}".format(b['port']) for b in self.bindings]
        else:
            data = [
                "{}{}".format(
                    b['port'],
                    "/{} on {}".format(b['protocol'], b['ip'])
                ) for b in self.bindings
            ]

        data = list(set(data))

        return ", ".join(data)


class BindingsStats(BaseStats):
    shema = [
        ("node", ConnectionStats),
        ("client", ConnectionStats)
    ]

    @staticmethod
    def explore_bindings(port):
        ANYADDR_IPV4 = '*'

        def shell_cmd(command):
            res = None
            try:
                ret = subprocess.check_output(
                    command, stderr=subprocess.STDOUT, shell=True)
            except subprocess.CalledProcessError as e:
                logger.warning(
                    "Shell command '{}' failed, "
                    "return code {}, stderr: {}".format(
                        command, e.returncode, e.stderr)
                )
            except Exception as e:
                logger.warning(
                    "Failed to process shell command: '{}', "
                    "unexpected error: {}".format(command, e)
                )
            else:
                logger.debug("command '{}': stdout '{}'".format(command, ret))
                res = ret.decode().strip()

            return res

        if port is None:
            return None

        # TODO
        # 1. need to filter more (e.g. by pid) for such cases as:
        #   - SO_REUSEPORT or SO_REUSEADDR
        #   - tcp + udp
        # 2. procss ipv6 as well
        #
        # parse listening ip using 'ss' tool
        command = "ss -ln4 | sort -u | grep ':{}\s'".format(port)
        ret = shell_cmd(command)

        if ret is None:
            return None

        ips = []
        ips_with_netmasks = {}
        for line in ret.split('\n'):
            try:
                parts = re.compile("\s+").split(line)
                # format:
                # Netid State Recv-Q Send-Q Local Address:Port Peer Address:Port
                protocol, ip = parts[0], parts[4].split(":")[0]
            except Exception as e:
                logger.warning(
                    "Failed to parse protocol, ip from '{}', "
                    "error: {}".format(line, e)
                )
            else:
                if ip == ANYADDR_IPV4:
                    # TODO mask here seems not necessary,
                    # but requested in INDY-715
                    ip = "0.0.0.0/0"
                else:
                    if ip not in ips_with_netmasks:
                        # parse mask using 'ip' tool
                        # TODO more filtering by 'ip' tool itself if possible
                        command = "ip a | grep 'inet {}'".format(ip)
                        ret = shell_cmd(command)

                        try:
                            ip_with_netmask = re.match(
                                "^inet\s([^\s]+)", ret).group(1)
                        except Exception as e:
                            logger.debug(
                                "Failed to parse ip with mask: command {}, "
                                "stdout: {}, error {}".format(command, ret, e))
                            ip = "{}/unknown".format(ip)

                        ips_with_netmasks[ip] = ip_with_netmask

                    ip = ips_with_netmasks[ip]

                ips.append((protocol, ip))

        return list(set(ips))

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        for target, data in self.items():
            # change schema: ignoring any data received except port,
            # resolve it ourselves (requested in original task INDY-715)
            # TODO refactor

            bindings = self.explore_bindings(data['port'].val)
            logger.info(
                "Found the following bindings for target {} "
                "with port {}: {}".format(target, data['port'], bindings)
            )
            self[target] = ConnectionStatsOut([] if bindings is None else [
                dict(port=data['port'], protocol=protocol, ip=ip)
                for protocol, ip in bindings
            ], self._verbose)


class TransactionsStats(BaseStats):
    shema = [
        ("config", BaseUnknown),
        ("ledger", BaseUnknown),
        ("pool", BaseUnknown)
    ]


class AverageStats(BaseStats):
    shema = [
        ("read-transactions", FloatUnknown),
        ("write-transactions", FloatUnknown)
    ]


class MetricsStats(BaseStats):
    shema = [
        ("uptime", UptimeUnknown),
        ("transaction-count", TransactionsStats),
        ("average-per-second", AverageStats)
    ]


class NodesListStats(BaseStats):
    shema = [
        ("count", BaseUnknown),
        ("list", NodesListUnknown)
    ]


class PoolStats(BaseStats):
    shema = [
        ("total-count", BaseUnknown),
        ("reachable", NodesListStats),
        ("unreachable", NodesListStats)
    ]


class SoftwareStats(BaseStats):
    shema = [
        ("indy-node", BaseUnknown),
        ("sovrin", BaseUnknown)
    ]

    @staticmethod
    def pkgVersion(pkgName):
        try:
            pkg = importlib.import_module(pkgName)
        except ImportError as e:
            logger.warning("Failed to import {}: {}".format(pkgName, e))
        else:
            try:
                return pkg.__version__
            except AttributeError as e:
                logger.warning(
                    "Failed to get version of {}: {}".format(pkgName, e))
                return None

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        pkgMappings = {
            'indy-node': 'indy_node'
        }

        for pkgName, obj in self.items():
            if obj is None or obj.is_unknown():
                self[pkgName] = BaseUnknown(
                    self.pkgVersion(pkgMappings.get(pkgName, pkgName)))


class ValidatorStats(BaseStats):
    shema = [
        ("response-version", BaseUnknown),
        ("timestamp", TimestampUnknown),
        ("alias", BaseUnknown),
        ("state", StateUnknown),
        ("enabled", BaseUnknown),
        ("did", BaseUnknown),
        ("verkey", BaseUnknown),
        ("bindings", BindingsStats),
        ("metrics", MetricsStats),
        ("pool", PoolStats),
        ("software", SoftwareStats)
    ]

    @staticmethod
    def get_process_state():
        ret = subprocess.check_output(
            'systemctl is-failed indy-node; exit 0',
            stderr=subprocess.STDOUT, shell=True
        )
        ret = ret.decode().strip()
        if ret == 'inactive':
            return 'stopped'
        elif ret == 'active':
            return 'running'
        else:
            logger.info(
                "Non-expected output for indy-node "
                "is-failed state: {}".format(ret)
            )
            return None

    @staticmethod
    def get_enabled_state():
        ret = subprocess.check_output(
            'systemctl is-enabled indy-node; exit 0',
            stderr=subprocess.STDOUT, shell=True
        )
        ret = ret.decode().strip()
        if ret in ('enabled', 'static'):
            return True
        elif ret == 'disabled':
            return False
        else:
            logger.info(
                "Non-expected output for indy-node "
                "is-enabled state: {}".format(ret)
            )
            return None

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # TODO move that to classes too

        if self['state'].is_unknown():
            self['state'] = StateUnknown(self.get_process_state())

        if self['enabled'].is_unknown():
            self['enabled'] = BaseUnknown(self.get_enabled_state())

    def __str__(self):
        # TODO moving parts to other classes seems reasonable but
        # will drop visibility of output

        lines = [
            "Validator {} is {}".format(self['alias'], self['state']),
            "#Current time:     {}".format(self['timestamp']),
            "Validator DID:    {}".format(self['did']),
            "Verification Key: {}".format(self['verkey']),
            "Node Port:        {}".format(self['bindings']['node']),
            "Client Port:      {}".format(self['bindings']['client']),
            "Metrics:",
            "  Uptime: {}".format(self['metrics']['uptime']),
            "#  Total Config Transactions:  {}".format(
                self['metrics']['transaction-count']['config']),
            "  Total Ledger Transactions:  {}".format(
                self['metrics']['transaction-count']['ledger']),
            "  Total Pool Transactions:    {}".format(
                self['metrics']['transaction-count']['pool']),
            "  Read Transactions/Seconds:  {}".format(
                self['metrics']['average-per-second']['read-transactions']),
            "  Write Transactions/Seconds: {}".format(
                self['metrics']['average-per-second']['write-transactions']),
            "Reachable Hosts:   {}/{}".format(
                self['pool']['reachable']['count'],
                self['pool']['total-count'])
        ] + [
            "#  {}".format(alias)
            for alias in self['pool']['reachable']['list']
        ] + [
            "Unreachable Hosts: {}/{}".format(
                self['pool']['unreachable']['count'],
                self['pool']['total-count']
            )
        ] + [
            "#  {}".format(alias)
            for alias in self['pool']['unreachable']['list']
        ] + [
            "#Software Versions:"
        ] + [
            "#  {}: {}".format(pkgName, self['software'][pkgName])
            for pkgName in self['software'].keys()
        ]

        # skip lines with started with '#' if not verbose
        # or remove '#' otherwise
        return ("\n".join(
            [l[(1 if l[0] == '#' else 0):]
                for l in lines if self._verbose or l[0] != '#'])
        )


async def handle_client(client_reader, client_writer):
    # give client a chance to respond, timeout after 10 seconds
    while True:
        try:
            data = await client_reader.readline()
        except concurrent.futures.CancelledError:
            logger.warning("task has been cancelled")
            return
        except Exception as e:
            logger.exception("failed to readline: {}".format(e))
            return
        else:
            if data is None:
                logger.warning("Expected data, received None")
                return
            elif not data:
                logger.warning("EOF received, closing connection")
                return
            else:
                logger.debug("Received data: {}".format(data))
                stats = json.loads(data.decode())
                print(json.dumps(stats, indent=2, cls=NewEncoder))


def accept_client(client_reader, client_writer):
    logger.info("New Connection")
    task = asyncio.Task(handle_client(client_reader, client_writer))

    clients[task] = (client_reader, client_writer)

    def client_done(task):
        del clients[task]
        client_writer.close()
        logger.info("End Connection")

    task.add_done_callback(client_done)


def get_stats_from_file(fpath, verbose, _json):
    with open(fpath) as f:
        stats = json.loads(f.read())

    logger.debug("Data {}".format(stats))
    vstats = ValidatorStats(stats, verbose)

    return (json.dumps(vstats, indent=2, cls=NewEncoder) if _json else vstats)


def watch(fpath, verbose, _json):

    def _watch(stdscr):
        stats = None
        while True:
            # Clear screen
            stdscr.clear()

            if stats is not None:
                time.sleep(3)
            stdscr.addstr(
                0, 0, str(get_stats_from_file(fpath, verbose, _json))
            )
            stdscr.refresh()

    try:
        curses.wrapper(_watch)
    except KeyboardInterrupt:
        pass


def main():
    global logger

    def check_unsigned(s):
        res = None
        try:
            res = int(s)
        except ValueError:
            pass
        if res is None or res <= 0:
            raise argparse.ArgumentTypeError(("{!r} is incorrect, "
                                              "should be int > 0").format(s,))
        else:
            return res

    config_helper = ConfigHelper(config)

    parser = argparse.ArgumentParser(
        description=(
            "Tool to explore and gather statistics about running validator"
        ),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "-v", "--verbose", action="store_true",
        help="Verbose mode (command line)"
    )
    parser.add_argument(
        "--json", action="store_true",
        help="Format output as JSON (ignores -v)"
    )

    statfile_group = parser.add_argument_group(
        "statfile", "settings for exploring validator stats from stat file"
    )

    statfile_group.add_argument(
        "--basedir", metavar="PATH",
        default=config_helper.node_info_dir,
        help=("Path to stats files")
    )
    # statfile_group.add_argument(
    #     "--watch", action="store_true", help="Watch for stats file updates"
    # )

    # socket group is disabled for now due the feature is unsupported
    # socket_group = parser.add_argument_group(
    #     "socket", "settings for exploring validator stats from socket"
    # )
    #
    # socket_group.add_argument(
    #     "--listen", action="store_true",
    #     help="Listen socket for stats (ignores --statfile)"
    # )
    #
    # socket_group.add_argument(
    #     "-i", "--ip", metavar="IP", default=config.STATS_SERVER_IP,
    #     help="Server IP"
    # )
    # socket_group.add_argument(
    #     "-p", "--port", metavar="PORT", default=config.STATS_SERVER_PORT,
    #     type=check_unsigned, help="Server port"
    # )

    other_group = parser.add_argument_group(
        "other", "other settings"
    )

    other_group.add_argument(
        "--log", metavar="FILE",
        default=os.path.join(
            config_helper.log_base_dir,
            os.path.basename(sys.argv[0] + ".log")
        ),
        help="Path to log file")

    other_group.add_argument("--stdlog", action="store_true",
                             help="Enable logging to stdout")

    args = parser.parse_args()

    config.enableStdOutLogging = args.stdlog

    logger = getlogger()
    Logger().enableFileLogging(args.log)

    logger.debug("Cmd line arguments: {}".format(args))

    # is not supported for now
    # if args.listen:
    #     logger.info("Starting server on {}:{} ...".format(
    #       args.ip, args.port))
    #     print("Starting server on {}:{} ...".format(args.ip, args.port))
    #
    #     loop = asyncio.get_event_loop()
    #     coro = asyncio.start_server(accept_client,
    #                                 args.ip, args.port, loop=loop)
    #     server = loop.run_until_complete(coro)
    #
    #     logger.info("Serving on {}:{} ...".format(args.ip, args.port))
    #     print('Serving on {} ...'.format(server.sockets[0].getsockname()))
    #
    #     # Serve requests until Ctrl+C is pressed
    #     try:
    #         loop.run_forever()
    #     except KeyboardInterrupt:
    #         pass
    #
    #     logger.info("Stopping server ...")
    #     print("Stopping server ...")
    #
    #     # Close the server
    #     server.close()
    #     for task in clients.keys():
    #         task.cancel()
    #     loop.run_until_complete(server.wait_closed())
    #     loop.close()
    # else:
    paths = glob(os.path.join(args.basedir, "*_info.json"))
    if not paths:
        print('There are no info files in {}'.format(args.basedir))
        return
    for file_path in paths:
        logger.info("Reading file {} ...".format(file_path))
        print(get_stats_from_file(file_path, args.verbose, args.json))
        print('\n')

    logger.info("Done")


if __name__ == "__main__":
    sys.exit(main())
