"""Functions to plot motif logo, heatmap, scatter plot, and others."""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/05_plot.ipynb.

# %% auto 0
__all__ = ['sty_color', 'group_color', 'pspa_category_color', 'set_sns', 'save_svg', 'save_pdf', 'save_show', 'get_color_dict',
           'get_subfamily_color', 'get_plt_color', 'get_hue_big', 'reduce_feature', 'plot_2d', 'plot_cluster',
           'plot_bokeh', 'plot_rank', 'plot_hist', 'plot_count', 'plot_bar', 'plot_group_bar', 'plot_stacked',
           'plot_violin', 'add_stats', 'plot_box', 'plot_rel', 'get_similarity', 'plot_corr', 'get_AUCDF',
           'plot_confusion_matrix', 'plot_pie', 'calculate_pct', 'plot_composition', 'plot_cnt']

# %% ../nbs/05_plot.ipynb 3
import joblib,logomaker
import pandas as pd, numpy as np, seaborn as sns
from adjustText import adjust_text
from pathlib import Path
from tqdm import tqdm
from fastcore.meta import delegates

from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator
from numpy import trapz

# Katlas
from .data import *

# Bokeh
from bokeh.io import output_notebook, show
from bokeh.plotting import figure, ColumnDataSource
from bokeh.models import HoverTool, AutocompleteInput, CustomJS
from bokeh.layouts import column
from bokeh.palettes import Category20_20
from itertools import cycle
import math

# Dimension Reduction
from sklearn import set_config
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap.umap_ import UMAP

from sklearn.metrics import pairwise_distances

import matplotlib.ticker as mticker

import matplotlib as mpl

# for statistical annotations
from statannotations.Annotator import Annotator
import itertools

# %% ../nbs/05_plot.ipynb 5
def set_sns(dpi=300):
    "Set seaborn resolution for notebook display"
    sns.set(rc={"figure.dpi":dpi, 'savefig.dpi':dpi}) # savefig.dpi is ignored when saved in svg or pdf
    sns.set_context('notebook')
    sns.set_style("ticks")

# %% ../nbs/05_plot.ipynb 6
def save_svg(path): 
    plt.rcParams['svg.fonttype'] = 'none'
    return plt.savefig(path, format='svg', bbox_inches='tight',transparent=True)

# %% ../nbs/05_plot.ipynb 7
def save_pdf(path): 
    mpl.rcParams['pdf.fonttype'] = 42  # Use TrueType fonts for Illustrator compatibility
    mpl.rcParams['ps.fonttype'] = 42   # Also good for EPS, if needed
    plt.savefig(path, format='pdf', bbox_inches='tight',transparent=True)

# %% ../nbs/05_plot.ipynb 8
def save_show(path=None, # image path, e.g., img.svg, if not None, will save, else plt.show()
              show_only=False,
             ):
    "Show plot or save path"
    if show_only: plt.show()
    elif path is not None: plt.savefig(path, bbox_inches='tight', pad_inches=0.05,transparent=True)
    else: plt.show()
    plt.close('all') # close all figures to avoid memory leak

# %% ../nbs/05_plot.ipynb 10
def get_color_dict(categories, # list of names to assign color
                   palette: str='tab20', # choose from sns.color_palette
                   ):
    "Assign colors to a list of names (allow duplicates), returns a dictionary of unique name with corresponding color"
    p=sns.color_palette(palette)
    color_cycle = cycle(p)
    color_map = {category: next(color_cycle) for category in categories}
    return color_map

# %% ../nbs/05_plot.ipynb 11
sty_color=get_color_dict(['S','T','Y'])| get_color_dict(['s','t','y'])

# %% ../nbs/05_plot.ipynb 14
group_color=get_color_dict(
            ['CMGC','AGC', # blue
             'TK','TKL', # orange
             'CAMK','STE', # green
             'CK1', 'NEK', # red
             'Atypical','Other', # purple
             'RGC'
            ]
)

# %% ../nbs/05_plot.ipynb 16
def get_subfamily_color():
    group_color2 = pd.DataFrame(group_color).T
    group_color2 = group_color2.reset_index(names='modi_group')
    info=Data.get_kinase_info()
    subfamily_color = info[['modi_group','subfamily']].merge(group_color2).drop(columns=['modi_group']).set_index('subfamily')
    subfamily_color = subfamily_color.apply(tuple, axis=1).to_dict()
    return subfamily_color

