import logging
import os
from urllib.parse import urlparse

from .config import get_bustools_binary_path, get_kallisto_binary_path
from .constants import (
    ADATA_PREFIX,
    BUS_CDNA_PREFIX,
    BUS_FILENAME,
    BUS_FILTERED_FILENAME,
    BUS_INTRON_PREFIX,
    BUS_S_FILENAME,
    BUS_SC_FILENAME,
    BUS_UNFILTERED_FILENAME,
    COUNTS_PREFIX,
    ECMAP_FILENAME,
    FILTER_WHITELIST_FILENAME,
    FILTERED_COUNTS_DIR,
    INSPECT_FILENAME,
    TXNAMES_FILENAME,
    UNFILTERED_COUNTS_DIR,
    WHITELIST_FILENAME,
)
from .utils import (
    copy_whitelist,
    stream_file,
    import_matrix_as_anndata,
    overlay_anndatas,
    run_executable,
    whitelist_provided,
)

logger = logging.getLogger(__name__)


def kallisto_bus(fastqs, index_path, technology, out_dir, threads=8):
    """Runs `kallisto bus`.

    :param fastqs: list of FASTQ file paths
    :type fastqs: list
    :param index_path: path to kallisto index
    :type index_path: str
    :param technology: single-cell technology used
    :type technology: str
    :param out_dir: path to output directory
    :type out_dir: str
    :param threads: number of threads to use, defaults to `8`
    :type threads: int, optional

    :return: dictionary containing path to generated index
    :rtype: dict
    """
    logger.info('Generating BUS file from')
    for fastq in fastqs:
        logger.info((' ' * 8) + fastq)
    command = [get_kallisto_binary_path(), 'bus']
    command += ['-i', index_path]
    command += ['-o', out_dir]
    command += ['-x', technology]
    command += ['-t', threads]
    command += fastqs
    run_executable(command)
    return {
        'bus': os.path.join(out_dir, BUS_FILENAME),
        'ecmap': os.path.join(out_dir, ECMAP_FILENAME),
        'txnames': os.path.join(out_dir, TXNAMES_FILENAME),
    }


def bustools_sort(bus_path, out_path, temp_dir='tmp', threads=8, memory='4G'):
    """Runs `bustools sort`.

    :param bus_path: path to BUS file to sort
    :type bus_path: str
    :param out_dir: path to output directory
    :type out_dir: str
    :param temp_dir: path to temporary directory, defaults to `tmp`
    :type temp_dir: str, optional
    :param threads: number of threads to use, defaults to `8`
    :type threads: int, optional
    :param memory: amount of memory to use, defaults to `4G`
    :type memory: str, optional

    :return: dictionary containing path to generated index
    :rtype: dict
    """
    logger.info('Sorting BUS file {} to {}'.format(bus_path, out_path))
    command = [get_bustools_binary_path(), 'sort']
    command += ['-o', out_path]
    command += ['-T', temp_dir]
    command += ['-t', threads]
    command += ['-m', memory]
    command += [bus_path]
    run_executable(command)
    return {'bus': out_path}


def bustools_inspect(bus_path, out_path, whitelist_path, ecmap_path):
    """Runs `bustools inspect`.

    :param bus_path: path to BUS file to sort
    :type bus_path: str
    :param out_path: path to output inspect JSON file
    :type out_path: str
    :param whitelist_path: path to whitelist
    :type whitelist_path: str
    :param ecmap_path: path to ecmap file, as generated by `kallisto bus`
    :type ecmap_path: str

    :return: dictionary containing path to generated index
    :rtype: dict
    """
    logger.info('Inspecting BUS file {}'.format(bus_path))
    command = [get_bustools_binary_path(), 'inspect']
    command += ['-o', out_path]
    command += ['-w', whitelist_path]
    command += ['-e', ecmap_path]
    command += [bus_path]
    run_executable(command)
    return {'inspect': out_path}


def bustools_correct(bus_path, out_path, whitelist_path):
    """Runs `bustools correct`.

    :param bus_path: path to BUS file to correct
    :type bus_path: str
    :param out_path: path to output corrected BUS file
    :type out_path: str
    :param whitelist_path: path to whitelist
    :type whitelist_path: str

    :return: dictionary containing path to generated index
    :rtype: dict
    """
    logger.info(
        'Correcting BUS records in {} to {} with whitelist {}'.format(
            bus_path, out_path, whitelist_path
        )
    )
    command = [get_bustools_binary_path(), 'correct']
    command += ['-o', out_path]
    command += ['-w', whitelist_path]
    command += [bus_path]
    run_executable(command)
    return {'bus': out_path}


