# src/omnibioai_tool_exec/execution/tools/blastn_real.py
from __future__ import annotations

import json
import os
import shlex
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple


@dataclass(frozen=True)
class BlastnConfig:
    blastn_path: str = "blastn"
    # db name -> absolute DB prefix (no extension)
    database_map: Dict[str, str] = None  # type: ignore
    allowed_databases: Optional[List[str]] = None
    defaults: Dict[str, Any] = None  # type: ignore

    def __post_init__(self):
        object.__setattr__(self, "database_map", self.database_map or {})
        object.__setattr__(self, "defaults", self.defaults or {})


OUTFMT6 = "6 sacc stitle evalue bitscore pident length"


def resolve_db_path(cfg: BlastnConfig, db_name: str) -> str:
    if cfg.allowed_databases is not None and db_name not in cfg.allowed_databases:
        raise ValueError(f"Database '{db_name}' is not allowed")

    if db_name not in cfg.database_map:
        raise ValueError(f"Unknown database '{db_name}'. Known: {sorted(cfg.database_map.keys())}")

    db_path = cfg.database_map[db_name]
    if not db_path.startswith("/"):
        raise ValueError(f"DB path must be absolute. Got: {db_path}")

    return db_path


def write_query_fasta(run_dir: Path, sequence: str) -> Path:
    run_dir.mkdir(parents=True, exist_ok=True)
    qfa = run_dir / "query.fa"
    qfa.write_text(sequence.strip() + ("\n" if not sequence.endswith("\n") else ""))
    return qfa


def build_blastn_cmd(
    cfg: BlastnConfig,
    query_fa: Path,
    db_prefix: str,
    max_hits: int,
    evalue: float,
    cpu: int,
    dust: bool,
    extra_args: Optional[List[str]] = None,
) -> List[str]:
    cmd = [
        cfg.blastn_path,
        "-query", str(query_fa),
        "-db", db_prefix,
        "-max_target_seqs", str(max_hits),
        "-evalue", str(evalue),
        "-num_threads", str(max(1, cpu)),
        "-outfmt", OUTFMT6,
    ]
    # optional low complexity filter
    cmd += ["-dust", "yes" if dust else "no"]

    if extra_args:
        cmd += extra_args

    return cmd


def parse_outfmt6(tsv_text: str) -> List[Dict[str, Any]]:
    """
    outfmt6 columns (as configured):
      sacc stitle evalue bitscore pident length
    """
    hits: List[Dict[str, Any]] = []
    for line in tsv_text.splitlines():
        line = line.strip()
        if not line:
            continue
        parts = line.split("\t")
        if len(parts) < 6:
            continue
        sacc, stitle, evalue, bitscore, pident, length = parts[:6]
        hits.append(
            {
                "accession": sacc,
                "description": stitle,
                "evalue": float(evalue) if evalue not in ("", "NA") else None,
                "bitscore": float(bitscore) if bitscore not in ("", "NA") else None,
                "pident": float(pident) if pident not in ("", "NA") else None,
                "align_len": int(float(length)) if length not in ("", "NA") else None,
            }
        )
    return hits


def run_blastn_local(
    cfg: BlastnConfig,
    run_dir: Path,
    sequence: str,
    database: str,
    resources: Dict[str, Any],
) -> Tuple[int, str, str, Dict[str, Any]]:
    """
    Local execution helper (optional). For Slurm you typically only need:
      - write_query_fasta
      - build_blastn_cmd
      - parse_outfmt6
    """
    cpu = int(resources.get("cpu", 1) or 1)
    max_hits = int(resources.get("max_hits", cfg.defaults.get("max_hits", 10)) or 10)
    evalue = float(resources.get("evalue", cfg.defaults.get("evalue", 1e-5)) or 1e-5)
    dust = bool(resources.get("dust", cfg.defaults.get("dust", True)))

    db_prefix = resolve_db_path(cfg, database)
    qfa = write_query_fasta(run_dir, sequence)
    cmd = build_blastn_cmd(cfg, qfa, db_prefix, max_hits, evalue, cpu, dust)

    p = subprocess.run(cmd, capture_output=True, text=True)
    stdout = p.stdout or ""
    stderr = p.stderr or ""

    results = {"hits": parse_outfmt6(stdout), "outfmt": OUTFMT6, "cmd": cmd}
    return p.returncode, stdout, stderr, results