# %% ../nbs/05_plot.ipynb 18
pspa_category_color = get_color_dict(['Basophilic', 'Pro-directed', 'Acidophilic', 'Map3k', 'Map4k',
       'Alpha/mlk', 'Fgf and vegf receptors', 'Assorted', 'Ripk/wnk', 'Pkc',
       'Ephrin receptors', 'Eif2ak/tlk', 'Nek/ask', 'Pdgf receptors', 'Src',
       'Jak', 'Ulk/ttbk', 'Cmgc', 'Tec', 'Tam receptors'])

# %% ../nbs/05_plot.ipynb 20
def get_plt_color(palette, # dict, list, or set name (tab10)
                  columns, # columns in the df for plot
                 ):
    "Given a dict, list or set name, return the list of names; if dict, need to provide column names of the df."
    if isinstance(palette, dict):
        # Match colors to column order in pct_df
        colors = [palette.get(col, '#cccccc') for col in columns]  # fallback color if missing
    elif isinstance(palette, str):
        colors = sns.color_palette(palette, n_colors=len(columns))
    elif isinstance(palette, list):
        colors = palette
    return colors

# %% ../nbs/05_plot.ipynb 22
def get_hue_big(df,
                hue_col, # column of hue
                cnt_thr=10, # higher or equal to this threshold will be considered
               ):
    "Get part of hue according to its value counts; applied when the groups are too many."
    cnt = df[hue_col].value_counts()
    names = cnt[cnt>=cnt_thr].index
    return df[hue_col][df[hue_col].isin(names)]

# %% ../nbs/05_plot.ipynb 27
def reduce_feature(df: pd.DataFrame, 
                   method: str='pca', # dimensionality reduction method, accept both capital and lower case
                   complexity: int=20, # None for PCA; perfplexity for TSNE, recommend: 30; n_neigbors for UMAP, recommend: 15
                   n: int=2, # n_components
                   load: str=None, # load a previous model, e.g. model.pkl
                   save: str=None, # pkl file to be saved, e.g. pca_model.pkl
                   seed: int=123, # seed for random_state
                   **kwargs, # arguments from PCA, TSNE, or UMAP depends on which method to use
                  ):
    
    "Reduce the dimensionality given a dataframe of values"
    
    method = method.lower()
    assert method in ['pca','tsne','umap'], "Please choose a method among PCA, TSNE, and UMAP"
    
    if load is not None:
        reducer = joblib.load(load)
    else:
        if method == 'pca':
            reducer = PCA(n_components=n, random_state=seed,**kwargs)
        elif method == 'tsne':
            reducer = TSNE(n_components=n,
                           random_state=seed, 
                           perplexity = complexity, # default from official is 30 
                          **kwargs)
        elif method == 'umap':
            reducer = UMAP(n_components=n, 
                           random_state=seed, 
                           n_neighbors=complexity, # default from official is 15, try 15-200
                          **kwargs)
        else:
            raise ValueError('Invalid method specified')

    proj = reducer.fit_transform(df)
    embedding_df = pd.DataFrame(proj).set_index(df.index)
    embedding_df.columns = [f"{method.upper()}{i}" for i in range(1, n + 1)]

    if save is not None:
        path = Path(save)
        path.parent.mkdir(exist_ok=True)
        
        joblib.dump(reducer, save)

    return embedding_df

