"""Core randomization-inference engine for ritest.

This module contains the internal `ritest()` implementation used by the
public API. It coordinates:

- configuration (DEFAULTS and user overrides),
- validation and preprocessing,
- permutation generation (full matrix or streamed),
- model evaluation (FastOLS or user-supplied stat_fn),
- p-value and p-value CI calculation,
- optional coefficient CI bounds/bands,
- packaging results into `RitestResult`.
"""

from __future__ import annotations

import os
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Dict, Optional, Tuple, cast

import numpy as np
import pandas as pd

from .ci.coef_ci import Alt, coef_ci_band_fast, coef_ci_bounds_fast, coef_ci_bounds_generic
from .ci.pvalue_ci import _PValCIMethod, pvalue_ci
from .config import DEFAULTS

# FastOLS is the linear-path engine; NUMBA_OK tells us if its kernels are JITed.
try:
    from .engine.fast_ols import NUMBA_OK as FAST_OLS_NUMBA_OK  # type: ignore[attr-defined]
    from .engine.fast_ols import FastOLS
except Exception:  # pragma: no cover - defensive
    from .engine.fast_ols import FastOLS  # type: ignore[no-redef]

    FAST_OLS_NUMBA_OK = False  # conservative fallback
# Permutation providers: eager (full matrix) and streaming (chunked)
from .engine.shuffle import generate_permuted_matrix, iter_permuted_matrix
from .results import RitestResult
from .validation import validate_inputs

__all__ = ["ritest"]


def _coerce_n_jobs(val: int | None) -> int:
    """Normalise `n_jobs` to a sensible positive integer."""
    if val is None:
        return 1
    if val == -1:
        return max(os.cpu_count() or 1, 1)
    try:
        v = int(val)
    except Exception:
        return 1
    return max(v, 1)


def _coerce_ci_method(x: str | _PValCIMethod | None, fallback: _PValCIMethod) -> _PValCIMethod:
    """Normalise user input to core labels `'cp'` or `'normal'`."""
    if x is None:
        return fallback
    s = str(x).strip().lower()
    if s in {"cp", "clopper-pearson", "clopperpearson", "clopper", "exact"}:
        return "cp"  # type: ignore[return-value]
    if s in {"normal", "wald"}:
        return "normal"  # type: ignore[return-value]
    return fallback


def _bytes_per_row(n_obs: int, label_itemsize: int) -> int:
    """
    Estimate bytes required per *row* of the permutation block.

    Rules
    -----
    - If FastOLS kernels are Numba-JITed (`FAST_OLS_NUMBA_OK`), labels are
      consumed as-is (int8) ⇒ 1 byte per entry is enough.
    - Otherwise, NumPy fallbacks may cast to float64 internally; assume
      8 bytes per entry for a safe upper bound.
    - Always respect the input label itemsize so we never underestimate.
    """
    per_entry = 1 if FAST_OLS_NUMBA_OK else 8
    # In case label_itemsize is larger than 1 (user-provided permutations),
    # also respect the input size to avoid underestimation.
    per_entry = max(per_entry, label_itemsize)
    return n_obs * per_entry


