from __future__ import annotations

import subprocess
import time
from pathlib import Path
from typing import Any, Dict, Optional

from omnibioai_tool_exec.execution.adapters.base import Adapter
from omnibioai_tool_exec.models.capabilities import ServerCapabilities, ToolCapability


class SlurmAdapter(Adapter):
    """
    Local Slurm adapter (sbatch/squeue) that does NOT require slurmdbd/sacct.

    Strategy:
      - On submit(): create a per-run directory under work_root, write sbatch script,
        set --output/--error to files in that directory.
      - Capture Slurm JobID from sbatch output.
      - Keep an in-memory map: job_id -> {run_dir, stdout_path, stderr_path}.
      - status(): use squeue to detect RUNNING vs COMPLETED (queue empty -> completed).
      - logs(): tail stdout file.
      - results(): return stdout/stderr content + basic exit_code inference.
        (exit code is best-effort without sacct; we treat missing stderr + completed as 0)
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self.config = config
        self._jobs: Dict[str, Dict[str, str]] = {}  # job_id -> paths

    def adapter_type(self) -> str:
        return "slurm"

    def handshake(self) -> ServerCapabilities:
        # Advertise tools this backend can run. Keep it aligned with your tools YAML.
        return ServerCapabilities(
            engines=["slurm"],
            tools=[
                ToolCapability(tool_id="echo_test", version="1.0", features={}),
            ],
            resources={"max_cpu": 20, "max_ram_gb": 128},
            storage={},
            policies={"max_runtime_minutes": 1440},
        )

    def validate(self, tool_id: str, inputs: Dict[str, Any], resources: Dict[str, Any]) -> Dict[str, Any]:
        errors = []

        if tool_id != "echo_test":
            errors.append({"code": "UNSUPPORTED_TOOL", "message": f"{tool_id} not supported by SlurmAdapter"})
        else:
            msg = inputs.get("message")
            if not isinstance(msg, str) or not msg.strip():
                errors.append({"field": "message", "message": "message is required"})

        # Basic resource validation
        cpu = resources.get("cpu", 1)
        try:
            cpu_i = int(cpu)
            if cpu_i < 1:
                errors.append({"field": "resources.cpu", "message": "cpu must be >= 1"})
        except Exception:
            errors.append({"field": "resources.cpu", "message": "cpu must be an integer"})

        return {"ok": len(errors) == 0, "errors": errors, "warnings": []}

    def _run(self, cmd: list[str], check: bool = True) -> subprocess.CompletedProcess:
        return subprocess.run(cmd, capture_output=True, text=True, check=check)

    def submit(self, tool_id: str, inputs: Dict[str, Any], resources: Dict[str, Any]) -> str:
        work_root = Path(self.config.get("work_root", "/tmp/omnibioai_tes_runs"))
        work_root.mkdir(parents=True, exist_ok=True)

        # Unique local run directory (independent from TES run_id)
        run_dir = work_root / f"run_{int(time.time() * 1000)}"
        run_dir.mkdir(parents=True, exist_ok=True)

        partition = self.config.get("partition", "debug")
        cpu = int(resources.get("cpu", 1))

        stdout_path = run_dir / "stdout.txt"
        stderr_path = run_dir / "stderr.txt"
        script_path = run_dir / "job.sbatch"

        # Keep it simple and safe: echo a single message
        message = inputs["message"].replace('"', '\\"')

        script_path.write_text(
            "\n".join(
                [
                    "#!/bin/bash",
                    f"#SBATCH -p {partition}",
                    f"#SBATCH -c {cpu}",
                    f"#SBATCH --output={stdout_path}",
                    f"#SBATCH --error={stderr_path}",
                    "#SBATCH -J tes_echo_test",
                    "",
                    f'echo "{message}"',
                    "",
                ]
            )
        )

        # Submit
        r = self._run(["sbatch", str(script_path)], check=True)

        # Typical output: "Submitted batch job 123"
        parts = r.stdout.strip().split()
        job_id = parts[-1] if parts else ""
        if not job_id.isdigit():
            raise RuntimeError(f"Unexpected sbatch output: {r.stdout.strip()}")

        # Save mapping so logs/results work even without sacct
        self._jobs[job_id] = {
            "run_dir": str(run_dir),
            "stdout": str(stdout_path),
            "stderr": str(stderr_path),
            "script": str(script_path),
        }
        return job_id

    def status(self, remote_run_id: str) -> Dict[str, Any]:
        job_id = str(remote_run_id).strip()
        if not job_id:
            return {"state": "FAILED", "message": "missing job id"}

        # If job appears in squeue => still running/queued
        r = self._run(["squeue", "-j", job_id, "-h", "-o", "%T"], check=False)
        state = r.stdout.strip()

        if state:
            # squeue returns states like: PENDING, RUNNING, COMPLETING, etc.
            if state in ("PENDING", "CONFIGURING"):
                return {"state": "QUEUED", "slurm_state": state}
            return {"state": "RUNNING", "slurm_state": state}

        # Not in queue => assume completed (best effort without sacct)
        return {"state": "COMPLETED", "slurm_state": "NOT_IN_QUEUE"}

    def logs(self, remote_run_id: str, tail: int = 200) -> str:
        job_id = str(remote_run_id).strip()
        meta = self._jobs.get(job_id)
        if not meta:
            return f"[{job_id}] no local metadata (service restart?)."

        p = Path(meta["stdout"])
        if not p.exists():
            return f"[{job_id}] stdout not created yet."

        lines = p.read_text(errors="replace").splitlines()
        if tail and tail > 0:
            lines = lines[-tail:]
        return "\n".join(lines)

    def _read_file(self, path: Optional[str]) -> str:
        if not path:
            return ""
        p = Path(path)
        if not p.exists():
            return ""
        return p.read_text(errors="replace")

    def results(self, remote_run_id: str) -> Dict[str, Any]:
        job_id = str(remote_run_id).strip()
        meta = self._jobs.get(job_id, {})

        stdout = self._read_file(meta.get("stdout"))
        stderr = self._read_file(meta.get("stderr"))

        # Best-effort exit_code without sacct:
        # If job is not in queue and stderr is empty => 0, else 1 (you can improve later)
        st = self.status(job_id)
        if st.get("state") != "COMPLETED":
            return {"ok": False, "error": {"code": "NOT_READY", "message": f"state={st.get('state')}"}, "job_id": job_id}

        exit_code = 0 if stderr.strip() == "" else 1

        return {
            "ok": True,
            "job_id": job_id,
            "exit_code": exit_code,
            "stdout": stdout,
            "stderr": stderr,
            "paths": meta,
        }