# %% ../nbs/05_plot.ipynb 31
def plot_2d(
    embedding_df: pd.DataFrame,  # a dataframe of values that is waited for dimensionality reduction
    hue: str = None,  # colname of color
    complexity: int = 30,  # this argument does not affect pca but others; recommend 30 for tsne, 15 for umap
    palette: str = 'tab20',  # color scheme, could be tab10 if less categories
    legend: bool = False,  # whether or not add the legend on the side
    name_list=None,  # a list of names to annotate each dot in the plot
    seed: int = 123,  # seed for dimensionality reduction
    s: int = 20,  # size of the dot
    legend_title: str = None,  # new argument to override legend title
    **kwargs  # arguments for dimensional reduction method to be used
):
    """
    Given a dataframe of values, plot it in 2D. 
    The method could be 'pca', 'tsne', or 'umap'.
    """
    x_col, y_col = embedding_df.columns 
    
    g = sns.relplot(
        data=embedding_df, x=x_col, y=y_col, hue=hue, palette=palette, s=s, alpha=0.8, legend=legend,**kwargs
    )
    plt.xticks([])
    plt.yticks([])

    # Override legend title if specified
    if legend and legend_title is not None:
        if g._legend is not None:
            g._legend.set_title(legend_title)

    # Add text annotations
    if name_list is not None:
        ax = g.ax
        texts = [
            ax.text(
                embedding_df[x_col].iloc[i], embedding_df[y_col].iloc[i], str(name_list[i]), fontsize=8
            ) for i in range(len(embedding_df))
        ]
        adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black'))

# %% ../nbs/05_plot.ipynb 34
def plot_cluster(
    df: pd.DataFrame,  # a dataframe of values that is waited for dimensionality reduction
    method: str = 'pca',  # dimensionality reduction method, choose from pca, umap, and tsne
    hue: str = None,  # colname of color
    complexity: int = 30,  # this argument does not affect pca but others; recommend 30 for tsne, 15 for umap
    palette: str = 'tab20',  # color scheme, could be tab10 if less categories
    legend: bool = False,  # whether or not add the legend on the side
    name_list=None,  # a list of names to annotate each dot in the plot
    seed: int = 123,  # seed for dimensionality reduction
    s: int = 50,  # size of the dot
    legend_title: str = None,  # new argument to override legend title
    **kwargs  # arguments for dimensional reduction method to be used
):
    """
    Given a dataframe of values, plot it in 2D. 
    The method could be 'pca', 'tsne', or 'umap'.
    """
    
    embedding_df = reduce_feature(df, method=method, seed=seed, complexity=complexity, **kwargs)
    x_col, y_col = embedding_df.columns 
    
    g = sns.relplot(
        data=embedding_df, x=x_col, y=y_col, hue=hue, palette=palette, s=s, alpha=0.8, legend=legend
    )
    plt.xticks([])
    plt.yticks([])

    # Override legend title if specified
    if legend and legend_title is not None:
        if g._legend is not None:
            g._legend.set_title(legend_title)

    # Add text annotations
    if name_list is not None:
        ax = g.ax
        texts = [
            ax.text(
                embedding_df[x_col].iloc[i], embedding_df[y_col].iloc[i], str(name_list[i]), fontsize=8
            ) for i in range(len(embedding_df))
        ]
        adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black'))

# %% ../nbs/05_plot.ipynb 37
def plot_bokeh(X:pd.DataFrame, # a dataframe of two columns from dimensionality reduction
               idx, # pd.Series or list that indicates identities for searching box
               hue:None, # pd.Series or list that indicates category for each sample
               s: int=3, # dot size
               **kwargs # key:args format for information to include in the dot information box
               ):
    
    "Make interactive 2D plot with a searching box and window of dot information when pointing "
        
    output_notebook()
    
    idx = list(idx)
    hue = list(hue)
    
    def assign_colors(categories, palette):
        "assign each unique name in a list with a color, returns a color list of same length"
        color_cycle = cycle(palette)
        color_map = {category: next(color_cycle) for category in categories}
        return [color_map[category] for category in categories]
    
    if hue is not None:
        colors  = assign_colors(hue, Category20_20) 
    else:
        colors = ['navy'] * len(X)
    
    data_dict={
    'x': X.iloc[:,0],
    'y': X.iloc[:,1],
    'identity': idx,
    'color': colors,
    'original_color': colors,
    'size': [s] * len(X), 
    'highlighted': ['no'] * len(X)  # To keep track of which dot is highlighted
    }
    
    for key, value in kwargs.items():
        data_dict[key] = value
    
    source = ColumnDataSource(data=data_dict)
    
    p = figure(tools="pan,box_zoom,wheel_zoom,reset")
    p.scatter('x', 'y', source=source, alpha=0.6, color='color', size='size')

    # Disable grid lines
    p.xgrid.visible = False
    p.ygrid.visible = False
    
    # Add hover tool
    hover = HoverTool()
    
    tooltips = [("Identity", "@identity")]

    for key in kwargs.keys():
        tooltips.append((key, f"@{key}"))

    
    hover.tooltips = tooltips
    p.add_tools(hover)
    
    
    autocomplete = AutocompleteInput(title="Search by Identity:", completions=idx)

    callback = CustomJS(args=dict(source=source, plot=p), code="""
        const data = source.data;
        const search_val = cb_obj.value.toLowerCase();
        const x = data['x'];
        const y = data['y'];
        const identity = data['identity'];
        const color = data['color'];
        const original_color = data['original_color'];
        const size = data['size'];
        const highlighted = data['highlighted'];

        for (let i = 0; i < identity.length; i++) {
            if (highlighted[i] === 'yes') {
                color[i] = original_color[i];
                size[i] = 10;
                highlighted[i] = 'no';
            }
            if (identity[i].toLowerCase() === search_val) {
                plot.x_range.start = x[i] - 5;
                plot.x_range.end = x[i] + 5;
                plot.y_range.start = y[i] - 5;
                plot.y_range.end = y[i] + 5;
                color[i] = 'red';
                size[i] = 15;
                highlighted[i] = 'yes';
            }
        }
        source.change.emit();
    """)
    autocomplete.js_on_change('value', callback)

    # Show layout
    layout = column(autocomplete, p)
    show(layout)

