from __future__ import annotations
import time
from fastapi import FastAPI
from fastapi.responses import JSONResponse

from omnibioai_tool_exec.registry import load_tools, load_servers, ToolRegistry, ServerRegistry
from omnibioai_tool_exec.execution.runner import Runner
from omnibioai_tool_exec.store.run_store import RunStore

from omnibioai_tool_exec.execution.adapters.local import LocalAdapter
from omnibioai_tool_exec.execution.adapters.http_toolserver import HttpToolServerAdapter

from omnibioai_tool_exec.service.api import routes_tools, routes_servers, routes_runs
from omnibioai_tool_exec.execution.adapters.slurm import SlurmAdapter


def create_app(tools_path: str, servers_path: str) -> FastAPI:
    app = FastAPI(title="omnibioai-tool-exec")

    tool_registry = ToolRegistry()
    server_registry = ServerRegistry()
    store = RunStore()

    # Load tools/servers from YAML
    for t in load_tools(tools_path):
        tool_registry.register(t)

    servers = load_servers(servers_path)

    # Build adapters (by adapter_type)
    # NOTE: for types with per-server config, instantiate per server and store in a map.
    adapter_instances = {}

    # Always provide local adapter
    adapter_instances["local"] = LocalAdapter(config={})
    adapter_instances["slurm"] = SlurmAdapter(config={})

    # http_toolserver adapter is per server in principle (different base_url),
    # but we can instantiate per server_id and store under a composite key.
    # To keep Runner API simple, we map adapter_type -> a "factory adapter" pattern.
    # For MVP, we store one instance per server_id in app.state, and dispatch in routes by server.adapter_type.
    # Here we'll keep adapter_type -> callable in app.state instead.
    http_toolservers_by_server_id = {}

    # Register servers + preload capabilities if possible
    for s in servers:
        if s.adapter_type == "local":
            s.capabilities = adapter_instances["local"].handshake()
            s.capabilities_last_refreshed_epoch = int(time.time())
        elif s.adapter_type == "http_toolserver":
            http_toolservers_by_server_id[s.server_id] = HttpToolServerAdapter(s.config)
            # capabilities can be refreshed via endpoint; optionally do it at startup:
            try:
                caps = http_toolservers_by_server_id[s.server_id].handshake()
                s.capabilities = caps
                s.capabilities_last_refreshed_epoch = int(time.time())
            except Exception:
                # leave capabilities empty; user can refresh later
                pass
        elif s.adapter_type == "slurm":
            # slurm adapter uses local CLI (sbatch/squeue)
            adapter_instances["slurm"].config = s.config
            s.capabilities = adapter_instances["slurm"].handshake()
            s.capabilities_last_refreshed_epoch = int(time.time())

        server_registry.register(s)

    # Runner wants adapter_type -> Adapter
    # For http_toolserver, we will proxy calls by selecting correct instance per server in a small wrapper.
    class _HttpToolServerMux(HttpToolServerAdapter):
        def __init__(self) -> None:
            super().__init__({"base_url": "http://invalid"})
        def _get(self, server_id: str) -> HttpToolServerAdapter:
            if server_id not in http_toolservers_by_server_id:
                raise KeyError(f"http toolserver instance missing for server_id={server_id}")
            return http_toolservers_by_server_id[server_id]

        # We rely on app.state.current_server_id set by routes before calling.
        def handshake(self):
            return self._get(app.state._tes_current_server_id).handshake()
        def validate(self, tool_id, inputs, resources):
            return self._get(app.state._tes_current_server_id).validate(tool_id, inputs, resources)
        def submit(self, tool_id, inputs, resources):
            return self._get(app.state._tes_current_server_id).submit(tool_id, inputs, resources)
        def status(self, remote_run_id):
            return self._get(app.state._tes_current_server_id).status(remote_run_id)
        def logs(self, remote_run_id, tail=200):
            return self._get(app.state._tes_current_server_id).logs(remote_run_id, tail=tail)
        def results(self, remote_run_id):
            return self._get(app.state._tes_current_server_id).results(remote_run_id)

    adapter_instances["http_toolserver"] = _HttpToolServerMux()

    runner = Runner(adapters=adapter_instances)

    # Expose registries/runner/store
    app.state.tool_registry = tool_registry
    app.state.server_registry = server_registry
    app.state.runner = runner
    app.state.store = store
    app.state._tes_current_server_id = None

    @app.middleware("http")
    async def _server_context_middleware(request, call_next):
        # If a route accesses runner.get_adapter(server.adapter_type) and that adapter is multiplexed,
        # we need the server_id in context. We set it inside routes before calling adapter methods,
        # but keep this middleware in case you want to extend later.
        return await call_next(request)

    # Routes
    app.include_router(routes_tools.attach(tool_registry), prefix="/api")
    app.include_router(routes_servers.attach(server_registry, runner), prefix="/api")

    # Attach runs router but wrap adapter calls with server_id context
    runs_router = routes_runs.attach(tool_registry, server_registry, runner, store)

    @app.middleware("http")
    async def _set_current_server_id(request, call_next):
        # No-op by default; routes will set app.state._tes_current_server_id right before adapter calls.
        return await call_next(request)

    app.include_router(runs_router, prefix="/api")

    @app.get("/health")
    def health():
        return {"ok": True, "service": "omnibioai-tool-exec"}

    # Monkeypatch: set server_id context before adapter calls.
    # We do this by wrapping Runner.get_adapter to set current server_id when adapter_type == http_toolserver.
    orig_get_adapter = runner.get_adapter

    def _get_adapter_with_context(adapter_type: str):
        return orig_get_adapter(adapter_type)

    runner.get_adapter = _get_adapter_with_context  # type: ignore

    # Add a small helper for routes to set current server_id
    def set_current_server_id(server_id: str) -> None:
        app.state._tes_current_server_id = server_id

    app.state.set_current_server_id = set_current_server_id

    # Patch server routes and run routes to set current_server_id before calling http_toolserver adapter.
    # Easiest: keep it explicit in those routes when you extend; for now we do it centrally by
    # overriding server_registry.get to set context.
    orig_server_get = server_registry.get

    def _server_get(server_id: str):
        srv = orig_server_get(server_id)
        app.state.set_current_server_id(server_id)
        return srv

    server_registry.get = _server_get  # type: ignore

    # Basic error mapping for KeyError (cleaner API)
    @app.exception_handler(KeyError)
    def _key_error(_, exc: KeyError):
        return JSONResponse(status_code=404, content={"ok": False, "error": {"code": "NOT_FOUND", "message": str(exc)}})

    return app

