from __future__ import annotations

import logging
from contextlib import ExitStack
from contextvars import ContextVar
from typing import Dict, List, Optional, Union

from fastapi import FastAPI
from sqlalchemy import MetaData, create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import DeclarativeMeta, Session, sessionmaker
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.types import ASGIApp

from .exceptions import SQLAlchemyType
from .extensions import SQLAlchemy
from .extensions import db as db_


class DBStateMap:
    def __init__(self):
        self.dbs: Dict[URL, sessionmaker] = {}
        self.initialized = False

    def __getitem__(self, item: URL) -> sessionmaker:
        return self.dbs[item]

    def __setitem__(self, key: URL, value: sessionmaker) -> None:
        if not self.initialized:
            self.dbs[key] = value
        else:
            raise ValueError("DBStateMap is already initialized")


class DBSessionMiddleware(BaseHTTPMiddleware):
    def __init__(
        self,
        app: ASGIApp,
        db: Optional[Union[List[SQLAlchemy], SQLAlchemy]] = None,
        db_url: Optional[URL] = None,
        **options,
    ):
        super().__init__(app)
        self.state_map = DBStateMap()
        if not (type(db) == list or type(db) == SQLAlchemy) and not db_url:
            raise SQLAlchemyType()
        if db_url and not db:
            global db_
            if not db_.initiated:
                db_.init(url=db_url, **options)
            self.dbs = [db_]
        if type(db) == SQLAlchemy:
            self.dbs = [
                db,
            ]
        elif type(db) == list:
            self.dbs = db
        print(app)
        for db in self.dbs:
            db.init()
            db.create_all()

    async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
        with ExitStack() as stack:
            contexts = [stack.enter_context(ctx()) for ctx in self.dbs]
            response = await call_next(request)
        return response