# %% ../nbs/05_plot.ipynb 40
@delegates(sns.scatterplot)
def plot_rank(sorted_df: pd.DataFrame, # a sorted dataframe
              x: str, # column name for x axis
              y: str, # column name for y aixs
              n_hi: int=10, # if not None, show the head n names
              n_lo: int=10, # if not None, show the tail n names
              figsize: tuple=(10,8), # figure size
              **kwargs # arguments for sns.scatterplot()
              ):
    
    "Plot rank from a sorted dataframe"

    plt.figure(figsize=figsize)
    
    sorted_df = sorted_df.reset_index(drop=True) # drop customized index
    
    sns_plot = sns.scatterplot(data=sorted_df, 
                               x = x,
                               y = y, **kwargs)

    sns_plot.set_xticks([])
    
    
    texts = []
    
    # Annotate the highest n values
    if n_hi is not None:
        
        for i, row in sorted_df.head(n_hi).iterrows():
            texts.append(plt.text(i, row[y], row[x], ha='center', va='bottom'))

    if n_lo is not None:
        # Annotate the lowest n values
        n_lowest = n_lo
        for i, row in sorted_df.tail(n_lowest).iterrows():
            texts.append(plt.text(i, row[y], row[x], ha='center', va='bottom'))
            
    if len(texts)>0:
        # Use adjustText to adjust text positions
        adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black'))
    plt.ylabel(y)
    plt.tight_layout()

# %% ../nbs/05_plot.ipynb 44
@delegates(sns.histplot)
def plot_hist(df: pd.DataFrame, # a dataframe that contain values for plot
              x: str, # column name of values
              figsize: tuple=(6,2),
              **kwargs, # arguments for sns.histplot()
             ):
    
    hist_params = {'element':'poly',
              'edgecolor': None,
              'alpha':0.5,
              'bins':100,
              'kde':True}
    
    plt.figure(figsize=figsize)
    sns.histplot(data=df,x=x,**hist_params,**kwargs)

# %% ../nbs/05_plot.ipynb 48
def plot_count(cnt, # from df['x'].value_counts()
               tick_spacing: float= None, # tick spacing for x axis
               palette: str='tab20'):
    
    "Make bar plot from df['x'].value_counts()"

    cnt = cnt.sort_values(ascending=True).copy()
    
    c = sns.color_palette(palette)
    ax = cnt.plot.barh(color = c)
    ax.set_ylabel("")

    for index, value in enumerate(cnt):
        plt.text(value, index, str(value),fontsize=10,rotation=-90, va='center')
        # Set x-ticks at regular intervals
    if tick_spacing is not None:
        ax.xaxis.set_major_locator(MultipleLocator(tick_spacing))

