#!/usr/bin/env python -u
# coding=utf-8
""" A Pythonic implementation of pdsh powered by sshreader
"""
# Copyright (C) 2015-2017 Jesse Almanrode
#
#     This program is free software: you can redistribute it and/or modify
#     it under the terms of the GNU Lesser General Public License as published by
#     the Free Software Foundation, either version 3 of the License, or
#     (at your option) any later version.
#
#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU Lesser General Public License for more details.
#
#     You should have received a copy of the GNU Lesser General Public License
#     along with this program.  If not, see <http://www.gnu.org/licenses/>.
from __future__ import print_function
from collections import defaultdict
from hashlib import md5
from hostlist import expand_hostlist, collect_hostlist
import click
import logging
import os
import sshreader
import sys
# GLOBALS
__author__ = 'Jesse Almanrode'
__version__ = '2.1'
__examples__ = """\b
Examples:
    pydsh -w host1,host2,host3 "uname -r"
    pydsh -u root -k /root/.ssh/id_rsa -w host[1,3] "uname -r"
    pydsh -u root -P Password123 -w host[1-3] "uname -r"
    pydsh -F -w host[01-10] myscript.sh
"""


def copy_script(scriptfile, job):
    """ Copy script to remote host using SFTP connection

    :param scriptfile: Path to script file
    :param job: <ServerJob> object
    :return: True on file copy
    """
    script_name = os.path.split(scriptfile)[1]
    try:
        with sshreader.SSH(job.name, username=job.username, password=job.password, keyfile=job.key, timeout=job.sshtimeout) as conn:
            conn.sftp_put(scriptfile, '/tmp/' + script_name)
        return True
    except Exception as err:
        return err


def output(thisjob):
    """ Print output from jobs as they complete in a format that could be piped to dshbak

    :param thisjob: <ServerJob> object
    :return: None
    """
    if thisjob.status == 255:
        result = thisjob.results[0]
    else:
        result = thisjob.results[0].stdout
    if len(result) != 0:
        for line in result.split('\n'):
            sshreader.echo(str(thisjob.name) + ': ' + str(line))
    return None


def dshbak(jobresults):
    """ Output the results of the jobs grouped by hosts

    Similar to piping output to `dshbak`

    :param jobresults: List of <ServerJob> objects
    :return: None
    """
    for thisjob in jobresults:
        if thisjob.status == 255:
            result = thisjob.results[0]
        else:
            result = thisjob.results[0].stdout
        if len(result) != 0:
            click.echo(str('-' * 16) + '\n' + str(thisjob.name) + '\n' + str('-' * 16))
            click.echo(result)
    return None


def coalesce(jobresults):
    """ Output the results of jobs coalescing identical output from hosts

    Similar to piping output to `dshbak -c`

    :param jobresults: List of <ServerJob> objects
    :return: None
    """
    job_hashes = defaultdict(list)
    output_hashes = dict()
    for job in jobresults:
        if job.status == 255:
            result = job.results[0]
        else:
            result = job.results[0].stdout
        md5sum = md5(result.encode()).hexdigest()
        job_hashes[md5sum].append(job.name)
        if md5sum not in output_hashes.keys():
            output_hashes[md5sum] = result

    for md5sum, stdout in output_hashes.items():
        if len(stdout) != 0:
            click.echo(str('-' * 16) + '\n' + collect_hostlist(job_hashes[md5sum]) + '\n' + str('-' * 16))
            click.echo(stdout)
    return None


def validate_hostlist(ctx, param, value):
    """ Callback for click to expand hostlist expressions or error

    :param ctx: Click context
    :param param: Parameter Name
    :param value: Hostlist expression to expand
    :return: List of expanded hosts
    """
    try:
        return expand_hostlist(value)
    except Exception:
        raise click.BadOptionUsage('Invalid hostlist expression')


@click.command(epilog=__examples__)
@click.version_option(version=__version__)
@click.option('--hostlist', '-w', metavar='EXPR', required=True, callback=validate_hostlist,
              help='Hostlist expression')
