# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_pathway.ipynb.

# %% auto 0
__all__ = ['get_reactome_raw', 'get_reactome', 'query_reactome', 'add_reactome_ref', 'plot_path', 'get_overlap']

# %% ../nbs/06_pathway.ipynb 3
# reactome, pip install reactome2py
from pandas import json_normalize
from reactome2py import analysis
from matplotlib import pyplot as plt
import numpy as np, pandas as pd
from .data import *

# %% ../nbs/06_pathway.ipynb 5
def get_reactome_raw(gene_list):
    "Reactome pathway analysis for a given gene set; returns raw output in dataframe."
    gene_str = ','.join(gene_list)
    # set page size and page to -1 ensures to display all pathway results, sort by pvalue instead of fdr, projection set to True is consistent with official web
    result = analysis.identifiers(gene_str, page_size='-1', page='-1', sort_by='ENTITIES_FDR',projection=True)
    return json_normalize(result['pathways'])

# %% ../nbs/06_pathway.ipynb 8
def get_reactome(gene_list,
                 p_type='FDR', # or p
                ):
    "Reactome pathway analysis for a given gene set; returns formated output in dataframe with additional -log10(p)"
    assert p_type in ['p','FDR']
    col='entities.pValue' if p_type=='p' else 'entities.fdr'
    print('Running pathway anlysis')
    out = get_reactome_raw(gene_list)
    print('Done')
    out = out[['name','stId',col]].rename(columns={col:p_type,'stId':'reactome_id'})
    out[f'-log10_{p_type}'] = -np.log10(out[p_type]).round(3)
    return out

# %% ../nbs/06_pathway.ipynb 16
def query_reactome(uniprot_id):
    """Query uniprot ID in Reactome all level pathway database."""

    ref = Data.get_reactome_pathway()
    ref_lo = Data.get_reactome_pathway_lo()
    
    # Filter specific uniprot
    uniprot_ref = ref[ref.uniprot == uniprot_id].copy()
    uniprot_ref_lo = ref_lo[ref_lo.uniprot == uniprot_id].copy()

    # Group by reactome_id and aggregate other columns
    grouped = uniprot_ref.groupby("reactome_id", as_index=False).agg({
        "uniprot": "first",
        "pathway": "first",
        "type": lambda x: ", ".join(sorted(set(x))), # multiple type
        "species": "first",
    })
    grouped['lowest'] =grouped.reactome_id.isin(uniprot_ref_lo.reactome_id).astype(int)
    return grouped.reset_index(drop=True)

# %% ../nbs/06_pathway.ipynb 20
def add_reactome_ref(df,uniprot):
    path = query_reactome(uniprot)
    df=df.copy()
    df[f'{uniprot}_path_all'] = df.reactome_id.isin(path.reactome_id).astype(int)
    df[f'{uniprot}_path_lowest'] = df.reactome_id.isin(path[path["lowest"] == 1].reactome_id).astype(int)
    return df

# %% ../nbs/06_pathway.ipynb 23
def plot_path(react_df, # the output df of get_reactome
              p_type='FDR', 
              ref_id_list=None, # list of reactome_id
              ref_col = None, # column in reac_df, 1 or 0 to indicate whether it's in ref
              top_n=10, 
              max_label_length=80 ):
    """
    Plot the output of get_reactome.
    If ref_df is provided, bars corresponding to pathways in ref_df are shown in dark red.
    """
    assert p_type in ['p','FDR']
    p_col=f'-log10_{p_type}'
    # Take top_n rows
    subset = react_df.head(top_n)

    # Determine bar colors: if ref_df is provided, highlight matching pathways
    if ref_id_list is not None:
        ref_ids = set(ref_id_list)
        colors = ['darkred' if rid in ref_ids else 'C0' for rid in subset['reactome_id']]
    elif ref_col is not None:
        ref_ids = set(react_df[react_df[ref_col]==1].reactome_id)
        colors = ['darkred' if rid in ref_ids else 'C0' for rid in subset['reactome_id']]
    else:
        colors = 'C0'

    # Reverse order for horizontal bar plot
    data = subset.set_index('name')[p_col].iloc[::-1]
    # If colors is a list, reverse it to match the data order
    if isinstance(colors, list):
        colors = list(reversed(colors))

    # Truncate labels if too long
    truncated_labels = [label[:max_label_length] + '...' if len(label) > max_label_length else label for label in data.index]
    data.index = truncated_labels

    # Calculate figure width based on label length
    base_width = 2
    max_label = max(data.index, key=len)
    additional_width = len(max_label) * 0.1  # adjust scaling factor as needed
    figsize = (base_width + additional_width, 3 * top_n / 10)

    data.plot.barh(figsize=figsize, color=colors)
    plt.ylabel('')
    plt.xlabel(p_col.replace('_', '(', 1) + ')')
    plt.tight_layout()

# %% ../nbs/06_pathway.ipynb 30
def get_overlap(react_df, 
                 ref_id_list=None,
                 ref_col=None,  # column in react_df, 1 or 0 to indicate whether it's in ref
                 p_type='FDR',
                 thr=0.05,  # original threshold of p value, will be log10 transformed
                 plot=True,
                ):
    assert p_type in ['p', 'FDR']
    p_col = f'-log10_{p_type}'
    p_col_convert = p_col.replace('_', '(', 1) + ')'  # e.g., -log10(FDR)

    threshold = -np.log10(thr)

    # Subset based on input
    if ref_id_list is not None:
        subset = react_df[react_df.reactome_id.isin(ref_id_list)].copy()
    elif ref_col is not None:
        subset = react_df[react_df[ref_col] == 1].copy()
    else:
        raise ValueError("Need to give values to ref_id_list or ref_col")

    # Calculate and print statistics
    num_total = len(subset)
    num_pass = (subset[p_col] > threshold).sum()

    percent_pass = (num_pass / num_total) * 100

    # Plot histogram
    if plot:
        subset[p_col].hist(bins=100)
    
        # Add threshold line
        plt.axvline(x=threshold, color='red', linestyle='--', label=f'{p_type} = {thr}')
        plt.legend()
    
        # Label axes
        plt.xlabel(p_col_convert)
        plt.title(f'Histogram of {p_col_convert}')
    
        plt.text(0.66, 0.85, f'{percent_pass:.1f}% ({num_pass}/{num_total}) pass',
                 transform=plt.gca().transAxes,
                 ha='right', va='top', fontsize=10, color='green')

    return float(num_pass / num_total)