# %% ../nbs/05_plot.ipynb 51
@delegates(sns.barplot)
def plot_bar(df, 
             value, # colname of value
             group, # colname of group
             title = None,
             figsize = (12,5),
             fontsize=14,
             dots = True, # whether or not add dots in the graph
             rotation=90,
             ascending=False,
             ymin=None,
             **kwargs
              ):
    
    "Plot bar graph from unstacked dataframe; need to indicate columns of values and categories"
    
    plt.figure(figsize=figsize)
    
    idx = df.groupby(group)[value].mean().sort_values(ascending=ascending).index
    
    sns.barplot(data=df, x=group, y=value, order=idx,hue=group, legend=False, **kwargs)
    
    if dots:
        marker = {'marker': 'o', 
                  'color': 'white', 
                  'edgecolor': 'black', 
                  'linewidth': 1.5, 
                  'jitter':True,
                  's': 5}

        sns.stripplot(data=df, 
                      x=group, 
                      y=value,
                      order=idx,
                      alpha=0.8,
                      # ax=g.ax,
                      **marker)
        
    # Increase font size for the x-axis and y-axis tick labels
    plt.tick_params(axis='x', labelsize=fontsize)  # Increase x-axis label size
    plt.tick_params(axis='y', labelsize=fontsize)  # Increase y-axis label size
    
    # Modify x and y label and increase font size
    plt.xlabel('', fontsize=fontsize)
    plt.ylabel(value, fontsize=fontsize)
    
    # Rotate X labels
    plt.xticks(rotation=rotation)
    
    # Plot titles
    if title is not None: plt.title(title,fontsize=fontsize)

    # set ymin limit
    if ymin is not None: plt.ylim(bottom=ymin)
    plt.gca().spines[['right', 'top']].set_visible(False)

# %% ../nbs/05_plot.ipynb 54
@delegates(sns.barplot)
def plot_group_bar(df, 
                   value_cols,  # list of column names for values, the order depends on the first item
                   group,       # column name of group (e.g., 'kinase')
                   figsize=(12, 5),
                   order=None,
                   title=None,
                   fontsize=14,
                   rotation=90,
                   **kwargs):
    
    " Plot grouped bar graph from dataframe. "

    # Prepare the dataframe for plotting
    # Melt the dataframe to go from wide to long format
    df_melted = df.melt(id_vars=group, value_vars=value_cols, var_name='Ranking', value_name='Value')

    plt.figure(figsize=figsize)
    
    # Create the bar plot
    sns.barplot(data=df_melted, 
                x=group, 
                y='Value', 
                hue='Ranking', 
                order=order, 
                capsize=0.1,
                err_kws={'linewidth': 1.5,'color': 'gray'}, 
                alpha=1.0,
                **kwargs)
    
    # Increase font size for the x-axis and y-axis tick labels
    plt.tick_params(axis='x', labelsize=fontsize)  # Increase x-axis label size
    plt.tick_params(axis='y', labelsize=fontsize)  # Increase y-axis label size
    
    # Modify x and y label and increase font size
    plt.xlabel('', fontsize=fontsize)
    plt.ylabel('Value', fontsize=fontsize)
    
    # Rotate X labels
    plt.xticks(rotation=rotation)
    
    # Plot titles
    if title is not None:
        plt.title(title, fontsize=fontsize)
    
    plt.gca().spines[['right', 'top']].set_visible(False)
    # plt.legend(fontsize=fontsize) # if change legend location, use loc='upper right'
    plt.legend(
        fontsize=fontsize,
        loc="upper left",
        bbox_to_anchor=(1.02, 1),
        borderaxespad=0
    )

# %% ../nbs/05_plot.ipynb 57
def plot_stacked(df, column, hue, figsize=(5, 4),xlabel=None, ylabel=None, add_value=True, **kwargs):
    plt.figure(figsize=figsize)
    
    ax = sns.histplot(
        data=df,
        x=column,
        hue=hue,
        multiple='stack',
        discrete=True,
        shrink=0.8,
        alpha=1.0,
        **kwargs
    )

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xticks(rotation=0)
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{int(x):,}'))

    # Add total count on top of each bar
    if add_value:
        total_counts = df[column].value_counts().sort_index()
        for idx, (label, count) in enumerate(total_counts.items()):
            ax.text(idx, count + 1, str(count), ha='center', va='bottom', fontsize=9)

    plt.tight_layout()