@click.option('--username', '-u', help='Override ssh username')
@click.option('--keyfile', '-k', type=click.Path(exists=True, dir_okay=False, readable=True), help='Override ssh key')
@click.option('--prompt', '-p', is_flag=True, help='Prompt for ssh password')
@click.option('--password', '-P', help='Supply ssh password')
@click.option('--timeout', '-T', default=600, help='Timeout for ssh commands')
@click.option('--dshbak', '-D', is_flag=True, help='Group output by host')
@click.option('--coalesce', '-C', is_flag=True, help='Coalesce similar output from hosts')
@click.option('--file', '-F', is_flag=True, help='Treat CMD as a script file')
@click.option('--debug', '-d', is_flag=True, help='Enable debug output')
@click.option('--redline', is_flag=True, help='Run pydsh faster')
@click.argument('cmd', nargs=1, required=True)
def cli(**kwargs):
    """  Run ssh commands in parallel across hosts
    """
    if kwargs['file']:
        mkpath = click.Path(exists=True, dir_okay=False)
        script_path = mkpath(kwargs['cmd'])
        script_name = os.path.split(script_path)[1]
        prehook = sshreader.Hook(copy_script, args=[script_path])
        with open(script_path) as s:
            script = s.readlines()
        if script[0].startswith('#!') is False:
            raise click.UsageError('Script must start with #!')
        kwargs['cmd'] = [script[0].split('#!').pop().strip() + ' /tmp/' + script_name, 'rm /tmp/' + script_name]
    sshenv = sshreader.envvars()

    if kwargs['username'] is None:
        if sshenv.username is None:
            raise click.ClickException('Unable to determine ssh username. Please provide one using --username')
        else:
            kwargs['username'] = sshenv.username

    # By default, we prefer ssh keys
    if kwargs['keyfile'] is None:
        if sshenv.rsa_key is None and sshenv.dsa_key is None:
            if kwargs['password'] is None and kwargs['prompt'] is False:
                raise click.ClickException('Unable to find ssh key to use and password not supplied.')
        else:
            if sshenv.rsa_key is not None:
                kwargs['keyfile'] = sshenv.rsa_key
            else:
                kwargs['keyfile'] = sshenv.dsa_key
    else:
        # If you specify an SSH key then we ignore any password or prompt flags you might have entered
        kwargs['password'] = None
        kwargs['prompt'] = False

    # If you specify a password or prompt for one it overrides the ssh key
    if kwargs['password'] is None:
        if kwargs['prompt'] is False:
            if kwargs['keyfile'] is None:
                raise click.ClickException('Unable to find ssh key to use and password not supplied or prompt enabled.')
        else:
            kwargs['keyfile'] = None
            while kwargs['password'] is None:
                kwargs['password'] = click.prompt(kwargs['username'] + "'s Password", hide_input=True)
    else:
        # You provided a password, ignore the SSH key
        kwargs['keyfile'] = None

    if kwargs['debug']:
        logging.getLogger('sshreader').setLevel(logging.INFO)

    posthook = sshreader.Hook(target=output)
    jobs = list()
    for host in kwargs['hostlist']:
        if kwargs['keyfile'] is not None:
            job = sshreader.ServerJob(host, kwargs['cmd'], username=kwargs['username'], keyfile=kwargs['keyfile'],
                                      timeout=kwargs['timeout'], combine_output=True)
        else:
            job = sshreader.ServerJob(host, kwargs['cmd'], username=kwargs['username'], password=kwargs['password'],
                                      timeout=kwargs['timeout'], combine_output=True)
        if kwargs['dshbak'] is False and kwargs['coalesce'] is False:
            job.posthook = posthook
        if kwargs['file']:
            job.prehook = prehook
        jobs.append(job)

    if kwargs['dshbak'] is False and kwargs['coalesce'] is False:
        if kwargs['redline']:
            sshreader.sshread(jobs, pcount=-1, tcount=0)
        else:
            sshreader.sshread(jobs, pcount=0, tcount=0)
    else:
        if kwargs['redline']:
            jobs_finished = sshreader.sshread(jobs, pcount=-1, tcount=0, progress_bar=True)
        else:
            jobs_finished = sshreader.sshread(jobs, pcount=0, tcount=0, progress_bar=True)
        if kwargs['coalesce']:
            coalesce(jobs_finished)
        else:
            dshbak(jobs_finished)
    sys.exit(0)

if __name__ == "__main__":
    cli()