def ritest(  # noqa: C901
    df: pd.DataFrame,
    *,
    permute_var: str,
    # Linear-model path
    formula: str | None = None,
    stat: str | None = None,
    # User-supplied statistic function
    stat_fn: Callable[[pd.DataFrame], float] | None = None,
    # Optional columns
    cluster: str | None = None,
    strata: str | None = None,
    weights: str | None = None,
    # Hypothesis direction
    alternative: Alt = "two-sided",
    # --- public controls (override DEFAULTS when provided) ---
    reps: int | None = None,
    seed: int | None = None,
    alpha: float | None = None,
    ci_method: str | _PValCIMethod | None = None,
    ci_mode: str | None = None,  # "none" | "bounds" | "grid"
    ci_range: float | None = None,
    ci_step: float | None = None,
    coef_ci_generic: bool | None = None,  # relevant only when stat_fn and grid
    n_jobs: int | None = None,
    # prebuilt permutations (matrix of permuted T labels)
    permutations: np.ndarray | None = None,
) -> RitestResult:
    """
    Core randomisation test.

    Exactly one of (`formula`, `stat`) or (`stat_fn`) must be provided.
    Any explicit control provided here overrides the corresponding value
    in `config.DEFAULTS`.
    """

    # 0) Controls: pull from DEFAULTS, allow public overrides
    cfg = DEFAULTS
    reps = int(cfg["reps"]) if reps is None else int(reps)
    seed = int(cfg["seed"]) if seed is None else int(seed)
    alpha = float(cfg["alpha"]) if alpha is None else float(alpha)
    ci_method = _coerce_ci_method(ci_method, fallback=cfg["ci_method"])
    ci_mode = str(cfg["ci_mode"]) if ci_mode is None else str(ci_mode)
    ci_range = float(cfg["ci_range"]) if ci_range is None else float(ci_range)
    ci_step = float(cfg["ci_step"]) if ci_step is None else float(ci_step)
    ci_tol = float(cfg["ci_tol"])  # reserved for future use
    coef_ci_generic = (
        bool(cfg["coef_ci_generic"]) if coef_ci_generic is None else bool(coef_ci_generic)
    )
    n_jobs = _coerce_n_jobs(int(cfg.get("n_jobs", 1)) if n_jobs is None else int(n_jobs))
    # Memory/chunking knobs (soft budget)
    perm_chunk_bytes = int(cfg.get("perm_chunk_bytes", 256 * 1024 * 1024))
    perm_chunk_min_rows = int(cfg.get("perm_chunk_min_rows", 64))

    rng = np.random.default_rng(seed)

    # 1) Validate & preprocess (enforces binary treatment and returns T as int8 0/1)
    v = validate_inputs(
        df,
        permute_var=permute_var,
        formula=formula,
        stat=stat,
        stat_fn=stat_fn,
        cluster=cluster,
        strata=strata,
        weights=weights,
        alternative=alternative,
        alpha=alpha,
        ci_method=ci_method,
        ci_mode=ci_mode,
        ci_range=ci_range,
        ci_step=ci_step,
        ci_tol=ci_tol,
        coef_ci_generic=coef_ci_generic,
    )

    # 2) Observed statistic
    linear_model = v.stat_fn is None
    if linear_model:
        ols = FastOLS(
            v.y, v.X, v.treat_idx, weights=v.weights, cluster=v.cluster, compute_vcov=True
        )
        obs_stat = float(ols.beta_hat)
        se_obs = float(ols.se)
        K_obs = float(ols.K)
        t_metric_lin = ols.t_metric  # branch-local, aligns with c_vector metric
    else:
        stat_fn_local = cast(Callable[[pd.DataFrame], float], v.stat_fn)
        obs_stat = float(stat_fn_local(df))
        se_obs = float("nan")
        K_obs = float("nan")
        t_metric_lin = None  # type: ignore[assignment]

    # Warn for potentially slow generic grid band when opted-in
    if (not linear_model) and (ci_mode == "grid") and coef_ci_generic:
        grid_size = max(int(round((2.0 * ci_range) / max(ci_step, 1e-12))) + 1, 1)
        est_sec = float(v.warmup_time) * grid_size * reps
        if est_sec > 10:
            warnings.warn(
                f"CI bands for stat_fn may take ~{est_sec:.1f} sec "
                f"(warmup {v.warmup_time:.3f}s × grid≈{grid_size} × reps={reps})."
            )

    # 3) Allocate outputs (small, O(reps))
    perm_stats = np.empty(reps, dtype=np.float64)
    K_perm_local = np.empty(reps, dtype=np.float64) if linear_model else None

    # Helper workers (capture `v`, `t_metric_lin` by closure)
    def _fit_one(args: Tuple[int, np.ndarray]) -> Tuple[int, float, float]:
        """Linear-path worker: (abs_index, T_perm_row) -> (abs_index, beta_r, Kr)."""
        r_abs, T_perm = args
        Xp = v.X.copy()
        Xp[:, v.treat_idx] = T_perm  # int8 → float cast on assignment
        ols_r = FastOLS(
            v.y, Xp, v.treat_idx, weights=v.weights, cluster=v.cluster, compute_vcov=False
        )
        beta_r = float(ols_r.beta_hat)
        Kr = float(ols_r.c_vector @ t_metric_lin)  # type: ignore[arg-type]
        return r_abs, beta_r, Kr

    def _eval_one(args: Tuple[int, np.ndarray]) -> Tuple[int, float]:
        """Generic-path worker: (abs_index, T_perm_row) -> (abs_index, stat_r)."""
        r_abs, T_perm = args
        dfp = df.copy(deep=False)
        dfp[permute_var] = T_perm  # pandas will upcast to float when needed
        return r_abs, float(stat_fn_local(dfp))  # type: ignore[arg-type]

    # 4) Permutations: prebuilt, eager, or chunked (deterministic in all cases)
    if permutations is not None:
        # Expect shape (reps, n), where rows are permuted T-label vectors
        perms = np.asarray(permutations)
        if perms.ndim != 2 or perms.shape[1] != v.T.shape[0]:
            raise ValueError(
                f"permutations must have shape (reps, n={v.T.shape[0]}), got {perms.shape}"
            )
        if perms.shape[0] != reps:
            # Honor the supplied matrix size as the true reps
            reps = int(perms.shape[0])
            perm_stats = np.empty(reps, dtype=np.float64)
            if linear_model:
                K_perm_local = np.empty(reps, dtype=np.float64)
        # Process rows directly (no chunking for user-supplied matrix)
        if linear_model:
            if n_jobs == 1:
                X_work = v.X.copy()
                for r in range(reps):
                    X_work[:, v.treat_idx] = perms[r]
                    ols_r = FastOLS(
                        v.y,
                        X_work,
                        v.treat_idx,
                        weights=v.weights,
                        cluster=v.cluster,
                        compute_vcov=False,
                    )
                    perm_stats[r] = float(ols_r.beta_hat)
                    K_perm_local[r] = float(ols_r.c_vector @ t_metric_lin)  # type: ignore[index]
            else:
                with ThreadPoolExecutor(max_workers=n_jobs) as ex:
                    for r_abs, beta_r, Kr in ex.map(
                        lambda a: _fit_one(a), ((r, perms[r]) for r in range(reps))
                    ):
                        perm_stats[r_abs] = beta_r
                        K_perm_local[r_abs] = Kr  # type: ignore[index]
        else:
            if n_jobs == 1:
                for r in range(reps):
                    dfp = df.copy(deep=False)
                    dfp[permute_var] = perms[r]
                    perm_stats[r] = float(stat_fn_local(dfp))  # type: ignore[arg-type]
            else:
                with ThreadPoolExecutor(max_workers=n_jobs) as ex:
                    for r_abs, z in ex.map(
                        lambda a: _eval_one(a), ((r, perms[r]) for r in range(reps))
                    ):
                        perm_stats[r_abs] = z
    else:
        # Decide whether to allocate the full (reps, n) matrix or stream in chunks.
        n_obs = v.T.shape[0]
        itemsize = int(v.T.dtype.itemsize)  # int8 => 1
        bpr = _bytes_per_row(n_obs, itemsize)  # conservative when Numba is unavailable
        full_bytes = reps * bpr

        if full_bytes <= perm_chunk_bytes:
            # Eager path: build full matrix (current behavior)
            T_perms = generate_permuted_matrix(
                v.T, reps, cluster=v.cluster, strata=v.strata, rng=rng
            )
            if linear_model:
                if n_jobs == 1:
                    X_work = v.X.copy()
                    for r in range(reps):
                        X_work[:, v.treat_idx] = T_perms[r]
                        ols_r = FastOLS(
                            v.y,
                            X_work,
                            v.treat_idx,
                            weights=v.weights,
                            cluster=v.cluster,
                            compute_vcov=False,
                        )
                        perm_stats[r] = float(ols_r.beta_hat)
                        K_perm_local[r] = float(ols_r.c_vector @ t_metric_lin)  # type: ignore[index]
                else:
                    with ThreadPoolExecutor(max_workers=n_jobs) as ex:
                        for r, beta_r, Kr in ex.map(
                            _fit_one, ((r, T_perms[r]) for r in range(reps))
                        ):
                            perm_stats[r] = beta_r
                            K_perm_local[r] = Kr  # type: ignore[index]
            else:
                if n_jobs == 1:
                    for r in range(reps):
                        dfp = df.copy(deep=False)
                        dfp[permute_var] = T_perms[r]
                        perm_stats[r] = float(stat_fn_local(dfp))  # type: ignore[arg-type]
                else:
                    with ThreadPoolExecutor(max_workers=n_jobs) as ex:
                        for r, z in ex.map(_eval_one, ((r, T_perms[r]) for r in range(reps))):
                            perm_stats[r] = z
        else:
            # Streaming path: generate blocks with bounded memory and process each in turn.
            # Choose chunk_rows from budget; enforce a sensible minimum.
            chunk_rows = max(perm_chunk_min_rows, perm_chunk_bytes // max(bpr, 1))
            if chunk_rows <= 0:  # ultra-conservative fallback
                chunk_rows = perm_chunk_min_rows

            r0 = 0  # absolute write position into perm_stats / K_perm_local
            if n_jobs == 1:
                # Serial evaluation per block
                X_work: Optional[np.ndarray] = None
                if linear_model:
                    X_work = v.X.copy()

                for block in iter_permuted_matrix(
                    v.T,
                    reps,
                    cluster=v.cluster,
                    strata=v.strata,
                    rng=rng,
                    chunk_rows=int(chunk_rows),
                ):
                    m = block.shape[0]
                    if linear_model:
                        Xw = cast(np.ndarray, X_work)
                        for i in range(m):
                            Xw[:, v.treat_idx] = block[i]
                            ols_r = FastOLS(
                                v.y,
                                Xw,
                                v.treat_idx,
                                weights=v.weights,
                                cluster=v.cluster,
                                compute_vcov=False,
                            )
                            perm_stats[r0 + i] = float(ols_r.beta_hat)
                            K_perm_local[r0 + i] = float(ols_r.c_vector @ t_metric_lin)  # type: ignore[index]
                    else:
                        for i in range(m):
                            dfp = df.copy(deep=False)
                            dfp[permute_var] = block[i]
                            perm_stats[r0 + i] = float(stat_fn_local(dfp))  # type: ignore[arg-type]
                    r0 += m
            else:
                # Parallel evaluation per block; keep one pool for the entire run.
                with ThreadPoolExecutor(max_workers=n_jobs) as ex:
                    for block in iter_permuted_matrix(
                        v.T,
                        reps,
                        cluster=v.cluster,
                        strata=v.strata,
                        rng=rng,
                        chunk_rows=int(chunk_rows),
                    ):
                        m = block.shape[0]
                        if linear_model:
                            # Map absolute indices to rows in the current block
                            for r_abs, beta_r, Kr in ex.map(
                                _fit_one, ((r0 + i, block[i]) for i in range(m))
                            ):
                                perm_stats[r_abs] = beta_r
                                K_perm_local[r_abs] = Kr  # type: ignore[index]
                        else:
                            for r_abs, z in ex.map(
                                _eval_one, ((r0 + i, block[i]) for i in range(m))
                            ):
                                perm_stats[r_abs] = z
                        r0 += m

    # 5) P-value + CI for p
    if alternative == "two-sided":
        extreme = np.abs(perm_stats) >= abs(obs_stat)
    elif alternative == "right":
        extreme = perm_stats >= obs_stat
    else:
        extreme = perm_stats <= obs_stat

    c = int(extreme.sum())
    p_val = c / reps
    p_ci = pvalue_ci(c, reps, alpha=alpha, method=ci_method)

    # 6) (Optionally) compute coefficient CI artifacts
    coef_ci_bounds: Optional[tuple[float, float]] = None
    coef_ci_band = None
    band_valid_linear = linear_model  # True for linear path; False for generic

    if ci_mode != "none":
        if linear_model:
            if ci_mode in {"bounds", "grid"}:
                coef_ci_bounds = coef_ci_bounds_fast(
                    beta_obs=obs_stat,
                    beta_perm=perm_stats,
                    K_obs=K_obs,  # type: ignore[arg-type]
                    K_perm=cast(np.ndarray, K_perm_local),  # type: ignore[arg-type]
                    se=se_obs,
                    alpha=alpha,
                    ci_range=ci_range,
                    ci_step=ci_step,
                    alternative=alternative,
                )
            if ci_mode == "grid":
                coef_ci_band = coef_ci_band_fast(
                    beta_obs=obs_stat,
                    beta_perm=perm_stats,
                    K_obs=K_obs,  # type: ignore[arg-type]
                    K_perm=cast(np.ndarray, K_perm_local),  # type: ignore[arg-type]
                    se=se_obs,
                    ci_range=ci_range,
                    ci_step=ci_step,
                    alternative=alternative,
                )
        else:
            if coef_ci_generic:

                def _runner(beta0: float) -> float:
                    shifted = obs_stat - beta0
                    if alternative == "two-sided":
                        return (np.abs(perm_stats - beta0) >= abs(shifted)).mean()
                    elif alternative == "right":
                        return (perm_stats - beta0 >= shifted).mean()
                    else:
                        return (perm_stats - beta0 <= shifted).mean()

                se_scale = float(np.nanstd(perm_stats, ddof=1)) if reps >= 2 else 1.0
                if not (se_scale > 0.0 and np.isfinite(se_scale)):
                    se_scale = 1.0

                if ci_mode in {"bounds", "grid"}:
                    coef_ci_bounds = coef_ci_bounds_generic(
                        obs_stat,
                        runner=_runner,
                        alpha=alpha,
                        ci_range=ci_range,
                        ci_step=ci_step,
                        se=se_scale,
                        alternative=alternative,
                    )
                if ci_mode == "grid":
                    grid = (
                        np.arange(
                            -ci_range * se_scale, ci_range * se_scale + 1e-12, ci_step * se_scale
                        )
                        + obs_stat
                    )
                    pvals = np.array([_runner(b0) for b0 in grid], dtype=np.float64)
                    coef_ci_band = (grid, pvals)
                    band_valid_linear = False

    # 7) STRICT GATING of outputs before building result
    _bounds = coef_ci_bounds
    _band = coef_ci_band
    _band_valid_linear = bool(band_valid_linear)

    if ci_mode == "none":
        _bounds = None
        _band = None
        _band_valid_linear = False
    elif ci_mode == "bounds":
        _band = None
        # KEEP the linear flag: True for linear, False for generic
    elif ci_mode == "grid":
        _bounds = None
        if (stat_fn is not None) and (not coef_ci_generic):
            _band = None
            _band_valid_linear = False
    else:
        raise ValueError(f"Unknown ci_mode: {ci_mode!r}")

    # 8) Package results
    settings: Dict[str, object] = {
        "alpha": alpha,
        "seed": seed,
        "reps": reps,
        "ci_method": ci_method,
        "ci_mode": ci_mode,
        "ci_range": ci_range,
        "ci_step": ci_step,
        "alternative": alternative,
        "n_jobs": n_jobs,
        "coef_ci_generic": coef_ci_generic,
    }

    res = RitestResult(
        obs_stat=float(obs_stat),
        coef_ci_bounds=_bounds,
        pval=float(p_val),
        pval_ci=p_ci,
        reps=reps,
        c=c,
        alternative=alternative,
        stratified=v.has_strata,
        clustered=v.has_cluster,
        weights=v.weights is not None,
        coef_ci_band=_band,
        band_valid_linear=_band_valid_linear,
        settings=settings,
        perm_stats=perm_stats,
    )
    return res