# %% ../nbs/05_plot.ipynb 59
def plot_violin(
    data, 
    value='value',
    group='variable', 
    ylabel=None, 
    dots=True, 
    figsize=(5,3),
    **kwargs
):
    "Plot violin plots (with optional strip dots) for long-form data."

    plt.figure(figsize=figsize)

    ax=sns.violinplot(
        data=data, x=group, y=value,
        inner='box', linewidth=1,
        cut=0,         # prevents tails extending beyond data range
        bw_adjust=0.7,  # the smaller, the shape have more curves that fit to the data
        hue=group,**kwargs
    )

    if dots:
        sns.stripplot(
            data=data, x=group, y=value,
            color='k', size=2, jitter=0.1, alpha=0.6
        )

    plt.xlabel('')
    if ylabel is not None: plt.ylabel(ylabel)
    plt.tight_layout()
    return ax

# %% ../nbs/05_plot.ipynb 61
def add_stats(ax,data,value='value',group='variable',pairs=None,test='t-test_ind',loc='inside',text_format='star',**kwargs):
    group_items = data[group].unique()
    if pairs is None: pairs = list(itertools.combinations(group_items, 2))
    annotator = Annotator(ax, pairs, data=data, x=group, y=value)
    annotator.configure(test=test, text_format=text_format, loc=loc,verbose=False,**kwargs)
    annotator.apply_and_annotate()

# %% ../nbs/05_plot.ipynb 63
@delegates(sns.boxplot)
def plot_box(df,
             value, # colname of value
             group, # colname of group
             title=None, 
             figsize=(6,3),
             fontsize=14,
             dots=True, 
             rotation=90,
             **kwargs
            ):
    
    "Plot box plot."
    
    plt.figure(figsize=figsize)
    
    idx = df[[group,value]].groupby(group).median().sort_values(value,ascending=False).index
    
    
    sns.boxplot(data=df, x=group, y=value, order=idx,hue=group, legend=False, **kwargs)
    
    if dots:
        sns.stripplot(x=group, y=value, data=df, order=idx, jitter=True, color='black', size=3)
        

    # Increase font size for the x-axis and y-axis tick labels
    plt.tick_params(axis='x', labelsize=fontsize)  # Increase x-axis label size
    plt.tick_params(axis='y', labelsize=fontsize)  # Increase y-axis label size

    plt.xlabel('', fontsize=fontsize)
    plt.ylabel(value, fontsize=fontsize)

    plt.xticks(rotation=rotation)
    
    if title is not None:
        plt.title(title,fontsize=fontsize)
    
    # Remove right and top spines 
    # plt.gca().spines[['right', 'top']].set_visible(False)
    

# %% ../nbs/05_plot.ipynb 66
@delegates(sns.regplot)
def plot_rel(
    df,  # dataframe that contains data
    x,  # x axis values, or colname of x axis
    y,  # y axis values, or colname of y axis
    text_location=(0.8, 0.1),  # relative coords in Axes (0–1)
    method="spearman",  # correlation method: 'pearson' or 'spearman'
    index_list=None,  # list of indices to annotate
    hue=None,
    reg_line=True,
    **kwargs
):
    """
    Given a dataframe and the name of two columns, 
    plot the two columns' correlation with either Pearson or Spearman.
    Annotate points if their index is in index_list.
    """
    x_vals = df[x]
    y_vals = df[y]

    # Compute correlation
    if method.lower() == "spearman":
        corr_val, pvalue = spearmanr(x_vals, y_vals)
        corr_label = f"Spearman ρ = {corr_val:.2f}\n p = {pvalue:.2e}"
    else:
        corr_val, pvalue = pearsonr(x_vals, y_vals)
        corr_label = f"Pearson r = {corr_val:.2f}\n p = {pvalue:.2e}"

    # Plot regression line + scatter
    if hue is not None:
        sns.scatterplot(data=df, x=x, y=y, hue=hue, **kwargs)
        if reg_line: sns.regplot(x=x_vals, y=y_vals, scatter=False, line_kws={'color': 'gray','alpha': 0.5})
        plt.legend(
            bbox_to_anchor=(1.05, 1),   # (x, y) anchor relative to axes
            loc="upper left",           # where to attach the legend box
            borderaxespad=0.
        )
    else:
        sns.regplot(x=x_vals, y=y_vals, line_kws={'color': 'gray'}, **kwargs)


    # Add correlation text
    plt.text(
        x=text_location[0],
        y=text_location[1],
        s=corr_label,
        transform=plt.gca().transAxes,
        ha="center",
        va="center"
    )

    # Annotate selected points if index_list is given
    texts = []
    if index_list is not None:
        for idx in index_list:
            if idx in df.index:  # make sure index exists
                texts.append(
                    plt.text(
                        x_vals.loc[idx], 
                        y_vals.loc[idx], 
                        str(idx),
                        fontsize=9,
                        ha="center",
                        va="center"
                    )
                )
        if texts:
            adjust_text(texts, arrowprops=dict(arrowstyle="->", color="black", lw=0.5))

