# src/omnibioai_tool_exec/execution/tools/blastn_real.py
from __future__ import annotations

import json
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple


OUTFMT6 = "6 sacc stitle evalue bitscore pident length"


@dataclass(frozen=True)
class BlastnConfig:
    """
    Real BLASTN runner config.

    database_map:
      logical db name -> absolute BLAST DB prefix (no extension)
      e.g. {"ecoli_demo": "/data/blastdb/ecoli_demo/ecoli_demo"}

    allowed_databases:
      optional allow-list of db names; if set, db must be present in the list.

    defaults:
      default input parameters if not present in 'inputs'
      e.g. {"max_hits": 10, "evalue": 1e-5, "dust": True}
    """
    blastn_path: str = "blastn"
    database_map: Dict[str, str] = field(default_factory=dict)
    allowed_databases: Optional[List[str]] = None
    defaults: Dict[str, Any] = field(default_factory=lambda: {"max_hits": 10, "evalue": 1e-5, "dust": True})


def resolve_db_path(cfg: BlastnConfig, db_name: str) -> str:
    if not db_name:
        raise ValueError("database is required")

    if cfg.allowed_databases is not None and db_name not in cfg.allowed_databases:
        raise ValueError(f"Database '{db_name}' is not allowed. Allowed: {cfg.allowed_databases}")

    if db_name not in cfg.database_map:
        raise ValueError(f"Unknown database '{db_name}'. Known: {sorted(cfg.database_map.keys())}")

    db_prefix = cfg.database_map[db_name]
    if not db_prefix.startswith("/"):
        raise ValueError(f"DB path must be absolute. Got: {db_prefix}")

    # Optional sanity check: prefix should exist with some BLAST DB files around it.
    # We won't enforce all extensions because BLASTDB v5 uses multiple files.
    parent = Path(db_prefix).parent
    if not parent.exists():
        raise ValueError(f"DB parent directory does not exist: {parent}")

    return db_prefix


def normalize_fasta(sequence: str) -> str:
    s = (sequence or "").strip()
    if not s:
        raise ValueError("sequence is empty")

    # Accept either raw sequence or FASTA.
    # If it doesn't look like FASTA, wrap it.
    if not s.startswith(">"):
        s = f">q\n{s}\n"

    # Ensure it ends with newline
    if not s.endswith("\n"):
        s += "\n"
    return s


def write_query_fasta(run_dir: Path, sequence: str) -> Path:
    run_dir.mkdir(parents=True, exist_ok=True)
    qfa = run_dir / "query.fa"
    qfa.write_text(normalize_fasta(sequence))
    return qfa


def build_blastn_cmd(
    cfg: BlastnConfig,
    query_fa: Path,
    db_prefix: str,
    max_hits: int,
    evalue: float,
    cpu: int,
    dust: bool,
    out_tsv: Optional[Path] = None,
    extra_args: Optional[List[str]] = None,
) -> List[str]:
    cmd = [
        cfg.blastn_path,
        "-query", str(query_fa),
        "-db", db_prefix,
        "-max_target_seqs", str(max_hits),
        "-evalue", str(evalue),
        "-num_threads", str(max(1, cpu)),
        "-outfmt", OUTFMT6,
        "-dust", "yes" if dust else "no",
    ]

    if out_tsv is not None:
        cmd += ["-out", str(out_tsv)]

    if extra_args:
        cmd += list(extra_args)

    return cmd


def parse_outfmt6(tsv_text: str) -> List[Dict[str, Any]]:
    """
    outfmt6 columns (as configured):
      sacc stitle evalue bitscore pident length
    """
    hits: List[Dict[str, Any]] = []
    for line in (tsv_text or "").splitlines():
        line = line.strip()
        if not line:
            continue
        parts = line.split("\t")
        if len(parts) < 6:
            continue

        sacc, stitle, evalue, bitscore, pident, length = parts[:6]
        hits.append(
            {
                "accession": sacc,
                "description": stitle,
                "evalue": float(evalue) if evalue not in ("", "NA") else None,
                "bitscore": float(bitscore) if bitscore not in ("", "NA") else None,
                "pident": float(pident) if pident not in ("", "NA") else None,
                "align_len": int(float(length)) if length not in ("", "NA") else None,
            }
        )
    return hits


