from __future__ import annotations

import json
import shutil
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, Tuple


@dataclass(frozen=True)
class BcftoolsConfig:
    """
    Real bcftools view runner config.

    bcftools_path: absolute path or "bcftools" on PATH
    allowed_formats: e.g. ["vcf", "bcf"]
    default_format: "vcf" or "bcf"
    max_stdout_bytes: protect API from huge stdout/stderr
    require_index_for_region: if True, require .tbi/.csi when region is provided
    """
    bcftools_path: str = "bcftools"
    allowed_formats: list[str] = field(default_factory=lambda: ["vcf", "bcf"])
    default_format: str = "vcf"
    max_stdout_bytes: int = 2_000_000
    require_index_for_region: bool = True


def _which_or_error(path: str) -> str:
    # If absolute path was given, require it exists
    if path.startswith("/"):
        p = Path(path)
        if not p.exists():
            raise FileNotFoundError(f"bcftools_path not found: {p}")
        return str(p)

    w = shutil.which(path)
    if not w:
        raise RuntimeError("bcftools not found on PATH (and bcftools_path not absolute)")
    return w


def _tail_bytes(s: str, limit: int) -> str:
    b = (s or "").encode("utf-8", errors="replace")
    if len(b) <= limit:
        return s or ""
    return b[-limit:].decode("utf-8", errors="replace")


def _has_index_for(path: Path) -> bool:
    # For .vcf.gz => .tbi OR .csi
    # For .bcf    => .csi (often) but allow .csi/.tbi as a loose check
    return (path.with_suffix(path.suffix + ".tbi").exists() or
            path.with_suffix(path.suffix + ".csi").exists() or
            Path(str(path) + ".tbi").exists() or
            Path(str(path) + ".csi").exists())


def _output_type_flag(fmt: str) -> str:
    # bcftools view -O v|z|b|u
    # We'll default to VCF text ("v") or BCF binary ("b").
    if fmt == "vcf":
        return "v"
    if fmt == "bcf":
        return "b"
    raise ValueError(f"invalid output_format: {fmt}")


def validate_inputs(cfg: BcftoolsConfig, inputs: Dict[str, Any], resolved_input_path: Path) -> Dict[str, Any]:
    errors = []

    if not resolved_input_path.exists():
        errors.append({"field": "input_path", "message": f"Input does not exist: {resolved_input_path}"})

    region = inputs.get("region")
    if region and cfg.require_index_for_region:
        # region queries generally require indexed compressed VCF/BCF
        if not _has_index_for(resolved_input_path):
            errors.append(
                {"field": "region", "message": f"region requested but index not found for {resolved_input_path} (.tbi/.csi)"}
            )

    out_fmt = str(inputs.get("output_format") or cfg.default_format).strip().lower()
    if out_fmt not in cfg.allowed_formats:
        errors.append({"field": "output_format", "message": f"must be one of {cfg.allowed_formats}"})

    include = inputs.get("include")
    exclude = inputs.get("exclude")
    if include and exclude:
        errors.append({"field": "include/exclude", "message": "provide only one of include (-i) or exclude (-e)"})

    return {"ok": len(errors) == 0, "errors": errors, "warnings": []}


def run_bcftools_view_local(
    cfg: BcftoolsConfig,
    run_dir: Path,
    input_path: Path,
    inputs: Dict[str, Any],
    resources: Dict[str, Any],
) -> Tuple[int, str, str, Dict[str, Any]]:
    """
    Runs bcftools view and writes output to run_dir.
    Returns (exit_code, stdout, stderr, results_dict).
    """
    run_dir.mkdir(parents=True, exist_ok=True)

    bcftools = _which_or_error(cfg.bcftools_path)

    region = inputs.get("region")
    samples = inputs.get("samples")
    include = inputs.get("include")
    exclude = inputs.get("exclude")
    max_records = inputs.get("max_records")

    out_fmt = str(inputs.get("output_format") or cfg.default_format).strip().lower()
    out_flag = _output_type_flag(out_fmt)

    # Output file name
    out_name = "output.vcf" if out_fmt == "vcf" else "output.bcf"
    out_file = run_dir / out_name

    cmd = [bcftools, "view", "-O", out_flag, "-o", str(out_file)]

    # threads are supported for some operations; safe to pass if >1
    cpu = int(resources.get("cpu", 1) or 1)
    if cpu > 1:
        cmd += ["--threads", str(cpu)]

    if region:
        cmd += ["-r", str(region)]
    if samples:
        cmd += ["-s", str(samples)]
    if include:
        cmd += ["-i", str(include)]
    if exclude:
        cmd += ["-e", str(exclude)]

    # NOTE: bcftools view doesn't have a universal "-n max_records" limiter for all versions.
    # Safer: limit after the fact by counting output records. We keep max_records in meta only.
    cmd += [str(input_path)]

    p = subprocess.run(cmd, capture_output=True, text=True, check=False)
    stdout = p.stdout or ""
    stderr = p.stderr or ""

    # Count records in output (best-effort)
    n_records: Optional[int] = None
    try:
        if out_file.exists() and out_fmt == "vcf":
            # Count non-header lines
            n_records = sum(1 for line in out_file.read_text(errors="replace").splitlines() if line and not line.startswith("#"))
        elif out_file.exists() and out_fmt == "bcf":
            # Use bcftools view -H to count (fast-ish) if bcftools exists; otherwise None
            count_cmd = [bcftools, "view", "-H", str(out_file)]
            cp = subprocess.run(count_cmd, capture_output=True, text=True, check=False)
            if cp.returncode == 0:
                n_records = len((cp.stdout or "").splitlines())
    except Exception:
        n_records = None

    meta = {
        "cmd": cmd,
        "inputs": {
            "region": region,
            "samples": samples,
            "include": include,
            "exclude": exclude,
            "output_format": out_fmt,
            "max_records": max_records,
        },
        "resources": {"cpu": cpu},
        "exit_code": p.returncode,
        "output_path": str(out_file),
        "n_records": n_records,
    }

    (run_dir / "meta.json").write_text(json.dumps(meta, indent=2))
    (run_dir / "stderr.txt").write_text(_tail_bytes(stderr, cfg.max_stdout_bytes))

    results = {
        "output_path": str(out_file),
        "n_records": n_records,
        "stdout_tail": _tail_bytes(stdout, cfg.max_stdout_bytes),
        "stderr_tail": _tail_bytes(stderr, cfg.max_stdout_bytes),
        "meta": meta,
    }

    return p.returncode, stdout, stderr, results
