#%%
# Run this in Jupyter to test the two approaches

import numpy as np
import pandas as pd
from pythonflex.utils import dload

dataset_name = "[CORUM] 19Q2"

pra = dload("pra", dataset_name)
mpr = dload("mpr", dataset_name)

filter_ids = set(mpr["filters"]["no_mtRibo_ETCI"])
print(f"Filter IDs: {filter_ids}")

cid_col = "complex_id" if "complex_id" in pra.columns else "complex_ids"

# Sort by score descending
pra_sorted = pra.sort_values("score", ascending=False).reset_index(drop=True)

def has_filter_id(cids, filter_ids):
    """Check if any complex ID is in filter_ids"""
    if isinstance(cids, (np.ndarray, list)):
        ids = [int(x) for x in cids if pd.notnull(x)]
    else:
        return False
    return any(c in filter_ids for c in ids)

# Mark which pairs should be filtered
pra_sorted["should_filter"] = pra_sorted[cid_col].apply(lambda x: has_filter_id(x, filter_ids))

print(f"\nTotal pairs: {len(pra_sorted)}")
print(f"Pairs to filter: {pra_sorted['should_filter'].sum()}")
print(f"TPs to filter: {(pra_sorted['should_filter'] & (pra_sorted['prediction']==1)).sum()}")

# APPROACH 1: Mark as negative (what your Python does)
# Keep all rows, but filtered TPs become FPs
print("\n" + "=" * 70)
print("APPROACH 1: Mark filtered TPs as negatives (keep rows)")
print("=" * 70)

df1 = pra_sorted.copy()
df1["true_filtered"] = df1["prediction"].copy()
df1.loc[df1["should_filter"] & (df1["prediction"]==1), "true_filtered"] = 0

tp_cum_1 = df1["true_filtered"].cumsum()
prec_1 = tp_cum_1 / (np.arange(len(df1)) + 1)

# Show precision at key TP counts
print("\nPrecision at key TP counts:")
for target_tp in [10, 50, 100, 500, 1000]:
    if target_tp <= tp_cum_1.max():
        idx = np.where(tp_cum_1 >= target_tp)[0][0]
        print(f"   TP={target_tp}: precision={prec_1.iloc[idx]:.3f} (at rank {idx+1})")

# APPROACH 2: Remove rows entirely (what R does with replace=FALSE)
print("\n" + "=" * 70)
print("APPROACH 2: Remove filtered rows entirely")
print("=" * 70)

df2 = pra_sorted[~pra_sorted["should_filter"]].copy().reset_index(drop=True)

tp_cum_2 = df2["prediction"].cumsum()
prec_2 = tp_cum_2 / (np.arange(len(df2)) + 1)

print(f"\nRows remaining after removal: {len(df2)}")
print(f"TPs remaining: {df2['prediction'].sum()}")

print("\nPrecision at key TP counts:")
for target_tp in [10, 50, 100, 500, 1000]:
    if target_tp <= tp_cum_2.max():
        idx = np.where(tp_cum_2 >= target_tp)[0][0]
        print(f"   TP={target_tp}: precision={prec_2.iloc[idx]:.3f} (at rank {idx+1})")

# APPROACH 3: Only remove filtered POSITIVE pairs, keep negatives
print("\n" + "=" * 70)
print("APPROACH 3: Remove only filtered TPs (keep filtered negatives)")
print("=" * 70)

# This removes TP rows that contain filter IDs, but keeps negative rows
remove_mask = pra_sorted["should_filter"] & (pra_sorted["prediction"] == 1)
df3 = pra_sorted[~remove_mask].copy().reset_index(drop=True)

tp_cum_3 = df3["prediction"].cumsum()
prec_3 = tp_cum_3 / (np.arange(len(df3)) + 1)

print(f"\nRows remaining: {len(df3)}")
print(f"TPs remaining: {df3['prediction'].sum()}")

print("\nPrecision at key TP counts:")
for target_tp in [10, 50, 100, 500, 1000]:
    if target_tp <= tp_cum_3.max():
        idx = np.where(tp_cum_3 >= target_tp)[0][0]
        print(f"   TP={target_tp}: precision={prec_3.iloc[idx]:.3f} (at rank {idx+1})")

print("\n" + "=" * 70)
print("COMPARISON")
print("=" * 70)
print("""
Approach 1 (mark as negative): Filtered TPs become FPs, lowering precision
Approach 2 (remove all filtered): Both TPs and negatives removed
Approach 3 (remove only TPs): Only filtered TPs removed, negatives kept

The R code uses Approach 3 (remove positive pairs that contain the filter ID).
""")
# %%