# %% ../nbs/05_plot.ipynb 70
def get_similarity(df, metric='euclidean'):
    "Calculate distance matrix of a df; also return inverse df (similarity df)"
    dist_matrix = pairwise_distances(df, metric=metric)
    dist_df = pd.DataFrame(dist_matrix, index=df.index, columns=df.index)
    
    sigma = np.mean(dist_matrix)
    sim_df = np.exp(-dist_df**2 / (2 * sigma**2))
    return dist_df, sim_df

# %% ../nbs/05_plot.ipynb 71
def plot_corr(df_corr, inverse_color=False,figsize=(15,10),**kwargs):
    "Plot distance/similarity matrix"
    
    mask = np.triu(np.ones_like(df_corr, dtype=bool))
    cmap = 'coolwarm' if not inverse_color else 'coolwarm_r'
    plt.figure(figsize=figsize)
    sns.heatmap(
        df_corr,
        cmap=cmap,
        mask=mask,
        annot=True,
        fmt='.2f',
        linewidths=0.1,  
        linecolor='white',
        **kwargs
    )
    plt.xlabel('')
    plt.ylabel('')
    plt.yticks(rotation=0)

# %% ../nbs/05_plot.ipynb 76
def get_AUCDF(df,col, reverse=False,plot=True,xlabel='Rank of reported kinase'):
    
    "Plot CDF curve and get relative area under the curve"
    
    # sort col values as x values
    x_values = df[col].sort_values().values
    
    # get y_values evenly distributed from 0 to 1
    # y_values = np.arange(1, len(x_values) + 1) / len(x_values) # this method assumes equal distribution of each x value
    y_values = pd.Series(x_values).rank(method='average', pct=True).values # this method takes duplicates into account
    
    if reverse:
        y_values = 1 - y_values + y_values.min()  # Adjust for reverse while keeping the distribution's integrity
    # calculate the area under the curve using the trapezoidal rule
    area_under_curve = np.trapezoid(y_values, x_values)
    
    # calculate total area
    # total_area = (x_values[-1] - x_values[0]) * (y_values[-1] - y_values[0])
    total_area = (x_values[-1] - x_values[0]) * 1
    

    AUCDF = area_under_curve / total_area
    if reverse:
        AUCDF = -AUCDF
    
    if plot:
        # Create a figure and a primary axis
        fig, ax1 = plt.subplots(figsize=(7,5))
        
        # fontsize
        fontsize=17
        
        # Plot the histogram on the primary axis
        sns.histplot(x_values,bins=20,ax=ax1)
        ax1.set_xlabel(xlabel,fontsize=fontsize)
        ax1.set_ylabel('Substrates',color='darkblue',fontsize=fontsize)
        ax1.tick_params(axis='y', labelcolor='darkblue',labelsize=fontsize)
        ax1.tick_params(axis='x', labelcolor='black',labelsize=fontsize)
        ax1.set_xlim(min(x_values),max(x_values))

        # Create a secondary axis for the CDF
        ax2 = ax1.twinx()

        # Plot the CDF on the secondary axis
        # ax2.plot(bin_edges[:-1], cumulative_data, color='red', linestyle='-', linewidth=2.0)
        ax2.plot(x_values, y_values, color='darkred', linestyle='-', linewidth=2.0)
        if reverse:
            ax2.plot([max(x_values),0],[0, max(y_values)], 'k--')  # 'k--' is for a black dashed line
        else:
            ax2.plot([0, max(x_values)], [0, max(y_values)], 'k--')  # 'k--' is for a black dashed line

        ax2.set_ylabel('Probability', color='darkred',fontsize=fontsize,rotation=270,labelpad=18)
        if reverse:
            ax2.text(0.45, 0.3, f"AUCDF:{AUCDF.round(4)}", transform=plt.gca().transAxes, ha='right', va='bottom',fontsize=fontsize)
        else:
            ax2.text(0.95, 0.3, f"AUCDF:{AUCDF.round(4)}", transform=plt.gca().transAxes, ha='right', va='bottom',fontsize=fontsize)
        ax2.tick_params(axis='y', labelcolor='darkred',labelsize=fontsize)
        ax2.set_ylim(0, 1)  # Probabilities range from 0 to 1

        # Show the plot
        plt.title(f'{len(x_values):,} kinase-substrate pairs',fontsize=fontsize)
        plt.show()
        
    return AUCDF

