from __future__ import annotations

import os
import subprocess
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple


@dataclass(frozen=True)
class BlastDefaults:
    max_hits: int = 10
    evalue: str = "1e-5"
    dust: bool = True


@dataclass(frozen=True)
class BlastConfig:
    blastn_path: str = "/usr/bin/blastn"
    database_map: Dict[str, str] = None  # name -> filesystem path prefix
    allowed_databases: List[str] = None
    defaults: BlastDefaults = BlastDefaults()

    @staticmethod
    def from_server_config(server_cfg: Dict[str, Any]) -> "BlastConfig":
        """
        server_cfg is the server's config dict (the same thing exposed under /api/servers[].config).
        We expect:
          server_cfg["blast"]["database_map"] etc.
        """
        blast = (server_cfg or {}).get("blast") or {}

        database_map = blast.get("database_map") or {}
        if not isinstance(database_map, dict):
            database_map = {}

        allowed = blast.get("allowed_databases") or []
        if not isinstance(allowed, list):
            allowed = []

        defaults_raw = blast.get("defaults") or {}
        if not isinstance(defaults_raw, dict):
            defaults_raw = {}

        defaults = BlastDefaults(
            max_hits=int(defaults_raw.get("max_hits", 10)),
            evalue=str(defaults_raw.get("evalue", "1e-5")),
            dust=bool(defaults_raw.get("dust", True)),
        )

        blastn_path = str(blast.get("blastn_path") or "/usr/bin/blastn")

        return BlastConfig(
            blastn_path=blastn_path,
            database_map=database_map,
            allowed_databases=allowed,
            defaults=defaults,
        )


def _get_selected_server_cfg(context: Any) -> Dict[str, Any]:
    """
    This tries hard to extract the selected server config from whatever execution context your
    framework passes to tool validators/runners.

    Adapt the attribute names ONLY if your context differs.
    """
    # Common patterns:
    # context.server, context.server_record, context.selected_server, context.execution.server
    for attr in ("server", "selected_server", "server_record"):
        srv = getattr(context, attr, None)
        if srv:
            # srv may be dict-like or object-like
            if isinstance(srv, dict):
                return (srv.get("config") or {}) if isinstance(srv.get("config"), dict) else {}
            cfg = getattr(srv, "config", None)
            if isinstance(cfg, dict):
                return cfg

    # Another common pattern: context has server_id only; config is attached elsewhere
    cfg = getattr(context, "server_config", None)
    if isinstance(cfg, dict):
        return cfg

    return {}


def _capability_databases(context: Any) -> List[str]:
    """
    If your server capabilities are attached to the context, use those as a fallback list.
    """
    for attr in ("server", "selected_server", "server_record"):
        srv = getattr(context, attr, None)
        if not srv:
            continue

        caps = None
        if isinstance(srv, dict):
            caps = srv.get("capabilities")
        else:
            caps = getattr(srv, "capabilities", None)

        if isinstance(caps, dict):
            tools = caps.get("tools") or []
            if isinstance(tools, list):
                for t in tools:
                    if isinstance(t, dict) and t.get("tool_id") == "blastn":
                        feats = t.get("features") or {}
                        dbs = feats.get("databases") or []
                        if isinstance(dbs, list):
                            return [str(x) for x in dbs]

    return []


