from __future__ import annotations
from typing import Concatenate, Iterable, cast, override, Type, TYPE_CHECKING, Any, Callable, Optional

from ormlambda.sql.types import ASTERISK, compileOptions

from ormlambda.sql.elements import ClauseElement

if TYPE_CHECKING:
    from ormlambda.engine.base import Engine
    from ormlambda.sql.types import AliasType
    from ormlambda import Table
    from ormlambda.statements.types import OrderTypes
    from ormlambda.sql.types import ColumnType
    from ormlambda.statements.types import SelectCols
    from ormlambda.statements.types import TypeExists
    from ormlambda.statements.types import WhereTypes
    from ormlambda.dialects import Dialect

from ormlambda.statements.interfaces import IStatements
from ormlambda.statements.base_statement import ClusterResponse

from ormlambda import OrderType, Table
from ormlambda.common.enums import JoinType, UnionEnum
from ormlambda.sql.clauses.join import JoinContext, TupleJoinType

from ormlambda.common.global_checker import GlobalChecker
from .query_builder import QueryBuilder

from ormlambda.sql import clauses
from ormlambda.sql import functions as func
from ormlambda.common import errors as error


# COMMENT: It's so important to prevent information generated by other tests from being retained in the class.
@staticmethod
def clear_list[T, **P](f: Callable[Concatenate[Statements, P], T]) -> Callable[P, T]:
    def wrapper(self: Statements, *args: P.args, **kwargs: P.kwargs) -> T:
        try:
            return f(self, *args, **kwargs)

        except error.ColumnError as err:
            err.set_clause(f.__name__)
            raise

        except Exception:
            raise
        finally:
            self._query_builder.clear()

    return wrapper


