"""Functions to preprocess sequence to prepare kinase substrate dataset"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_utils.ipynb.

# %% auto 0
__all__ = ['prepare_path', 'get_diff', 'pSTY2sty', 'sty2pSTY', 'check_seq', 'check_seqs', 'validate_site', 'validate_site_df',
           'phosphorylate_seq', 'phosphorylate_seq_df', 'extract_site_seq', 'get_fasta', 'run_clustalo', 'aln2df',
           'get_aln_freq']

# %% ../nbs/01_utils.ipynb 3
import numpy as np, pandas as pd
from tqdm import tqdm
from .data import *
from fastcore.meta import delegates
from pathlib import Path

# for alignment
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO, AlignIO
import subprocess


# %% ../nbs/01_utils.ipynb 6
def prepare_path(path):
    """Ensure the parent directory exists and return the full file path."""
    full_path = Path(path).expanduser()
    full_path.parent.mkdir(parents=True, exist_ok=True)
    return full_path

# %% ../nbs/01_utils.ipynb 8
def get_diff(df1, df2, col1, col2=None):
    "Get non-overlap parts of two dataframes."
    if col2 is None:
        col2=col1
    df1_unique = df1[~df1[col1].isin(df2[col2])]
    df2_unique = df2[~df2[col2].isin(df1[col1])]
    return df1_unique, df2_unique

# %% ../nbs/01_utils.ipynb 12
def pSTY2sty(string):
    "Convert pS/pT/pY to s/t/y in a string."
    return string.replace('pS', 's').replace('pT', 't').replace('pY', 'y')

# %% ../nbs/01_utils.ipynb 13
def sty2pSTY(string):
    "Convert s/t/y to pS/pT/pY in a string."
    return string.replace('s', 'pS').replace('t', 'pT').replace('y', 'pY')

# %% ../nbs/01_utils.ipynb 16
def check_seq(seq):
    """Convert non-s/t/y characters to uppercase and replace disallowed characters with underscores."""
    acceptor = seq[len(seq) // 2]
    assert acceptor.lower() in {'s', 't', 'y'}, f"{seq} has {acceptor} at position {len(seq) // 2}; need to have one of 's', 't', or 'y' in the center"

    allowed_chars = set("PGACSTVILMFYWHKRQNDEsty")
    return "".join(char if char in {'s', 't', 'y'} else (char.upper() if char.upper() in allowed_chars else '_') for char in seq)

# %% ../nbs/01_utils.ipynb 19
def check_seqs(data,col=None):
    "Convert non-s/t/y to upper case & replace with underscore if the character is not in the allowed set"
    if isinstance(data, pd.DataFrame):
        if col is None:
            raise ValueError("Must specify 'col' when passing a DataFrame.")
        seqs = data[col]
    elif isinstance(data, (pd.Series, list)):
        seqs = pd.Series(data)
    else:
        raise TypeError("Input must be a DataFrame, Series, or list.")
    
    assert len(seqs.str.len().value_counts()) == 1, "Inconsistent sequence length detected."
    return seqs.apply(check_seq)

# %% ../nbs/01_utils.ipynb 22
def validate_site(site_info,
                  seq):
    "Validate site position residue match with site residue."
    pos=int(site_info[1:])-1 # python index starts from zero
    if pos >= len(seq) or pos < 0: 
        return int(False)
    return int(seq[pos]==site_info[0])

# %% ../nbs/01_utils.ipynb 25
def validate_site_df(df, 
                     site_info_col,
                     protein_seq_col): 
    "Validate site position residue match with site residue in a dataframe."
    return df.apply(lambda r: validate_site(r[site_info_col],r[protein_seq_col]) , axis=1)

# %% ../nbs/01_utils.ipynb 28
def phosphorylate_seq(seq, # full protein sequence
                      *sites, # site info, e.g., S140
                      ):
    "Phosphorylate protein sequence based on phosphosites (e.g.,S140). "
    seq = list(seq)

    for site in sites:
        char = site[0] 
        position = int(site[1:]) - 1 # substract 1 as python index starts from 0

        if 0 <= position < len(seq):
            if seq[position] == char:
                seq[position] = char.lower()  
            else:
                raise ValueError(f"Mismatch at position {position+1}: expected {char}, found {seq[position]}")
        else:
            raise IndexError(f"Position {position+1} out of range for sequence length {len(seq)}")

    return ''.join(seq)

# %% ../nbs/01_utils.ipynb 30
def phosphorylate_seq_df(df,
                         id_col='substrate_uniprot', # column of sequence ID
                         seq_col='substrate_sequence', # column that contains protein sequence
                         site_col='site', # column that contains site info, e.g., S140
                         
                        ):
    "Phosphorylate whole sequence based on phosphosites in a dataframe"
    df_seq = df.groupby(id_col).agg({site_col:lambda r: r.unique(),seq_col:'first'}).reset_index()
    df_seq['phosphoseq'] = df_seq.apply(lambda r: phosphorylate_seq(r[seq_col],*r[site_col]),axis=1)
    return df_seq

# %% ../nbs/01_utils.ipynb 34
def extract_site_seq(df: pd.DataFrame, # dataframe that contains protein sequence
                     seq_col: str, # column name of protein sequence
                     site_col: str, # column name of site information (e.g., S10)
                     n=7, # length of surrounding sequence (default -7 to +7)
                    ):
    "Extract -n to +n site sequence from protein sequence"
    
    data = []
    for i, r in tqdm(df.iterrows(),total=len(df)):
        position = int(r[site_col][1:]) - 1
        start = position - n
        end = position + n +1

        # Extract the subsequence
        subseq = r[seq_col][max(0, start):min(len(r[seq_col]), end)]

        # Pad the subsequence if needed
        if start < 0:
            subseq = "_" * abs(start) + subseq
        if end > len(r[seq_col]):
            subseq = subseq + "_" * (end - len(r[seq_col]))

        data.append(subseq)
        
    return np.array(data)

# %% ../nbs/01_utils.ipynb 39
def get_fasta(df,seq_col='kd_seq',id_col='kd_ID',path='out.fasta'):
    "Generate fasta file from sequences."
    records = [
        SeqRecord(Seq(str(row[seq_col])), id=str(row[id_col]), description="")
        for _, row in df.iterrows()
    ]
    SeqIO.write(records, path, "fasta")
    print(len(records))

# %% ../nbs/01_utils.ipynb 43
def run_clustalo(input_fasta,  # .fasta fname
                 output_aln, # .aln output fname
                 outfmt="clu"):
    "Run Clustal Omega to perform multiple sequence alignment."
    # if the output directory does not exist, create one
    output_aln = Path(output_aln)
    output_aln.parent.mkdir(parents=True, exist_ok=True)

    # run clustalo
    subprocess.run([
        "clustalo", "-i", str(input_fasta),
        "-o", str(output_aln),
        "--force", "--outfmt=clu"
    ], check=True)

# %% ../nbs/01_utils.ipynb 45
def aln2df(fname):
    alignment = AlignIO.read(fname, "clustal")
    alignment_array = [list(str(record.seq)) for record in alignment]
    ids = [record.id for record in alignment]
    df = pd.DataFrame(alignment_array, index=ids)
    df.columns = df.columns+1
    return df

# %% ../nbs/01_utils.ipynb 47
def get_aln_freq(df):
    "Get frequency of each amino acid across each position from the aln2df output."
    counts_df = df.apply(lambda col: col.value_counts(), axis=0).fillna(0)
    return counts_df.div(counts_df.sum(axis=0), axis=1)
