from __future__ import annotations

import json
import time
from pathlib import Path
from typing import Any, Dict, Optional

import boto3
from botocore.exceptions import ClientError

from omnibioai_tool_exec.execution.adapters.base import Adapter
from omnibioai_tool_exec.models.capabilities import ServerCapabilities, ToolCapability


class AwsBatchAdapter(Adapter):
    """
    AWS Batch adapter.

    Contract:
      - submit() returns Batch jobId
      - status() maps Batch jobStatus -> TES state
      - logs() pulls from CloudWatch Logs if configured
      - results() reads results.json from S3 (recommended)

    Minimal assumptions:
      - each tool_id has a jobDefinition (container image + entrypoint) that knows
        how to run the tool when given TOOL_ID + INPUTS_JSON + RESOURCES_JSON + S3_RESULT_URI env vars.
      - job writes results.json to S3_RESULT_URI (s3://bucket/prefix/<run_id>/results.json)
    """

    def __init__(self, config: Dict[str, Any]) -> None:
        self.config = config or {}

        self.region = str(self.config.get("region") or "")
        if not self.region:
            raise ValueError("AwsBatchAdapter requires config.region")

        # Optional AWS profile
        profile = self.config.get("aws_profile")
        if profile:
            session = boto3.Session(profile_name=str(profile), region_name=self.region)
        else:
            session = boto3.Session(region_name=self.region)

        self.batch = session.client("batch")
        self.logs_client = session.client("logs")
        self.s3 = session.client("s3")

        self.job_queue = str(self.config.get("job_queue") or "")
        if not self.job_queue:
            raise ValueError("AwsBatchAdapter requires config.job_queue")

        self.job_definition_map = dict(self.config.get("job_definition_map") or {})
        if not self.job_definition_map:
            # allow empty for now, but submit() will fail unless tool_id mapped
            pass

        self.work_root = Path(self.config.get("work_root", "/tmp/omnibioai_tes_runs"))
        self.index_dir = self.work_root / "_aws_batch_index"
        self.index_dir.mkdir(parents=True, exist_ok=True)

        self.s3_bucket = None
        self.s3_prefix = None
        s3cfg = self.config.get("s3_results") or {}
        if isinstance(s3cfg, dict):
            self.s3_bucket = s3cfg.get("bucket")
            self.s3_prefix = s3cfg.get("prefix", "tes-runs")

    def adapter_type(self) -> str:
        return "aws_batch"

    # ----------------------------
    # Helpers
    # ----------------------------
    def _index_path(self, job_id: str) -> Path:
        return self.index_dir / f"{job_id}.json"

    def _write_index(self, job_id: str, meta: Dict[str, Any]) -> None:
        self._index_path(job_id).write_text(json.dumps(meta, indent=2))

    def _read_index(self, job_id: str) -> Dict[str, Any]:
        p = self._index_path(job_id)
        if not p.exists():
            return {}
        try:
            return json.loads(p.read_text())
        except Exception:
            return {}

    def _s3_uri_for(self, run_id: str) -> Optional[str]:
        if not self.s3_bucket:
            return None
        prefix = str(self.s3_prefix or "tes-runs").strip("/")

        # results.json
        return f"s3://{self.s3_bucket}/{prefix}/{run_id}/results.json"

    @staticmethod
    def _batch_to_tes_state(batch_status: str) -> str:
        # Batch statuses: SUBMITTED, PENDING, RUNNABLE, STARTING, RUNNING, SUCCEEDED, FAILED
        s = (batch_status or "").upper()
        if s in ("SUBMITTED", "PENDING", "RUNNABLE", "STARTING"):
            return "QUEUED"
        if s == "RUNNING":
            return "RUNNING"
        if s == "SUCCEEDED":
            return "COMPLETED"
        if s == "FAILED":
            return "FAILED"
        # fallback
        return "RUNNING"

    # ----------------------------
    # Adapter methods
    # ----------------------------
    def handshake(self) -> ServerCapabilities:
        # For MVP: advertise tools from job_definition_map
        tools = []
        for tool_id, jd in sorted(self.job_definition_map.items()):
            tools.append(ToolCapability(tool_id=tool_id, version=str(jd), features={}))

        return ServerCapabilities(
            engines=["aws_batch"],
            tools=tools,
            resources=self.config.get("resources", {}) or {},
            storage={"s3_results_bucket": self.s3_bucket, "s3_results_prefix": self.s3_prefix},
            policies=self.config.get("policies", {}) or {},
        )

    def validate(self, tool_id: str, inputs: Dict[str, Any], resources: Dict[str, Any]) -> Dict[str, Any]:
        errors = []

        if tool_id not in self.job_definition_map:
            errors.append({"field": "tool_id", "message": f"No job definition mapped for tool_id={tool_id}"})

        # light validation only; deeper validation should happen in the tool container itself too
        if not isinstance(inputs, dict):
            errors.append({"field": "inputs", "message": "inputs must be an object"})
        if not isinstance(resources, dict):
            errors.append({"field": "resources", "message": "resources must be an object"})

        # require S3 results for MVP results()
        if not self.s3_bucket:
            errors.append({"field": "config.s3_results.bucket", "message": "s3_results.bucket required for results()"})

        return {"ok": len(errors) == 0, "errors": errors, "warnings": []}

    def submit(self, tool_id: str, inputs: Dict[str, Any], resources: Dict[str, Any]) -> str:
        jd = self.job_definition_map.get(tool_id)
        if not jd:
            raise RuntimeError(f"No job definition mapped for tool_id={tool_id}")

        # Make a stable run_id for S3 path: reuse TES run_id? (Not available here)
        # We'll use a timestamp-based id tied to job submission.
        run_id = f"aws_{int(time.time() * 1000)}"
        s3_uri = self._s3_uri_for(run_id)
        if not s3_uri:
            raise RuntimeError("s3_results.bucket not configured; cannot compute results URI")

        # Provide inputs/resources via env vars (simple, works for small payloads)
        # If payload becomes big, switch to S3 inputs.
        env = [
            {"name": "TOOL_ID", "value": tool_id},
            {"name": "INPUTS_JSON", "value": json.dumps(inputs)},
            {"name": "RESOURCES_JSON", "value": json.dumps(resources)},
            {"name": "S3_RESULT_URI", "value": s3_uri},
            {"name": "RUN_ID", "value": run_id},
        ]

        job_name = f"tes-{tool_id}-{run_id}"

        # Optional: override vcpus/memory if your job definition supports it (Batch has limits)
        # Here we keep it minimal; you can add resourceRequirements later.

        resp = self.batch.submit_job(
            jobName=job_name,
            jobQueue=self.job_queue,
            jobDefinition=str(jd),
            containerOverrides={
                "environment": env,
            },
        )

        job_id = resp["jobId"]

        # persist index for restart-proof logs/results
        meta = {
            "job_id": job_id,
            "job_name": job_name,
            "tool_id": tool_id,
            "run_id": run_id,
            "submitted_epoch": int(time.time()),
            "s3_result_uri": s3_uri,
        }
        self._write_index(job_id, meta)
        return job_id

    def status(self, remote_run_id: str) -> Dict[str, Any]:
        job_id = str(remote_run_id).strip()
        if not job_id:
            return {"state": "FAILED", "message": "missing job id"}

        try:
            resp = self.batch.describe_jobs(jobs=[job_id])
        except ClientError as e:
            return {"state": "FAILED", "message": str(e)}

        jobs = resp.get("jobs") or []
        if not jobs:
            return {"state": "FAILED", "message": "job not found"}

        j = jobs[0]
        batch_status = j.get("status", "")
        tes_state = self._batch_to_tes_state(batch_status)

        return {
            "state": tes_state,
            "batch_status": batch_status,
            "status_reason": j.get("statusReason"),
            "created_at": j.get("createdAt"),
            "started_at": j.get("startedAt"),
            "stopped_at": j.get("stoppedAt"),
        }

    def logs(self, remote_run_id: str, tail: int = 200) -> str:
        job_id = str(remote_run_id).strip()
        if not job_id:
            return "missing job id"

        # Need logStreamName from describe_jobs -> container.logStreamName
        try:
            resp = self.batch.describe_jobs(jobs=[job_id])
            jobs = resp.get("jobs") or []
            if not jobs:
                return f"[{job_id}] job not found"

            j = jobs[0]
            container = (j.get("container") or {})
            stream = container.get("logStreamName")
            if not stream:
                return f"[{job_id}] no CloudWatch log stream yet"
        except Exception as e:
            return f"[{job_id}] error fetching job info: {e}"

        # log group is usually from job definition (e.g. /aws/batch/job)
        # Many setups use "/aws/batch/job". Allow override.
        log_group = str(self.config.get("log_group", "/aws/batch/job"))

        try:
            # pull latest events
            out = self.logs_client.get_log_events(
                logGroupName=log_group,
                logStreamName=stream,
                startFromHead=False,
                limit=max(1, int(tail)),
            )
            events = out.get("events") or []
            lines = [e.get("message", "") for e in events]
            return "\n".join(lines[-tail:]) if tail and tail > 0 else "\n".join(lines)
        except Exception as e:
            return f"[{job_id}] error reading logs: {e}"

    def results(self, remote_run_id: str) -> Dict[str, Any]:
        job_id = str(remote_run_id).strip()
        if not job_id:
            return {"ok": False, "error": {"code": "INVALID", "message": "missing job id"}}

        st = self.status(job_id)
        if st.get("state") != "COMPLETED":
            return {"ok": False, "error": {"code": "NOT_READY", "message": f"state={st.get('state')}"}, "status": st}

        meta = self._read_index(job_id)
        s3_uri = meta.get("s3_result_uri")
        if not s3_uri:
            return {"ok": False, "error": {"code": "NO_RESULTS_URI", "message": "missing s3_result_uri in index"}}

        # parse s3://bucket/key
        try:
            _, _, rest = s3_uri.partition("s3://")
            bucket, _, key = rest.partition("/")
            if not bucket or not key:
                raise ValueError("bad s3 uri")
        except Exception:
            return {"ok": False, "error": {"code": "BAD_S3_URI", "message": f"invalid s3 uri: {s3_uri}"}}

        try:
            obj = self.s3.get_object(Bucket=bucket, Key=key)
            body = obj["Body"].read().decode("utf-8", errors="replace")
            results_obj = json.loads(body)
        except ClientError as e:
            return {"ok": False, "error": {"code": "S3_READ_FAILED", "message": str(e)}, "s3_uri": s3_uri}
        except Exception as e:
            return {"ok": False, "error": {"code": "RESULTS_PARSE_FAILED", "message": str(e)}, "s3_uri": s3_uri}

        return {"ok": True, "job_id": job_id, "results": results_obj, "meta": meta}