def bustools_count(bus_path, out_prefix, t2g_path, ecmap_path, txnames_path):
    """Runs `bustools count`.

    :param bus_path: path to BUS file to correct
    :type bus_path: str
    :param out_prefix: prefix of the output files to generate
    :type out_prefix: str
    :param t2g_path: path to output transcript-to-gene mapping
    :type t2g_path: str
    :param ecmap_path: path to ecmap file, as generated by `kallisto bus`
    :type ecmap_path: str
    :param txnames_path: path to transcript names file, as generated by `kallisto bus`
    :type txnames_path: str

    :return: dictionary containing path to generated index
    :rtype: dict
    """
    logger.info(
        'Generating count matrix {} from BUS file {}'.format(
            out_prefix, bus_path
        )
    )
    command = [get_bustools_binary_path(), 'count']
    command += ['-o', out_prefix]
    command += ['-g', t2g_path]
    command += ['-e', ecmap_path]
    command += ['-t', txnames_path]
    command += ['--genecounts']
    command += [bus_path]
    run_executable(command)
    return {
        'mtx': '{}.mtx'.format(out_prefix),
        'genes': '{}.genes.txt'.format(out_prefix),
        'barcodes': '{}.barcodes.txt'.format(out_prefix),
    }


def bustools_capture(
        bus_path,
        out_path,
        capture_path,
        ecmap_path,
        txnames_path,
        capture_type='transcripts'
):
    """Runs `bustools capture`.

    :param bus_path: path to BUS file to capture
    :type bus_path: str
    :param out_path: path to BUS file to generate
    :type out_path: str
    :param capture_path: path transcripts-to-capture list
    :type capture_path: str
    :param ecmap_path: path to ecmap file, as generated by `kallisto bus`
    :type ecmap_path: str
    :param txnames_path: path to transcript names file, as generated by `kallisto bus`
    :type txnames_path: str
    :param capture_type: the type of information in the capture list.
                      can be one of `transcripts`, `umis`, `barcode`.
    :type capture_type: str

    :return: dictionary containing path to generated index
    :rtype: dict
    """
    logger.info(
        'Capturing records from BUS file {} to {} with capture list {}'.format(
            bus_path, out_path, capture_path
        )
    )
    command = [get_bustools_binary_path(), 'capture']
    command += ['-o', out_path]
    command += ['-c', capture_path]
    command += ['-e', ecmap_path]
    command += ['-t', txnames_path]
    command += ['--{}'.format(capture_type)]
    command += [bus_path]
    run_executable(command)
    return {'bus': bus_path}


def bustools_whitelist(bus_path, out_path):
    """Runs `bustools whitelist`.

    :param bus_path: path to BUS file generate the whitelist from
    :type bus_path: str
    :param out_path: path to output whitelist
    :type out_path: str

    :return: dictionary containing path to generated index
    :rtype: dict
    """
    logger.info(
        'Generating whitelist {} from BUS file {}'.format(out_path, bus_path)
    )
    command = [
        get_bustools_binary_path(), 'whitelist', '-o', out_path, bus_path
    ]
    run_executable(command)
    return {'whitelist': out_path}


def stream_fastqs(fastqs, temp_dir='tmp'):
    """Given a list of fastqs (that may be local or remote paths), stream any
    remote files. Internally, calls utils.

    :param fastqs: list of (remote or local) fastq paths
    :type fastqs: list
    :param temp_dir: temporary directory
    :type temp_dir: str

    :return: all remote paths substituted with a local path
    :rtype: list
    """
    return [
        stream_file(fastq, os.path.join(temp_dir, os.path.basename(fastq)))
        if urlparse(fastq).scheme in ('http', 'https', 'ftp', 'ftps') else fastq
        for fastq in fastqs
    ]


def copy_or_create_whitelist(technology, bus_path, out_dir):
    """Copies a pre-packaged whitelist if it is provided. Otherwise, runs
    `bustools whitelist` to generate a whitelist.

    :param technology: single-cell technology used
    :type technology: str
    :param bus_path: path to BUS file generate the whitelist from
    :type bus_path: str
    :param out_dir: path to output directory
    :type out_dir: str

    :return: path to copied or generated whitelist
    :rtype: str
    """
    if whitelist_provided(technology):
        logger.info(
            'Copying pre-packaged {} whitelist to {}'.format(
                technology.upper(), out_dir
            )
        )
        return copy_whitelist(technology, out_dir)
    else:
        return bustools_whitelist(
            bus_path, os.path.join(out_dir, WHITELIST_FILENAME)
        )['whitelist']


