"""Scoring functions to calculate kinase score based on substrate sequence"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_scoring.ipynb.

# %% auto 0
__all__ = ['multiply_23', 'multiply_20', 'cut_seq', 'STY2sty', 'get_dict', 'multiply', 'multiply_pspa', 'sumup',
           'duplicate_ref_zero', 'preprocess_ref', 'predict_kinase', 'Params', 'multiply_generic', 'predict_kinase_df',
           'get_pct', 'get_pct_df']

# %% ../nbs/03_scoring.ipynb 3
import numpy as np, pandas as pd
from .data import *
from .utils import *
from .pssm import *
from typing import Callable
from functools import partial

from tqdm.contrib.concurrent import process_map
from tqdm import tqdm

# %% ../nbs/03_scoring.ipynb 6
def cut_seq(input_string: str, # site sequence
            min_position: int, # minimum position relative to its center
            max_position: int, # maximum position relative to its center
            ):
    
    "Extract sequence based on a range relative to its center position"
    
    # Find the center position of the string
    center_position = len(input_string) // 2

    # Calculate the start and end indices
    start_index = max(center_position + min_position, 0)  # Ensure start_index is not negative
    end_index = min(center_position + max_position + 1, len(input_string))  # Ensure end_index does not exceed string length

    # Extract and return the substring
    return input_string[start_index:end_index]

# %% ../nbs/03_scoring.ipynb 8
def STY2sty(input_string: str):
    "Replace all 'STY' with 'sty' in a sequence"    
    return input_string.replace('S', 's').replace('T', 't').replace('Y', 'y')

# %% ../nbs/03_scoring.ipynb 10
def get_dict(input_string:str, # phosphorylation site sequence
            ):
    
    "Get a dictionary of input string; no need for the star in the middle; make sure it is 15 or 10 length"

    center_index = len(input_string) // 2
    center_char = input_string[center_index]

    result = []

    for i, char in enumerate(input_string):
        position = i - center_index

        if char.isalpha():
            result.append(f"{position}{char}")

    return result

# %% ../nbs/03_scoring.ipynb 14
def multiply(values, # list of values, possibilities of amino acids at certain positions
                  kinase=None,
             num_aa=23, # number of amino acids, 23 for standard CDDM, 20 for all uppercase CDDM
            ):
    
    "Multiply the possibilities of the amino acids at each position in a phosphorylation site"
    

    # Using the logarithmic property: log(a*b) = log(a) + log(b)
    # Compute the sum of the logarithms of the values and the scale factor
    values = [v+EPSILON for v in values]
    log_sum = np.sum(np.log2(values)) + (len(values) - 1) * np.log2(num_aa)

    return log_sum

# %% ../nbs/03_scoring.ipynb 17
multiply_23 = partial(multiply,num_aa=23)

# %% ../nbs/03_scoring.ipynb 18
multiply_20 = partial(multiply,num_aa=20)

# %% ../nbs/03_scoring.ipynb 20
def multiply_pspa(values, kinase, num_aa_dict=Data.get_num_dict()):
    "Multiply values, consider the dynamics of scale factor, which is PSPA random aa number."

    # Check if any values are less than or equal to zero
    if np.any(np.array(values) == 0):
        return np.nan
    else:
        # Retrieve the divide factor from the dictionary
        divide_factor = num_aa_dict[kinase]

        # Using the logarithmic property: log(a*b) = log(a) + log(b)
        # Compute the sum of the logarithms of the values and the divide factor
        log_sum = np.sum(np.log2(values)) + (len(values) - 1) * np.log2(divide_factor)

        return log_sum

# %% ../nbs/03_scoring.ipynb 23
def sumup(values, # list of values, possibilities of amino acids at certain positions
          kinase=None, 
         ):
    "Sum up the possibilities of the amino acids at each position in a phosphorylation site sequence"
    return sum(values)

# %% ../nbs/03_scoring.ipynb 25
def duplicate_ref_zero(df: pd.DataFrame) -> pd.DataFrame:
    """
    If '0S', '0T', '0Y' exist with non-zero values, create '0s', '0t', '0y' with same values.
    If '0s', '0t', '0y' exist with non-zero values, create '0S', '0T', '0Y' with same values.
    """
    df = df.copy()
    pairs = [('0S', '0s'), ('0T', '0t'), ('0Y', '0y')]

    for upper, lower in pairs:
        if upper in df.columns and (df[upper] != 0).any():
            df[lower] = df[upper]
        elif lower in df.columns and (df[lower] != 0).any():
            df[upper] = df[lower]

    return df

# %% ../nbs/03_scoring.ipynb 26
def preprocess_ref(ref):
    "Convert pS/T/Y in ref columns to s/t/y if any; mirror 0S/T/Y to 0s/t/y."
    ref = ref.copy()
    # if ref contains pS,pT,pY columns, convert them to s,t,y for scoring
    ref.columns=ref.columns.map(pSTY2sty)
    # duplicate 0S/T/Y to 0s/t/y (or the opposite) to ensure equal treatment of zero position
    return duplicate_ref_zero(ref)

# %% ../nbs/03_scoring.ipynb 27
def predict_kinase(input_string: str, # site sequence
                   ref: pd.DataFrame, # reference dataframe for scoring
                   func: Callable, # function to calculate score
                   to_lower: bool=False, # convert capital STY to lower case
                   to_upper: bool=False, # convert all letter to uppercase
                   verbose=True
                   ):
    "Predict kinase given a phosphorylation site sequence"
 
    input_string = check_seq(input_string)

    if to_lower: input_string = STY2sty(input_string)

    if to_upper: input_string = input_string.upper()

    ref = preprocess_ref(ref)
    
    results = []
    
    for kinase, row in ref.iterrows():
        
        # Convert the row into a dictionary, excluding NaN values, to create a PSSM dictionary for a kinase
        r_dict = row.dropna().to_dict()
        
        # Extract position+amino acid name from the input string and filter them against the name in PSSM
        pos_aa_name = get_dict(input_string)
        pos_aa_name = [key for key in pos_aa_name if key in r_dict.keys()]
    
        # Collect corresponding PSSM values for these positions and amino acids
        pos_aa_val = [r_dict[key] for key in pos_aa_name] # Further checks for NaN values
        
        # Calculate the score for this kinase using the specified function
        score = func(pos_aa_val, kinase)
        results.append(score)
    
    if verbose:
        print(f'considering string: {pos_aa_name}')

    out = pd.Series(results, index=ref.index).sort_values(ascending=False)
        
    return out.round(3).dropna()

# %% ../nbs/03_scoring.ipynb 40
def Params(name=None, load=True):
    def lazy(f): return lambda: f().astype('float32')
    
    params = {
        "CDDM": {'ref': lazy(Data.get_cddm_LO), 'func': sumup},
        "CDDM_upper": {'ref': lazy(Data.get_cddm_LO_upper), 'func': sumup, 'to_upper': True},
        "PSPA_st": {'ref': lazy(Data.get_pspa_st), 'func': multiply_pspa},
        "PSPA_y": {'ref': lazy(Data.get_pspa_tyr), 'func': multiply_pspa},
        "PSPA": {'ref': lazy(Data.get_pspa), 'func': multiply_pspa},
    }

    if name is None:
        return list(params.keys())

    cfg = params[name]
    if load and callable(cfg['ref']):
        cfg['ref'] = cfg['ref']()  # actually load now
    return cfg

# %% ../nbs/03_scoring.ipynb 44
def multiply_generic(merged_df, kinases, df_index, divide_factor_func):
    """Multiply-based log-sum aggregation across kinases."""
    out = {}
    log2 = np.log2  # local alias for speed
    
    for kinase in tqdm(kinases, desc="Computing multiply_generic"):
        divide_factor = divide_factor_func(kinase)
        df = merged_df[['input_index', kinase]].dropna()
        if df.empty:
            out[kinase] = pd.Series(index=df_index, dtype=float)
            continue
        
        log_values = log2(df[kinase] + EPSILON)
        grouped = df.assign(log_value=log_values).groupby('input_index')['log_value']
        
        # vectorized form
        log_sum = grouped.sum() + (grouped.count() - 1) * log2(divide_factor)
        out[kinase] = log_sum

    return pd.DataFrame(out).reindex(df_index)

# %% ../nbs/03_scoring.ipynb 45
def predict_kinase_df(df, seq_col, ref, func, to_lower=False, to_upper=False):
    """
    Predict kinase scores based on reference PSSM or weight matrix.
    Applies preprocessing, merges long format keys, then aggregates using given func.
    """
    print(f"Input dataframe has {len(df)} rows")
    print("Preprocessing...")

    ref = preprocess_ref(ref)
    df = df.copy()
    df[seq_col] = check_seqs(df[seq_col])  # accepts both Series and DataFrame per your earlier fix

    if to_lower:
        df[seq_col] = df[seq_col].apply(STY2sty)
    if to_upper:
        df[seq_col] = df[seq_col].str.upper()

    # Align sequence length to ref
    pos = ref.columns.str[:-1].astype(int)
    df[seq_col] = df[seq_col].apply(partial(cut_seq, min_position=pos.min(), max_position=pos.max()))

    print("Preprocessing done. Expanding sequences...")

    # Convert sequences to long-form keys
    input_keys_df = (
        df.assign(keys=df[seq_col].apply(get_dict))
          .explode('keys')
          .reset_index(names='input_index')[['input_index', 'keys']]
          .rename(columns={'keys': 'key'})
          .set_index('key')
    )

    print("Merging reference...")
    ref_T = ref.T.astype('float32')
    merged_df = input_keys_df.merge(ref_T, left_index=True, right_index=True, how='inner')
    print("Merge complete.")

    # Dispatch by func
    if func == sumup:
        out = merged_df.groupby('input_index').sum().reindex(df.index)
    elif func in (multiply_pspa, multiply_23, multiply_20):
        num_dict = Data.get_num_dict() if func == multiply_pspa else None
        divisor = (
            (lambda k: num_dict[k])
            if func == multiply_pspa else
            (lambda k: 23 if func == multiply_23 else 20)
        )
        out = multiply_generic(merged_df, ref_T.columns, df.index, divide_factor_func=divisor)
    else:
        raise ValueError(f"Unknown function: {func}")

    return out.round(3)

# %% ../nbs/03_scoring.ipynb 50
def get_pct(site,ref,func,pct_ref):
    
    "Replicate the precentile results from The Kinase Library."
    
    # As here we try to replicate the results, we use site.upper(); consider removing it for future version.
    score = predict_kinase(site.upper(),ref=ref,func=func)
    
    percentiles = {}
    for kinase in score.index: 
        # Get the values from `ref` for this kinase
        ref_values = pct_ref[kinase].values
        # Calculate how many values in `ref` are less than the new score
        less = np.sum(ref_values < score[kinase])
        # Calculate how many values are equal to the new score
        equal = np.sum(ref_values == score[kinase])
        # Calculate the percentile rank
        percentile = (less + 0.5 * equal) / len(ref_values) * 100
        percentiles[kinase] = percentile
        
    pct = pd.Series(percentiles)
    final = pd.concat([score,pct],axis=1)
    final.columns=['log2(score)','percentile']
    return final

# %% ../nbs/03_scoring.ipynb 54
def get_pct_df(score_df, # output from predict_kinase_df 
               pct_ref, # a reference df for percentile calculation
              ):
    
    "Replicate the precentile results from The Kinase Library."

    # Create an array to hold percentile ranks
    percentiles = np.zeros(score_df.shape)
    
    # Calculate percentiles for each column in a vectorized manner
    for i, kinase in tqdm(enumerate(score_df.columns),total=len(score_df.columns)):
        ref_values = np.sort(pct_ref[kinase].values)
        
        # Use searchsorted to find indices where the scores would be inserted to maintain order
        indices = np.searchsorted(ref_values, score_df[kinase].values, side='right')
        
        # Calculate percentile ranks
        percentiles[:, i] = indices / len(ref_values) * 100

    # Convert the array to a DataFrame with appropriate indices and columns
    percentiles_df = pd.DataFrame(percentiles, index=score_df.index, columns=score_df.columns).astype(float).round(3)
    
    return percentiles_df