class Statements[T: Table](IStatements[T]):
    def __init__(self, model: T, engine: Engine) -> None:
        self._query_builder = QueryBuilder()
        self._engine = engine
        self._dialect = engine.dialect
        self._model: T = model[0] if isinstance(model, Iterable) else model

        if not issubclass(self._model, Table):
            # Deben heredar de Table ya que es la forma que tenemos para identificar si estamos pasando una instancia del tipo que corresponde o no cuando llamamos a insert o upsert.
            # Si no heredase de Table no sabriamos identificar el tipo de dato del que se trata porque al llamar a isinstance, obtendriamos el nombre de la clase que mapea a la tabla, Encargo, Edificio, Presupuesto y no podriamos crear una clase generica
            raise Exception(f"'{model}' class does not inherit from Table class")

    @property
    def dialect(self) -> Dialect:
        return self._dialect

    @property
    def engine(self) -> Engine:
        return self._engine

    @override
    def table_exists(self) -> bool:
        return self.engine.repository.table_exists(self._model.__table_name__)

    def __repr__(self):
        return f"<Model: {self.__class__.__name__}>"

    @property
    def model(self) -> Type[T]:
        return self._model

    @override
    def create_table(self, if_exists: TypeExists = "fail") -> None:
        name: str = self._model.__table_name__
        if self.engine.repository.table_exists(name):
            if if_exists == "replace":
                self.drop_table()

            elif if_exists == "fail":
                raise ValueError(f"Table '{self._model.__table_name__}' already exists")

            elif if_exists == "append":
                counter: int = 0
                char: str = ""
                while self.engine.repository.table_exists(name + char):
                    counter += 1
                    char = f"_{counter}"
                name += char

                new_model = self._model.copy(__table_name__=name)
                return new_model.create_table(self.dialect)

        query = self.model.create_table(self.dialect)
        self.engine.repository.execute(query)
        return None

    @override
    def drop_table(self) -> None:
        q = self.model.drop_table(self.dialect)
        self.engine.repository.execute(q)
        return None

    @override
    @clear_list
    def insert(self, instances: T | list[T]) -> None:
        insert = clauses.Insert(instances)
        query = insert.compile(self.dialect).string
        self.engine.repository.executemany_with_values(query, insert.cleaned_values)
        return None

    @override
    @clear_list
    def delete(self, instances: Optional[T | list[T]] = None) -> None:
        if instances and not isinstance(instances, Iterable):
            instances = (instances,)

        if instances:
            pks_values = []
            for instance in instances:
                pk = instance.get_pk()
                pks_values.append(instance[pk])

            self.where(lambda x: getattr(x, pk.column_name).contains(pks_values))

        delete = clauses.Delete(self.model, self._query_builder.where, instances)
        query = delete.compile(self.dialect).string
        self._engine.repository.execute(query)
        # not necessary to call self._query_builder.clear() because select() method already call it
        return None

    @override
    @clear_list
    def upsert(self, instances: T | list[T]) -> None:
        upsert = clauses.Upsert(instances)
        query = upsert.compile(self.dialect).string
        self._engine.repository.executemany_with_values(query, upsert.cleaned_values)
        return None

    @override
    @clear_list
    def update(self, dicc: dict[str, Any] | list[dict[str, Any]]) -> None:
        update = clauses.Update(self.model, self._query_builder.where, dicc)
        query = update.compile(self.dialect).string
        return self._engine.repository.execute_with_values(query, update.cleaned_values)

    @override
    def limit(self, number: int) -> IStatements[T]:
        # Only can be one LIMIT SQL parameter. We only use the last LimitQuery
        limit = clauses.Limit(number=number)
        self._query_builder.add_statement(limit)
        return self

    @override
    def offset(self, number: int) -> IStatements[T]:
        offset = clauses.Offset(number=number)
        self._query_builder.add_statement(offset)
        return self

    @override
    def count[TProp](
        self,
        selection: Optional[SelectCols[T, TProp] | str] = ASTERISK,
        alias: AliasType = "count",
    ) -> Optional[int]:
        if selection == ASTERISK:
            return self.select_one(lambda x: func.Count(x, alias), flavour=dict)[alias]

        # get first position because 'resolved_callback_object' return an, alway Iterable and we should only pass one column
        res = GlobalChecker.resolved_callback_object(self.model, selection)[0]
        return self.select_one(lambda x: func.Count(res, alias), flavour=dict)[alias]

    @override
    def where(self, conditions: WhereTypes[T], restrictive: bool = True) -> IStatements[T]:
        # FIXME [x]: I've wrapped self._model into tuple to pass it instance attr. Idk if it's correct

        restrictive = UnionEnum.AND if restrictive else UnionEnum.OR
        result = GlobalChecker.resolved_callback_object(self.model, conditions)

        self._query_builder.add_where(result, restrictive)
        return self

    @override
    def having(self, conditions: ColumnType, restrictive: bool = True) -> IStatements[T]:
        restrictive = UnionEnum.AND if restrictive else UnionEnum.OR
        result = GlobalChecker.resolved_callback_object(self.model, conditions)

        self._query_builder.add_having(result, restrictive)
        return self

    @override
    def order[TValue](self, columns: str | Callable[[T], TValue], order_type: OrderTypes = OrderType.ASC) -> IStatements[T]:
        if isinstance(columns, str):
            callable_func = lambda x: columns  # noqa: E731
        else:
            callable_func = columns

        res = GlobalChecker.resolved_callback_object(self.model, callable_func)
        deferred_op = clauses.Order(*res, order_type=order_type)
        self._query_builder.add_statement(deferred_op)

        return self

    @override
    def max[TProp](
        self,
        column: SelectCols[T, TProp],
        alias: AliasType = "max",
    ) -> int:
        res = GlobalChecker.resolved_callback_object(self.model, column)[0]

        return self.select_one(lambda x: func.Max(res, alias), flavour=dict)[alias]

    @override
    def min[TProp](
        self,
        column: SelectCols[T, TProp],
        alias: AliasType = "min",
    ) -> int:
        res = GlobalChecker.resolved_callback_object(self.model, column)[0]

        return self.select_one(lambda x: func.Min(res, alias), flavour=dict)[alias]

    @override
    def sum[TProp](
        self,
        column: SelectCols[T, TProp],
        alias: AliasType = "sum",
    ) -> int:
        res = GlobalChecker.resolved_callback_object(self.model, column)[0]

        return self.select_one(lambda x: func.Sum(res, alias), flavour=dict)[alias]

    @override
    def join[LTable: Table, LProp, RTable: Table, RProp](self, joins: tuple[TupleJoinType[LTable, LProp, RTable, RProp]]) -> JoinContext[tuple[*TupleJoinType[LTable, LProp, RTable, RProp]]]:
        return JoinContext(self, joins, self._query_builder._context, self._dialect)

    @override
    @clear_list
    def select[TValue, TFlavour, *Ts](
        self,
        selector: Optional[tuple[TValue, *Ts]] = None,
        *,
        flavour: Optional[Type[TFlavour]] = None,
        by: JoinType = JoinType.INNER_JOIN,
        alias: Optional[AliasType[T]] = None,
        avoid_duplicates: bool = False,
        only_query: bool = False,
        **kwargs,
    ):
        if selector is None:
            # COMMENT: if we do not specify any lambda function we assumed the user want to retreive only elements of the Model itself avoiding other models
            result = self.select(
                selector=lambda x: x,
                flavour=flavour,
                by=by,
                avoid_duplicates=avoid_duplicates,
                only_query=only_query,
                **kwargs,
            )
            return result
        select_clause = GlobalChecker.resolved_callback_object(self.model, selector)

        select = clauses.Select(
            table=self.model,
            columns=select_clause,
            alias=alias,
            avoid_duplicates=avoid_duplicates,
        )

        self._query_builder.add_statement(select)

        self._query_builder.by = by

        if only_query:
            return self.query(sep="\n")

        query = self.query()
        return ClusterResponse(select, self._engine, flavour, query).cluster_data()

    @override
    def select_one[TValue, TFlavour, *Ts](
        self,
        selector: Optional[tuple[TValue, *Ts]] = None,
        *,
        flavour: Optional[Type[TFlavour]] = None,
        by: JoinType = JoinType.INNER_JOIN,
        **kwargs,
    ):
        self.limit(1)

        response = self.select(selector=selector, flavour=flavour, by=by, **kwargs)

        if not isinstance(response, Iterable):
            return response
        if flavour:
            return response[0] if response else None

        # response var could be return more than one element when we work with models an we
        # select columns from different tables using a join query
        # FIXME [x]: before it was if len(response) == 1 and len(response[0]) == 1: return response[0][0]
        if len(response) == 1:
            return response[0]
        return response

    @override
    def first[TValue, TFlavour, *Ts](
        self,
        selector: Optional[tuple[TValue, *Ts]] = None,
        *,
        flavour: Optional[Type[TFlavour]] = None,
        by: JoinType = JoinType.INNER_JOIN,
        **kwargs,
    ):
        return self.select_one(
            selector=selector,
            flavour=flavour,
            by=by,
            **kwargs,
        )

    @override
    def groupby[TProp](self, column: ColumnType[TProp] | Callable[[T], Any]) -> IStatements[T]:
        result = GlobalChecker.resolved_callback_object(self.model, column)
        deferred_op = clauses.GroupBy(*result)
        self._query_builder.add_statement(deferred_op)
        return self

    def query(self, element: Optional[compileOptions] = None, sep: str = " ") -> str:
        if not element:
            return self._query_builder.query(sep, self._dialect).strip()

        return cast(ClauseElement, getattr(self._query_builder, element)).compile(self.dialect).string.strip()