def validate(inputs: Dict[str, Any], context: Any) -> Tuple[bool, List[Dict[str, str]], List[Dict[str, str]]]:
    """
    Returns (ok, errors, warnings)
    """
    errors: List[Dict[str, str]] = []
    warnings: List[Dict[str, str]] = []

    seq = (inputs or {}).get("sequence")
    db_name = (inputs or {}).get("database")

    if not seq or not str(seq).strip():
        errors.append({"field": "sequence", "message": "Missing FASTA sequence"})
        return False, errors, warnings

    if not db_name or not str(db_name).strip():
        errors.append({"field": "database", "message": "Missing database name"})
        return False, errors, warnings

    server_cfg = _get_selected_server_cfg(context)
    cfg = BlastConfig.from_server_config(server_cfg)

    known_from_cfg = sorted((cfg.database_map or {}).keys())
    known_from_caps = sorted(_capability_databases(context))
    known = known_from_cfg or known_from_caps

    # allowlist check
    if cfg.allowed_databases:
        if str(db_name) not in set(map(str, cfg.allowed_databases)):
            errors.append(
                {
                    "field": "database",
                    "message": f"Database '{db_name}' is not in allowed_databases. Allowed: {sorted(cfg.allowed_databases)}",
                }
            )
            return False, errors, warnings

    # mapping check (prefer cfg.database_map, but fall back to capabilities list if present)
    if cfg.database_map and str(db_name) in cfg.database_map:
        return True, errors, warnings

    # If cfg.database_map is empty but capabilities knows dbs, warn instead of hard fail here.
    # (The runner will still need a db path though — so this typically indicates mis-wiring.)
    if not cfg.database_map and known_from_caps:
        errors.append(
            {
                "field": "database",
                "message": (
                    f"Server advertised databases {known_from_caps} but blast.database_map is empty in server config. "
                    f"Fix server config injection into blastn_real.py."
                ),
            }
        )
        return False, errors, warnings

    errors.append({"field": "database", "message": f"Unknown database '{db_name}'. Known: {known}"})
    return False, errors, warnings


def run(inputs: Dict[str, Any], context: Any) -> Dict[str, Any]:
    """
    Executes blastn locally (on the slurm node / runner machine).

    Expected inputs:
      sequence (FASTA string)
      database (name in database_map)
      max_hits, evalue, dust (optional)
    """
    server_cfg = _get_selected_server_cfg(context)
    cfg = BlastConfig.from_server_config(server_cfg)

    seq = str((inputs or {}).get("sequence") or "").strip()
    db_name = str((inputs or {}).get("database") or "").strip()

    # Apply defaults
    max_hits = int((inputs or {}).get("max_hits") or cfg.defaults.max_hits)
    evalue = str((inputs or {}).get("evalue") or cfg.defaults.evalue)
    dust = bool((inputs or {}).get("dust") if (inputs or {}).get("dust") is not None else cfg.defaults.dust)

    if not cfg.database_map or db_name not in cfg.database_map:
        known = sorted((cfg.database_map or {}).keys())
        raise ValueError(
            f"Unknown database '{db_name}'. Known: {known}. "
            f"(blast.database_map missing or not injected from server config)"
        )

    db_path = cfg.database_map[db_name]

    # Write query to a temp file in a safe work dir.
    # If your framework provides a workdir, use it; else /tmp.
    work_root = getattr(context, "work_root", None) or "/tmp"
    os.makedirs(work_root, exist_ok=True)
    query_path = os.path.join(work_root, "query.fasta")

    with open(query_path, "w", encoding="utf-8") as f:
        f.write(seq if seq.endswith("\n") else (seq + "\n"))

    # Build blastn command
    cmd = [
        cfg.blastn_path,
        "-db",
        db_path,
        "-query",
        query_path,
        "-max_target_seqs",
        str(max_hits),
        "-evalue",
        str(evalue),
        "-outfmt",
        "6 sacc stitle evalue bitscore pident length",
    ]
    if not dust:
        cmd += ["-dust", "no"]

    proc = subprocess.run(cmd, capture_output=True, text=True)
    if proc.returncode != 0:
        raise RuntimeError(f"blastn failed (rc={proc.returncode}): {proc.stderr.strip() or proc.stdout.strip()}")

    hits: List[Dict[str, Any]] = []
    for line in (proc.stdout or "").splitlines():
        if not line.strip():
            continue
        parts = line.split("\t")
        # sacc stitle evalue bitscore pident length
        if len(parts) < 6:
            continue
        hits.append(
            {
                "accession": parts[0],
                "description": parts[1],
                "evalue": float(parts[2]) if parts[2] else None,
                "bitscore": float(parts[3]) if parts[3] else None,
                "pident": float(parts[4]) if parts[4] else None,
                "align_len": int(parts[5]) if parts[5] else None,
            }
        )

    return {"hits": hits}