def _get_int(d: Dict[str, Any], key: str, default: int) -> int:
    v = d.get(key, None)
    if v is None or v == "":
        return default
    return int(v)


def _get_float(d: Dict[str, Any], key: str, default: float) -> float:
    v = d.get(key, None)
    if v is None or v == "":
        return default
    return float(v)


def _get_bool(d: Dict[str, Any], key: str, default: bool) -> bool:
    v = d.get(key, None)
    if v is None or v == "":
        return default
    if isinstance(v, bool):
        return v
    s = str(v).strip().lower()
    return s in ("1", "true", "yes", "y", "on")


def run_blastn_local(
    cfg: BlastnConfig,
    run_dir: Path,
    inputs: Dict[str, Any],
    resources: Dict[str, Any],
    extra_args: Optional[List[str]] = None,
) -> Tuple[int, str, str, Dict[str, Any]]:
    """
    Run blastn on the local host and return:
      (exit_code, stdout, stderr, results_json)

    NOTE:
      - max_hits/evalue/dust are read from *inputs*
      - cpu is read from *resources*
    """
    sequence = inputs.get("sequence", "")
    database = inputs.get("database", "")

    cpu = _get_int(resources, "cpu", 1)

    max_hits = _get_int(inputs, "max_hits", int(cfg.defaults.get("max_hits", 10)))
    evalue = _get_float(inputs, "evalue", float(cfg.defaults.get("evalue", 1e-5)))
    dust = _get_bool(inputs, "dust", bool(cfg.defaults.get("dust", True)))

    db_prefix = resolve_db_path(cfg, str(database))
    qfa = write_query_fasta(run_dir, str(sequence))

    # write output to file (better for Slurm and reproducibility)
    out_tsv = run_dir / "blast.outfmt6.tsv"
    cmd = build_blastn_cmd(cfg, qfa, db_prefix, max_hits, evalue, cpu, dust, out_tsv=out_tsv, extra_args=extra_args)

    p = subprocess.run(cmd, capture_output=True, text=True, check=False)
    stdout = p.stdout or ""
    stderr = p.stderr or ""

    # Read hits from file if present (preferred), else from stdout
    tsv_text = ""
    if out_tsv.exists():
        tsv_text = out_tsv.read_text()
    else:
        tsv_text = stdout

    hits = parse_outfmt6(tsv_text)

    # Persist meta for debugging
    meta = {
        "cmd": cmd,
        "outfmt": OUTFMT6,
        "inputs": {"database": database, "max_hits": max_hits, "evalue": evalue, "dust": dust},
        "resources": {"cpu": cpu},
        "exit_code": p.returncode,
    }
    (run_dir / "meta.json").write_text(json.dumps(meta, indent=2))
    (run_dir / "stderr.txt").write_text(stderr)

    results = {"hits": hits, "outfmt": OUTFMT6, "meta": meta}
    return p.returncode, stdout, stderr, results


def run_blastn_in_dir(
    cfg: BlastnConfig,
    run_dir: Path,
    sequence: str,
    database: str,
    max_hits: int = 10,
    evalue: float = 1e-5,
    cpu: int = 1,
    dust: bool = True,
    extra_args: Optional[List[str]] = None,
) -> Dict[str, Any]:
    """
    Convenience API for Slurm job scripts:
    - writes query.fa
    - writes blast.outfmt6.tsv
    - writes stderr.txt and meta.json
    - returns results dict with hits

    Raises RuntimeError on non-zero exit.
    """
    run_dir.mkdir(parents=True, exist_ok=True)

    inputs = {
        "sequence": sequence,
        "database": database,
        "max_hits": max_hits,
        "evalue": evalue,
        "dust": dust,
    }
    resources = {"cpu": cpu}

    rc, _stdout, stderr, results = run_blastn_local(cfg, run_dir, inputs, resources, extra_args=extra_args)
    if rc != 0:
        raise RuntimeError(f"blastn failed (exit_code={rc}). stderr tail: {stderr[-1200:]}")
    return results