def convert_matrix_to_loom(matrix_path, barcodes_path, genes_path, out_path):
    """Converts a matrix to loom.

    :param matrix_path: path to matrix mtx file
    :type matrix_path: str
    :param barcodes_path: path list of barcodes
    :type barcodes_path: str
    :param genes_path: path to list of genes
    :type genes_path: str
    :param out_path: path to output loom file
    :type out_path: str

    :return: path to loom file
    :rtype: str
    """
    logger.info('Converting matrix {} to loom {}'.format(matrix_path, out_path))
    adata = import_matrix_as_anndata(matrix_path, barcodes_path, genes_path)
    adata.write_loom(out_path)
    return out_path


def convert_matrix_to_h5ad(matrix_path, barcodes_path, genes_path, out_path):
    """Converts a matrix to h5ad.

    :param matrix_path: path to matrix mtx file
    :type matrix_path: str
    :param barcodes_path: path list of barcodes
    :type barcodes_path: str
    :param genes_path: path to list of genes
    :type genes_path: str
    :param out_path: path to output h5ad file
    :type out_path: str

    :return: path to h5ad file
    :rtype: str
    """
    logger.info('Converting matrix {} to h5ad {}'.format(matrix_path, out_path))
    adata = import_matrix_as_anndata(matrix_path, barcodes_path, genes_path)
    adata.write(out_path)
    return out_path


