#!/usr/bin/env python -u
# coding=utf-8
"""A "pythonish" implementation of pdsh using sshreader!"""
from __future__ import print_function, unicode_literals
import argparse
import getpass
import os
import sshreader
import sys
from hostlist import expand_hostlist
# GLOBALS
__author__ = 'Jesse Almanrode'
__version__ = '1.1.1'
__examples__ = """EXAMPLES:
    %(prog)s -k ~/.ssh/id_rsa -w host1,host2,host3 "uname -r"
    %(prog)s -u root -k /root/.ssh/id_rsa -w host[1,3] "uname -r"
    %(prog)s -u root -p Password123 -w host[1-3] "uname -r"
"""
debug = False


def print_job(*args):
    """Output the results of the job in a given format (also used as the posthook for sshreader)

    :param args: Tuple of ( 'quick', <ServerJob> )
    :return: None
    """
    thisjob = list(args).pop()
    if 'quick' in args:
        if thisjob.status == 255:
            sshreader.echo(str(thisjob.name) + ': Unable to establish ssh connection')
        else:
            result = thisjob.results[0]
            for line in str(result.stdout.decode()).split('\n'):
                sshreader.echo(str(thisjob.name) + ': ' + str(line))
    else:
        print(str('-' * 16) + '\n' + str(thisjob.name) + '\n' + str('-' * 16))
        if thisjob.status == 255:
            print(str(thisjob.name) + ': Unable to establish ssh connection')
        else:
            result = thisjob.results[0]
            for line in str(result.stdout.decode()).split('\n'):
                print(str(line))
    return None


def process_results(jobresults, quiet):
    """Output the results of the jobs if the have results

    :param jobresults: List of <ServerJob> objects
    :param quiet: Only output jobs with stdout or stderr
    :return: None
    """
    for job in jobresults:
        if quiet:
            if len(job.results[0].stdout) > 0:
                print_job(job)
        else:
            print_job(job)
    return None


def main():
    """MAIN

    :return: 0 on Success
    """
    global debug
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
    parser.epilog = __examples__
    parser.add_argument("-v", "--version", action='version', version="%(prog)s " + __version__)
    parser.add_argument("command", metavar="CMD", type=str,  nargs=1, help="The command to run on the remote hosts")
    parser.add_argument("-u",  "--username",  type=str,  metavar="UID",
                        help="Username to ssh as (default is current user)")
    parser.add_argument('-T', '--timeout', type=int, default=30, metavar='INT', dest='timeout',
                        help='Timeout for ssh commands in seconds (default = 30)')
    requiredargs = parser.add_argument_group("required arguments")
    requiredargs.add_argument("-w", "--hostlist", type=str, metavar="EXPR", required=True, help="hostlist expression")
    authtype = parser.add_mutually_exclusive_group()
    authtype.add_argument("-k", "--keyfile", type=str, metavar="PATH", help="Provide an ssh key to use")
    authtype.add_argument("-p", "--password", type=str, metavar="PWD",  nargs='?',  action='store', const='',
                          help="Password to use")
    outputoptions = parser.add_argument_group("output styles")
    outputstyle = outputoptions.add_mutually_exclusive_group()
    outputstyle.add_argument("-q", "--quiet", default=False, action="store_true",
                             help="Disable progress bar and blank output")
    outputstyle.add_argument("-Q", "--quick", default=False, action="store_true",
                             help="Output results as they return (can be piped to dshbak)")
    outputoptions.add_argument("-s", "--sort", default=False, action="store_true",
                               help="Sort the output of command statuses (Ignored if -Q or --quick is used!)")
    options = parser.add_argument_group("debug options")
    options.add_argument("-d", "--debug", default=False, action="store_true", help="Turn on debug mode")

    args = parser.parse_args()
    if args.debug:
        debug = True
        print(args)

    if args.username is None:
        if os.getlogin() == getpass.getuser():
            args.username = getpass.getuser()
        else:
            raise parser.error('Unable to determine username.  Please specify one!')
        userhome = os.path.expanduser('~')
        if args.keyfile is None and args.password is None:
            keyfiles = os.listdir(userhome + "/.ssh")
            if "id_dsa" in keyfiles:
                args.keyfile = userhome + "/.ssh/id_dsa"
            elif "id_rsa" in keyfiles:
                args.keyfile = userhome + "/.ssh/id_rsa"
            else:
                parser.error("Unable to find an ssh key. Please specify one or use the password flag!")
    else:
        if args.keyfile is None and args.password is None:
            parser.error("Please specify a key location or password when specifying a username!")
        if args.password == '':
            args.password = None
            while args.password is None:
                args.password = getpass.getpass(args.username + "'s Password: ")

    hostlist = expand_hostlist(args.hostlist)

    if debug:
        print("Processing " + str(len(hostlist)) + " hosts!")

    post = sshreader.Hook(target=print_job, args=['quick'])
    jobs = []
    for host in hostlist:
        if args.keyfile is not None:
            job = sshreader.ServerJob(host, args.command, username=args.username, keyfile=args.keyfile,
                                      timeout=args.timeout, combine_output=True)
        else:
            job = sshreader.ServerJob(host, args.command, username=args.username, password=args.password,
                                      timeout=args.timeout, combine_output=True)
        if args.quick:
            job.posthook = post
        jobs.append(job)

    if args.quiet:
        job_results = sshreader.sshread(jobs, pcount=0, tcount=0)
    else:
        if args.quick:
            sshreader.sshread(jobs, pcount=0, tcount=0)
            return 0
        else:
            job_results = sshreader.sshread(jobs, pcount=0, tcount=0, progress_bar=True)

    if args.sort:
        complete = [x for x in job_results if x.status == 0]
        failed = [x for x in job_results if x.status != 0]
        for jobresults in (complete, failed):
            process_results(jobresults, args.quiet)
    else:
        process_results(job_results, args.quiet)
    return 0

if __name__ == "__main__":
    sys.exit(main())