# %% ../nbs/05_plot.ipynb 79
def plot_confusion_matrix(target, # pd.Series 
                          pred, # pd.Series
                          class_names:list=['0','1'],
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    
    "Plot the confusion matrix."
    
    cm = confusion_matrix(target, pred)
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')


    plt.figure(figsize=(6,6))
    sns.heatmap(cm, annot=True, cmap=cmap)  # Plot the heatmap
    plt.title(title)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.xticks(np.arange(len(class_names)) + 0.5, class_names)
    plt.yticks(np.arange(len(class_names)) + 0.5, class_names, rotation=0)

# %% ../nbs/05_plot.ipynb 83
def plot_pie(value_counts, # value counts
             hue_order=None, # list of strings
             labeldistance=0.8,
             fontsize=12,
             fontcolor='black',
             palette='tab20' ,
             figsize=(4,3)
            ):
    if hue_order is not None: value_counts = value_counts.reindex(hue_order)
    colors = sns.color_palette(palette, n_colors=len(value_counts))
    value_counts.plot.pie(
        autopct='%1.1f%%',    # Show percentage inside slices
        labeldistance=labeldistance,    # Move labels closer to center
        textprops={'fontsize': fontsize, 'color': fontcolor} ,
        colors=colors,
        figsize=figsize,
    )
    plt.ylabel('')
    plt.title(f'n={value_counts.sum():,}')

# %% ../nbs/05_plot.ipynb 87
def calculate_pct(df,bin_col, hue_col):
    "Get percentage for hue in each bin; with hue adding up to 1 in each bin."
    count_df = df.groupby([bin_col, hue_col], observed=False).size().unstack(fill_value=0)
    pct_df = count_df.div(count_df.sum(axis=1), axis=0) * 100
    return pct_df

# %% ../nbs/05_plot.ipynb 88
def plot_composition(df, bin_col, hue_col,palette='tab20',legend_title=None,rotate=45,xlabel=None,ylabel='Percentage',figsize=(5,3)):
    pct_df = calculate_pct(df,bin_col,hue_col)

    colors = get_plt_color(palette,pct_df.columns)
    
    pct_df.plot(kind='bar', figsize=figsize,stacked=True,color=colors)
    
    plt.ylabel(ylabel)
    plt.xlabel(xlabel)
    plt.xticks(rotation=rotate)
    if legend_title is None: legend_title = hue_col 
    plt.legend(title=legend_title, bbox_to_anchor=(1.05, 1), loc='upper left')

# %% ../nbs/05_plot.ipynb 90
def plot_cnt(cnt, xlabel=None,ylabel='Count',figsize=(6, 3)):
    fig, ax = plt.subplots(figsize=figsize)
    cnt.plot.bar(ax=ax)
    # Add text on top of each bar
    for idx, value in enumerate(cnt):
        ax.text(idx, value + 0.5, f"{value:,}", ha='center', va='bottom', fontsize=10)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel)
    plt.xticks(rotation=0)
    plt.tight_layout()