def count(
        index_path,
        t2g_path,
        technology,
        out_dir,
        fastqs,
        whitelist_path=None,
        filter=None,
        temp_dir='tmp',
        threads=8,
        memory='4G',
        overwrite=False,
        loom=False,
        h5ad=False,
):
    """Generates count matrices for single-cell RNA seq.

    :param index_path: path to kallisto index
    :type index_path: str
    :param t2g_path: path to transcript-to-gene mapping
    :type t2g_path: str
    :param technology: single-cell technology used
    :type technology: str
    :param out_dir: path to output directory
    :type out_dir: str
    :param fastqs: list of FASTQ file paths
    :type fastqs: list
    :param whitelist_path: path to whitelist, defaults to `None`
    :type whitelist_path: str, optional
    :param filter: filter to use to generate a filtered count matrix,
                   defaults to `None`
    :type filter: str, optional
    :param temp_dir: path to temporary directory, defaults to `tmp`
    :type temp_dir: str, optional
    :param threads: number of threads to use, defaults to `8`
    :type threads: int, optional
    :param memory: amount of memory to use, defaults to `4G`
    :type memory: str, optional
    :param overwrite: overwrite an existing index file, defaults to `False`
    :type overwrite: bool, optional
    :param loom: whether to convert the final count matrix into a loom file,
                 defaults to `False`
    :type loom: bool, optional
    :param h5ad: whether to convert the final count matrix into a h5ad file,
                 defaults to `False`
    :type h5ad: bool, optional

    :return: dictionary containing path to generated index
    :rtype: dict
    """
    results = {}

    os.makedirs(out_dir, exist_ok=True)
    unfiltered_results = results.setdefault(
        'unfiltered', {}
    ) if filter else results

    bus_result = {
        'bus': os.path.join(out_dir, BUS_FILENAME),
        'ecmap': os.path.join(out_dir, ECMAP_FILENAME),
        'txnames': os.path.join(out_dir, TXNAMES_FILENAME),
    }
    if any(not os.path.exists(path)
           for name, path in bus_result.items()) or overwrite:
        # Pipe any remote files.
        fastqs = stream_fastqs(fastqs, temp_dir=temp_dir)
        bus_result = kallisto_bus(
            fastqs, index_path, technology, out_dir, threads=threads
        )
    else:
        logger.info(
            'Skipping kallisto bus because output files already exist. Use the --overwrite flag to overwrite.'
        )
    unfiltered_results.update(bus_result)

    sort_result = bustools_sort(
        bus_result['bus'],
        os.path.join(temp_dir, BUS_S_FILENAME),
        temp_dir=temp_dir,
        threads=threads,
        memory=memory
    )
    if not whitelist_path:
        logger.info('Whitelist not provided')
        whitelist_path = copy_or_create_whitelist(
            technology, sort_result['bus'], out_dir
        )
        unfiltered_results.update({'whitelist': whitelist_path})

    inspect_result = bustools_inspect(
        sort_result['bus'], os.path.join(out_dir, INSPECT_FILENAME),
        whitelist_path, bus_result['ecmap']
    )
    unfiltered_results.update(inspect_result)
    correct_result = bustools_correct(
        sort_result['bus'], os.path.join(temp_dir, BUS_SC_FILENAME),
        whitelist_path
    )
    sort2_result = bustools_sort(
        correct_result['bus'],
        os.path.join(out_dir, BUS_UNFILTERED_FILENAME),
        temp_dir=temp_dir,
        threads=threads,
        memory=memory
    )
    unfiltered_results.update({'bus_scs': sort2_result['bus']})

    counts_dir = os.path.join(out_dir, UNFILTERED_COUNTS_DIR)
    os.makedirs(counts_dir, exist_ok=True)
    counts_prefix = os.path.join(counts_dir, COUNTS_PREFIX)
    count_result = bustools_count(
        sort2_result['bus'],
        counts_prefix,
        t2g_path,
        bus_result['ecmap'],
        bus_result['txnames'],
    )
    unfiltered_results.update(count_result)

    # Convert outputs.
    if loom:
        loom_path = convert_matrix_to_loom(
            count_result['mtx'], count_result['barcodes'],
            count_result['genes'],
            os.path.join(counts_dir, '{}.loom'.format(ADATA_PREFIX))
        )
        unfiltered_results.update({'loom': loom_path})
    if h5ad:
        h5ad_path = convert_matrix_to_h5ad(
            count_result['mtx'], count_result['barcodes'],
            count_result['genes'],
            os.path.join(counts_dir, '{}.h5ad'.format(ADATA_PREFIX))
        )
        unfiltered_results.update({'h5ad': h5ad_path})

    if filter == 'bustools':
        logger.info('Filtering')
        filtered_results = results.setdefault('filtered', {})
        filtered_whitelist_result = bustools_whitelist(
            sort2_result['bus'],
            os.path.join(out_dir, FILTER_WHITELIST_FILENAME)
        )
        filtered_results.update(filtered_whitelist_result)

        filtered_capture_result = bustools_capture(
            sort2_result['bus'],
            os.path.join(temp_dir, BUS_FILTERED_FILENAME),
            filtered_whitelist_result['whitelist'],
            bus_result['ecmap'],
            bus_result['txnames'],
            capture_type='barcode',
        )
        filtered_sort_result = bustools_sort(
            filtered_capture_result['bus'],
            os.path.join(out_dir, BUS_FILTERED_FILENAME),
            temp_dir=temp_dir,
            threads=threads,
            memory=memory,
        )
        filtered_results.update({'bus_scs': filtered_sort_result['bus']})

        filtered_counts_dir = os.path.join(out_dir, FILTERED_COUNTS_DIR)
        os.makedirs(filtered_counts_dir, exist_ok=True)
        filtered_counts_prefix = os.path.join(
            filtered_counts_dir, COUNTS_PREFIX
        )
        filtered_count_result = bustools_count(
            filtered_sort_result['bus'],
            filtered_counts_prefix,
            t2g_path,
            bus_result['ecmap'],
            bus_result['txnames'],
        )
        filtered_results.update(filtered_count_result)

        if loom:
            filtered_loom_path = convert_matrix_to_loom(
                filtered_count_result['mtx'],
                filtered_count_result['barcodes'],
                filtered_count_result['genes'],
                os.path.join(
                    filtered_counts_dir, '{}.loom'.format(ADATA_PREFIX)
                ),
            )
            filtered_results.update({'loom': filtered_loom_path})
        if h5ad:
            filtered_h5ad_path = convert_matrix_to_h5ad(
                filtered_count_result['mtx'],
                filtered_count_result['barcodes'],
                filtered_count_result['genes'],
                os.path.join(
                    filtered_counts_dir, '{}.h5ad'.format(ADATA_PREFIX)
                ),
            )
            filtered_results.update({'h5ad': filtered_h5ad_path})

    return results


