from __future__ import annotations

import time
from typing import Dict

from fastapi import FastAPI
from fastapi.responses import JSONResponse

from omnibioai_tool_exec.execution.adapters.aws_batch import AwsBatchAdapter
from omnibioai_tool_exec.execution.adapters.http_toolserver import HttpToolServerAdapter
from omnibioai_tool_exec.execution.adapters.local import LocalAdapter
from omnibioai_tool_exec.execution.adapters.slurm import SlurmAdapter
from omnibioai_tool_exec.execution.runner import Runner
from omnibioai_tool_exec.registry import (
    ServerRegistry,
    ToolRegistry,
    load_servers,
    load_tools,
)
from omnibioai_tool_exec.service.api import routes_runs, routes_servers, routes_tools
from omnibioai_tool_exec.store.run_store import RunStore


def create_app(tools_path: str, servers_path: str) -> FastAPI:
    app = FastAPI(title="omnibioai-tool-exec")

    tool_registry = ToolRegistry()
    server_registry = ServerRegistry()
    store = RunStore()

    # Load tools/servers from YAML
    for t in load_tools(tools_path):
        tool_registry.register(t)

    servers = load_servers(servers_path)

    # Build adapters (by adapter_type)
    adapter_instances: Dict[str, object] = {}

    # Always provide local + slurm adapters (unchanged)
    adapter_instances["local"] = LocalAdapter(config={})
    adapter_instances["slurm"] = SlurmAdapter(config={})

    # http_toolserver instances are per server_id
    http_toolservers_by_server_id: Dict[str, HttpToolServerAdapter] = {}

    # aws_batch instances are per server_id (IMPORTANT: avoids __init__ requiring region/job_queue/etc)
    aws_batch_by_server_id: Dict[str, AwsBatchAdapter] = {}

    # Register servers + preload capabilities if possible
    for s in servers:
        if s.adapter_type == "local":
            adapter_instances["local"].config = s.config  # type: ignore[attr-defined]
            s.capabilities = adapter_instances["local"].handshake()  # type: ignore[attr-defined]
            s.capabilities_last_refreshed_epoch = int(time.time())

        elif s.adapter_type == "http_toolserver":
            http_toolservers_by_server_id[s.server_id] = HttpToolServerAdapter(s.config)
            try:
                caps = http_toolservers_by_server_id[s.server_id].handshake()
                s.capabilities = caps
                s.capabilities_last_refreshed_epoch = int(time.time())
            except Exception:
                pass

        elif s.adapter_type == "slurm":
            adapter_instances["slurm"].config = s.config  # type: ignore[attr-defined]
            s.capabilities = adapter_instances["slurm"].handshake()  # type: ignore[attr-defined]
            s.capabilities_last_refreshed_epoch = int(time.time())

        elif s.adapter_type == "aws_batch":
            # Create the adapter with the REAL config (no dummy config!)
            aws_batch_by_server_id[s.server_id] = AwsBatchAdapter(s.config)
            try:
                s.capabilities = aws_batch_by_server_id[s.server_id].handshake()
                s.capabilities_last_refreshed_epoch = int(time.time())
            except Exception:
                pass

        server_registry.register(s)

    # Runner wants adapter_type -> Adapter
    class _HttpToolServerMux(HttpToolServerAdapter):
        def __init__(self) -> None:
            super().__init__({"base_url": "http://invalid"})

        def _get(self, server_id: str) -> HttpToolServerAdapter:
            if server_id not in http_toolservers_by_server_id:
                raise KeyError(f"http toolserver instance missing for server_id={server_id}")
            return http_toolservers_by_server_id[server_id]

        def handshake(self):
            return self._get(app.state._tes_current_server_id).handshake()

        def validate(self, tool_id, inputs, resources):
            return self._get(app.state._tes_current_server_id).validate(tool_id, inputs, resources)

        def submit(self, tool_id, inputs, resources):
            return self._get(app.state._tes_current_server_id).submit(tool_id, inputs, resources)

        def status(self, remote_run_id):
            return self._get(app.state._tes_current_server_id).status(remote_run_id)

        def logs(self, remote_run_id, tail=200):
            return self._get(app.state._tes_current_server_id).logs(remote_run_id, tail=tail)

        def results(self, remote_run_id):
            return self._get(app.state._tes_current_server_id).results(remote_run_id)

    class _AwsBatchMux(AwsBatchAdapter):
        """
        Mux that forwards calls to the correct per-server AwsBatchAdapter instance.
        We must call AwsBatchAdapter.__init__, but we DO NOT want to require real config here.
        So we provide minimal dummy keys to satisfy __init__ guards; methods are never used on this dummy.
        """

        def __init__(self) -> None:
            super().__init__(
                {
                    "region": "us-east-1",
                    "job_queue": "DUMMY",
                    # keep these present if your __init__ requires them
                    "job_definition_map": {},
                    "s3_results": {"bucket": "DUMMY", "prefix": "DUMMY"},
                }
            )

        def _get(self, server_id: str) -> AwsBatchAdapter:
            if server_id not in aws_batch_by_server_id:
                raise KeyError(f"aws_batch instance missing for server_id={server_id}")
            return aws_batch_by_server_id[server_id]

        def handshake(self):
            return self._get(app.state._tes_current_server_id).handshake()

        def validate(self, tool_id, inputs, resources):
            return self._get(app.state._tes_current_server_id).validate(tool_id, inputs, resources)

        def submit(self, tool_id, inputs, resources):
            return self._get(app.state._tes_current_server_id).submit(tool_id, inputs, resources)

        def status(self, remote_run_id):
            return self._get(app.state._tes_current_server_id).status(remote_run_id)

        def logs(self, remote_run_id, tail: int = 200):
            return self._get(app.state._tes_current_server_id).logs(remote_run_id, tail=tail)

        def results(self, remote_run_id):
            return self._get(app.state._tes_current_server_id).results(remote_run_id)

    adapter_instances["http_toolserver"] = _HttpToolServerMux()
    adapter_instances["aws_batch"] = _AwsBatchMux()

    runner = Runner(adapters=adapter_instances)  # type: ignore[arg-type]

    # Expose registries/runner/store
    app.state.tool_registry = tool_registry
    app.state.server_registry = server_registry
    app.state.runner = runner
    app.state.store = store
    app.state._tes_current_server_id = None

    # Routes
    app.include_router(routes_tools.attach(tool_registry), prefix="/api")
    app.include_router(routes_servers.attach(server_registry, runner), prefix="/api")
    app.include_router(routes_runs.attach(tool_registry, server_registry, runner, store), prefix="/api")

    @app.get("/health")
    def health():
        return {"ok": True, "service": "omnibioai-tool-exec"}

    # Set current server id whenever server_registry.get() is called
    orig_server_get = server_registry.get

    def _server_get(server_id: str):
        srv = orig_server_get(server_id)
        app.state._tes_current_server_id = server_id
        return srv

    server_registry.get = _server_get  # type: ignore

    # Basic error mapping for KeyError (cleaner API)
    @app.exception_handler(KeyError)
    def _key_error(_, exc: KeyError):
        return JSONResponse(
            status_code=404,
            content={"ok": False, "error": {"code": "NOT_FOUND", "message": str(exc)}},
        )

    return app
