from __future__ import annotations

import json
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Dict, Optional

from omnibioai_tool_exec.execution.adapters.base import Adapter
from omnibioai_tool_exec.models.capabilities import ServerCapabilities, ToolCapability

# Real BLAST tool helpers (local parsing + config)
from omnibioai_tool_exec.execution.tools.blastn_real import (
    BlastnConfig,
    parse_outfmt6,
)


class SlurmAdapter(Adapter):
    """
    Local Slurm adapter (sbatch/squeue) that does NOT require slurmdbd/sacct.

    Supports:
      - echo_test (simple stdout)
      - blastn (real BLASTN using local blastn binary + local BLAST DB)

    Strategy:
      - submit(): create run_dir, write sbatch script, submit via sbatch, store metadata in memory
      - status(): use squeue only
      - logs(): tail stdout file
      - results(): read files from run_dir (stdout/stderr + tool-specific artifacts)
        - blastn: read blast.outfmt6.tsv and meta.json if present
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self.config = config
        # job_id -> metadata (paths, tool_id, etc.)
        self._jobs: Dict[str, Dict[str, str]] = {}

    def adapter_type(self) -> str:
        return "slurm"

    # ----------------------------
    # Config helpers
    # ----------------------------
    def _blastn_cfg(self) -> BlastnConfig:
        """
        Reads BLAST settings from server config (servers.yaml):

          blastn_path: blastn
          database_map:
            ecoli_demo: /data/blastdb/ecoli_demo/ecoli_demo
          allowed_databases: [ecoli_demo]   # optional
          defaults:
            max_hits: 10
            evalue: 1e-5
            dust: true
        """
        blastn_path = str(self.config.get("blastn_path", "blastn"))
        database_map = dict(self.config.get("database_map", {}) or {})
        allowed = self.config.get("allowed_databases", None)
        defaults = dict(self.config.get("defaults", {}) or {})

        return BlastnConfig(
            blastn_path=blastn_path,
            database_map=database_map,
            allowed_databases=list(allowed) if isinstance(allowed, (list, tuple)) else None,
            defaults=defaults if defaults else {"max_hits": 10, "evalue": 1e-5, "dust": True},
        )

    def _run(self, cmd: list[str], check: bool = True) -> subprocess.CompletedProcess:
        return subprocess.run(cmd, capture_output=True, text=True, check=check)

    def _read_file(self, path: Optional[str]) -> str:
        if not path:
            return ""
        p = Path(path)
        if not p.exists():
            return ""
        return p.read_text(errors="replace")

    # ----------------------------
    # Capabilities / validation
    # ----------------------------
    def handshake(self) -> ServerCapabilities:
        cfg = self.config or {}

        # Pull blast config from servers.yaml (slurm_local.config.blast)
        blast_cfg = (cfg.get("blast") or {})
        db_map = (blast_cfg.get("database_map") or {})
        allowed = blast_cfg.get("allowed_databases")

        # Prefer explicit allow-list; else advertise all db_map keys
        dbs = list(allowed) if isinstance(allowed, list) and len(allowed) > 0 else sorted(db_map.keys())

        # allow overriding resource/policy caps from config too
        res_cfg = (cfg.get("resources") or {})
        pol_cfg = (cfg.get("policies") or {})

        max_cpu = int(res_cfg.get("max_cpu", 20) or 20)
        max_ram = int(res_cfg.get("max_ram_gb", 128) or 128)
        max_runtime = int(pol_cfg.get("max_runtime_minutes", 1440) or 1440)

        tools = [ToolCapability(tool_id="echo_test", version="1.0", features={})]

        # Advertise blastn only if blast configured
        if dbs:
            tools.append(ToolCapability(tool_id="blastn", version="real", features={"databases": dbs}))
        else:
            # still advertise blastn, but be explicit about why db list is empty
            tools.append(ToolCapability(tool_id="blastn", version="real", features={"databases": []}))

        return ServerCapabilities(
            engines=["slurm"],
            tools=tools,
            resources={"max_cpu": max_cpu, "max_ram_gb": max_ram},
            storage={},
            policies={"max_runtime_minutes": max_runtime},
        )



    def validate(self, tool_id: str, inputs: Dict[str, Any], resources: Dict[str, Any]) -> Dict[str, Any]:
        errors = []

        # Basic resource validation
        cpu = resources.get("cpu", 1)
        try:
            cpu_i = int(cpu)
            if cpu_i < 1:
                errors.append({"field": "resources.cpu", "message": "cpu must be >= 1"})
        except Exception:
            errors.append({"field": "resources.cpu", "message": "cpu must be an integer"})

        if tool_id == "echo_test":
            msg = inputs.get("message")
            if not isinstance(msg, str) or not msg.strip():
                errors.append({"field": "message", "message": "message is required"})

        elif tool_id == "blastn":
            seq = inputs.get("sequence")
            db = inputs.get("database")

            if not isinstance(seq, str) or not seq.strip():
                errors.append({"field": "sequence", "message": "sequence is required"})
            if not isinstance(db, str) or not db.strip():
                errors.append({"field": "database", "message": "database is required"})

            # Validate DB name exists in map + allowlist
            try:
                cfg = self._blastn_cfg()
                # resolve_db_path will raise if invalid
                from omnibioai_tool_exec.execution.tools.blastn_real import resolve_db_path
                resolve_db_path(cfg, str(db))
            except Exception as e:
                errors.append({"field": "database", "message": str(e)})

            # Optional input validation
            if "max_hits" in inputs:
                try:
                    mh = int(inputs["max_hits"])
                    if mh < 1:
                        errors.append({"field": "max_hits", "message": "max_hits must be >= 1"})
                except Exception:
                    errors.append({"field": "max_hits", "message": "max_hits must be an integer"})

            if "evalue" in inputs:
                try:
                    float(inputs["evalue"])
                except Exception:
                    errors.append({"field": "evalue", "message": "evalue must be a number"})

        else:
            errors.append({"code": "UNSUPPORTED_TOOL", "message": f"{tool_id} not supported by SlurmAdapter"})

        return {"ok": len(errors) == 0, "errors": errors, "warnings": []}

    # ----------------------------
    # Submit
    # ----------------------------
    def submit(self, tool_id: str, inputs: Dict[str, Any], resources: Dict[str, Any]) -> str:
        work_root = Path(self.config.get("work_root", "/tmp/omnibioai_tes_runs"))
        work_root.mkdir(parents=True, exist_ok=True)

        run_dir = work_root / f"run_{int(time.time() * 1000)}"
        run_dir.mkdir(parents=True, exist_ok=True)

        partition = self.config.get("partition", "debug")
        cpu = int(resources.get("cpu", 1))

        stdout_path = run_dir / "stdout.txt"
        stderr_path = run_dir / "stderr.txt"
        script_path = run_dir / "job.sbatch"

        # Use the same python interpreter running the service (important!)
        py = sys.executable

        if tool_id == "echo_test":
            message = str(inputs.get("message", "")).replace('"', '\\"')

            script_path.write_text(
                "\n".join(
                    [
                        "#!/bin/bash",
                        f"#SBATCH -p {partition}",
                        f"#SBATCH -c {cpu}",
                        f"#SBATCH --output={stdout_path}",
                        f"#SBATCH --error={stderr_path}",
                        "#SBATCH -J tes_echo_test",
                        "",
                        "set -euo pipefail",
                        f'echo "{message}"',
                        "",
                    ]
                )
            )

        elif tool_id == "blastn":
            # Prepare query + parameters
            sequence = str(inputs.get("sequence", "")).strip()
            database = str(inputs.get("database", "")).strip()

            max_hits = int(inputs.get("max_hits", self._blastn_cfg().defaults.get("max_hits", 10)))
            evalue = float(inputs.get("evalue", self._blastn_cfg().defaults.get("evalue", 1e-5)))
            dust = bool(inputs.get("dust", self._blastn_cfg().defaults.get("dust", True)))

            # Write query.fa now (so the sbatch script just runs)
            (run_dir / "query.fa").write_text(sequence if sequence.endswith("\n") else sequence + "\n")

            # Serialize blast config for the job
            cfg = self._blastn_cfg()
            cfg_json = {
                "blastn_path": cfg.blastn_path,
                "database_map": cfg.database_map,
                "allowed_databases": cfg.allowed_databases,
                "defaults": cfg.defaults,
            }
            (run_dir / "blast_cfg.json").write_text(json.dumps(cfg_json, indent=2))

            # Job script runs python to execute blastn_real.run_blastn_in_dir()
            script_path.write_text(
                "\n".join(
                    [
                        "#!/bin/bash",
                        f"#SBATCH -p {partition}",
                        f"#SBATCH -c {cpu}",
                        f"#SBATCH --output={stdout_path}",
                        f"#SBATCH --error={stderr_path}",
                        "#SBATCH -J tes_blastn",
                        "",
                        "set -euo pipefail",
                        f'cd "{run_dir}"',
                        "",
                        f'{py} - <<\'PY\'',
                        "import json",
                        "from pathlib import Path",
                        "from omnibioai_tool_exec.execution.tools.blastn_real import BlastnConfig, run_blastn_in_dir",
                        "",
                        "run_dir = Path('.')",
                        "cfg_json = json.loads((run_dir / 'blast_cfg.json').read_text())",
                        "cfg = BlastnConfig(",
                        "    blastn_path=cfg_json.get('blastn_path', 'blastn'),",
                        "    database_map=cfg_json.get('database_map', {}) or {},",
                        "    allowed_databases=cfg_json.get('allowed_databases', None),",
                        "    defaults=cfg_json.get('defaults', {}) or {'max_hits':10, 'evalue':1e-5, 'dust':True},",
                        ")",
                        "",
                        "sequence = (run_dir / 'query.fa').read_text()",
                        f"database = {database!r}",
                        f"max_hits = {max_hits}",
                        f"evalue = {evalue}",
                        f"cpu = {cpu}",
                        f"dust = {str(dust)}",
                        "",
                        "results = run_blastn_in_dir(",
                        "    cfg=cfg,",
                        "    run_dir=run_dir,",
                        "    sequence=sequence,",
                        "    database=database,",
                        "    max_hits=max_hits,",
                        "    evalue=evalue,",
                        "    cpu=cpu,",
                        "    dust=dust,",
                        ")",
                        "",
                        "# also dump a compact results json for adapter to read easily",
                        "(run_dir / 'results.json').write_text(json.dumps({'hits': results.get('hits', [])}, indent=2))",
                        "print('BLASTN completed. hits=', len(results.get('hits', [])))",
                        "PY",
                        "",
                    ]
                )
            )

        else:
            raise RuntimeError(f"Unsupported tool_id for slurm submit: {tool_id}")

        # Submit to Slurm
        r = self._run(["sbatch", str(script_path)], check=True)

        parts = r.stdout.strip().split()
        job_id = parts[-1] if parts else ""
        if not job_id.isdigit():
            raise RuntimeError(f"Unexpected sbatch output: {r.stdout.strip()}")

        # Save mapping so logs/results work even without sacct
        self._jobs[job_id] = {
            "tool_id": tool_id,
            "run_dir": str(run_dir),
            "stdout": str(stdout_path),
            "stderr": str(stderr_path),
            "script": str(script_path),
        }
        return job_id

    # ----------------------------
    # Status / Logs / Results
    # ----------------------------
    def status(self, remote_run_id: str) -> Dict[str, Any]:
        job_id = str(remote_run_id).strip()
        if not job_id:
            return {"state": "FAILED", "message": "missing job id"}

        r = self._run(["squeue", "-j", job_id, "-h", "-o", "%T"], check=False)
        state = r.stdout.strip()

        if state:
            if state in ("PENDING", "CONFIGURING"):
                return {"state": "QUEUED", "slurm_state": state}
            return {"state": "RUNNING", "slurm_state": state}

        return {"state": "COMPLETED", "slurm_state": "NOT_IN_QUEUE"}

    def logs(self, remote_run_id: str, tail: int = 200) -> str:
        job_id = str(remote_run_id).strip()
        meta = self._jobs.get(job_id)
        if not meta:
            return f"[{job_id}] no local metadata (service restart?)."

        p = Path(meta["stdout"])
        if not p.exists():
            return f"[{job_id}] stdout not created yet."

        lines = p.read_text(errors="replace").splitlines()
        if tail and tail > 0:
            lines = lines[-tail:]
        return "\n".join(lines)

    def results(self, remote_run_id: str) -> Dict[str, Any]:
        job_id = str(remote_run_id).strip()
        meta = self._jobs.get(job_id, {})
        tool_id = meta.get("tool_id", "")

        st = self.status(job_id)
        if st.get("state") != "COMPLETED":
            return {
                "ok": False,
                "error": {"code": "NOT_READY", "message": f"state={st.get('state')}"},
                "job_id": job_id,
            }

        run_dir = Path(meta.get("run_dir", "")) if meta.get("run_dir") else None
        stdout = self._read_file(meta.get("stdout"))
        stderr = self._read_file(meta.get("stderr"))

        # Tool-specific results
        if tool_id == "blastn" and run_dir and run_dir.exists():
            # Prefer results.json created by the sbatch python snippet
            results_json_path = run_dir / "results.json"
            meta_json_path = run_dir / "meta.json"
            out_tsv_path = run_dir / "blast.outfmt6.tsv"

            hits = []
            meta_obj: Dict[str, Any] = {}

            if results_json_path.exists():
                try:
                    obj = json.loads(results_json_path.read_text())
                    hits = obj.get("hits") or []
                except Exception:
                    hits = []

            # Fall back to parsing outfmt6 TSV
            if not hits and out_tsv_path.exists():
                try:
                    hits = parse_outfmt6(out_tsv_path.read_text())
                except Exception:
                    hits = []

            if meta_json_path.exists():
                try:
                    meta_obj = json.loads(meta_json_path.read_text())
                except Exception:
                    meta_obj = {}

            # If blastn_real wrote exit_code in meta.json, use it.
            exit_code = None
            if isinstance(meta_obj, dict) and "exit_code" in meta_obj:
                try:
                    exit_code = int(meta_obj["exit_code"])
                except Exception:
                    exit_code = None

            # Best-effort fallback if no exit_code known
            if exit_code is None:
                exit_code = 0 if stderr.strip() == "" else 1

            return {
                "ok": True,
                "job_id": job_id,
                "exit_code": exit_code,
                "results": {"hits": hits},
                "stdout": stdout,
                "stderr": stderr,
                "paths": meta,
                "meta": meta_obj,
            }

        # echo_test and others
        exit_code = 0 if stderr.strip() == "" else 1
        return {
            "ok": True,
            "job_id": job_id,
            "exit_code": exit_code,
            "stdout": stdout,
            "stderr": stderr,
            "paths": meta,
        }