def count_lamanno(
        index_path,
        t2g_path,
        cdna_t2c_path,
        intron_t2c_path,
        technology,
        out_dir,
        fastqs,
        whitelist_path=None,
        temp_dir='tmp',
        threads=8,
        memory='4G',
        overwrite=False,
        loom=False,
        h5ad=False,
):
    """Generates RNA velocity matrices for single-cell RNA seq.

    :param index_path: path to kallisto index
    :type index_path: str
    :param t2g_path: path to transcript-to-gene mapping
    :type t2g_path: str
    :param cdna_t2c_path: path to cDNA transcripts-to-capture file
    :type cdna_t2c_path: str
    :param intron_t2c_path: path to intron transcripts-to-capture file
    :type intron_t2c_path: str
    :param technology: single-cell technology used
    :type technology: str
    :param out_dir: path to output directory
    :type out_dir: str
    :param fastqs: list of FASTQ file paths
    :type fastqs: list
    :param whitelist_path: path to whitelist, defaults to `None`
    :type whitelist_path: str, optional
    :param temp_dir: path to temporary directory, defaults to `tmp`
    :type temp_dir: str, optional
    :param threads: number of threads to use, defaults to `8`
    :type threads: int, optional
    :param memory: amount of memory to use, defaults to `4G`
    :type memory: str, optional
    :param overwrite: overwrite an existing index file, defaults to `False`
    :type overwrite: bool, optional
    :param loom: whether to convert the final count matrix into a loom file,
                 defaults to `False`
    :type loom: bool, optional
    :param h5ad: whether to convert the final count matrix into a h5ad file,
                 defaults to `False`
    :type h5ad: bool, optional

    :return: dictionary containing path to generated index
    :rtype: dict
    """
    results = {}

    bus_result = {
        'bus': os.path.join(out_dir, BUS_FILENAME),
        'ecmap': os.path.join(out_dir, ECMAP_FILENAME),
        'txnames': os.path.join(out_dir, TXNAMES_FILENAME),
    }
    if any(not os.path.exists(path)
           for name, path in bus_result.items()) or overwrite:
        fastqs = stream_fastqs(fastqs, temp_dir=temp_dir)
        bus_result = kallisto_bus(
            fastqs, index_path, technology, out_dir, threads=threads
        )
    else:
        logger.info(
            'Skipping kallisto bus because output files already exist. Use the --overwrite flag to overwrite.'
        )
    results.update(bus_result)

    sort_result = bustools_sort(
        bus_result['bus'],
        os.path.join(temp_dir, BUS_S_FILENAME),
        temp_dir=temp_dir,
        threads=threads,
        memory=memory
    )
    if not whitelist_path:
        logger.info('Whitelist not provided')
        whitelist_path = copy_or_create_whitelist(
            technology, sort_result['bus'], out_dir
        )
        results.update({'whitelist': whitelist_path})

    inspect_result = bustools_inspect(
        sort_result['bus'], os.path.join(out_dir, INSPECT_FILENAME),
        whitelist_path, bus_result['ecmap']
    )
    results.update(inspect_result)
    correct_result = bustools_correct(
        sort_result['bus'], os.path.join(temp_dir, BUS_SC_FILENAME),
        whitelist_path
    )
    sort2_result = bustools_sort(
        correct_result['bus'],
        os.path.join(out_dir, BUS_UNFILTERED_FILENAME),
        temp_dir=temp_dir,
        threads=threads,
        memory=memory
    )
    results.update({'bus_scs': sort2_result['bus']})

    prefix_to_t2c = {
        BUS_CDNA_PREFIX: cdna_t2c_path,
        BUS_INTRON_PREFIX: intron_t2c_path,
    }
    counts_dir = os.path.join(out_dir, UNFILTERED_COUNTS_DIR)
    os.makedirs(counts_dir, exist_ok=True)
    for prefix, t2c_path in prefix_to_t2c.items():
        capture_result = bustools_capture(
            sort2_result['bus'],
            os.path.join(temp_dir, '{}.bus'.format(prefix)), t2c_path,
            bus_result['ecmap'], bus_result['txnames']
        )
        sort_result = bustools_sort(
            capture_result['bus'],
            os.path.join(out_dir, '{}.s.bus'.format(prefix)),
            temp_dir=temp_dir,
            threads=threads,
            memory=memory
        )

        if prefix not in results:
            results[prefix] = {}
        results[prefix].update({'bus_s': sort_result['bus']})

        counts_prefix = os.path.join(counts_dir, prefix)
        count_result = bustools_count(
            sort_result['bus'],
            counts_prefix,
            t2g_path,
            bus_result['ecmap'],
            bus_result['txnames'],
        )
        results[prefix].update(count_result)

    if loom or h5ad:
        adatas = {}
        for prefix in prefix_to_t2c:
            adatas[prefix] = import_matrix_as_anndata(
                results[prefix]['mtx'], results[prefix]['barcodes'],
                results[prefix]['genes']
            )
        adata = overlay_anndatas(adatas['spliced'], adatas['unspliced'])
        if loom:
            loom_path = os.path.join(counts_dir, '{}.loom'.format(ADATA_PREFIX))
            logger.info('Writing matrices to loom {}'.format(loom_path))
            adata.write_loom(loom_path)
            results.update({'loom': loom_path})
        if h5ad:
            h5ad_path = os.path.join(counts_dir, '{}.h5ad'.format(ADATA_PREFIX))
            logger.info('Writing matrices to h5ad {}'.format(h5ad_path))
            adata.write(h5ad_path)
            results.update({'h5ad': h5ad_path})
    return results
