#
#  Copyright MindBridge Analytics Inc. all rights reserved.
#
#  This material is confidential and may not be copied, distributed,
#  reversed engineered, decompiled or otherwise disseminated without
#  the prior written consent of MindBridge Analytics Inc.
#

from datetime import timedelta
from typing import Any, Dict, Generator, List, Type, Union
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic_core import PydanticUndefined
from mindbridgeapi.analysis_period import AnalysisPeriod
from mindbridgeapi.analysis_source_item import AnalysisSourceItem
from mindbridgeapi.analysis_type_item import AnalysisTypeItem
from mindbridgeapi.common_validators import (
    _convert_userinfo_to_useritem,
    _warning_if_extra_fields,
)
from mindbridgeapi.enumerations.analysis_type import AnalysisType
from mindbridgeapi.exceptions import ItemError
from mindbridgeapi.generated_pydantic_model.model import (
    ApiAnalysisCreateOnly,
    ApiAnalysisRead,
    ApiAnalysisUpdate,
    ApiDataTableRead,
)
from mindbridgeapi.task_item import TaskItem


def _empty_analysis_sources() -> Generator[AnalysisSourceItem, None, None]:
    """Empty generator function

    This returns an empty generator function, it's use is to ensure analysis_sources is
    not None for the AnalysisItem class
    """
    yield from ()


def _empty_data_tables() -> Generator[ApiDataTableRead, None, None]:
    """Empty generator function

    This returns an empty generator function, it's use is to ensure data_tables is not
    None for the AnalysisItem class
    """
    yield from ()


def _empty_tasks() -> Generator[TaskItem, None, None]:
    """Empty generator function

    This returns an empty generator function, it's use is to ensure tasks is not None
    for the AnalysisItem class
    """
    yield from ()


class AnalysisItem(ApiAnalysisRead):
    analysis_periods: List[AnalysisPeriod] = Field().merge_field_infos(
        ApiAnalysisRead.model_fields["analysis_periods"],
        default=PydanticUndefined,
        default_factory=lambda: [AnalysisPeriod()],
    )  # type: ignore[assignment]
    analysis_type_id: str = Field().merge_field_infos(
        ApiAnalysisRead.model_fields["analysis_type_id"],
        default=AnalysisTypeItem.GENERAL_LEDGER,
    )
    archived: bool = Field().merge_field_infos(
        ApiAnalysisRead.model_fields["archived"], default=False
    )
    converted: bool = Field().merge_field_infos(
        ApiAnalysisRead.model_fields["converted"], default=False
    )
    interim: bool = Field().merge_field_infos(
        ApiAnalysisRead.model_fields["interim"], default=False
    )
    periodic: bool = Field().merge_field_infos(
        ApiAnalysisRead.model_fields["periodic"], default=False
    )
    analysis_sources: Generator[AnalysisSourceItem, None, None] = Field(
        default_factory=_empty_analysis_sources, exclude=True
    )
    data_tables: Generator[ApiDataTableRead, None, None] = Field(
        default_factory=_empty_data_tables, exclude=True
    )
    tasks: Generator[TaskItem, None, None] = Field(
        default_factory=_empty_tasks, exclude=True
    )

    model_config = ConfigDict(
        extra="allow",
        validate_assignment=True,
        validate_default=True,
        validate_return=True,
    )
    _a = model_validator(mode="after")(_warning_if_extra_fields)
    _b = field_validator("*")(_convert_userinfo_to_useritem)

    @field_validator("analysis_periods", mode="after")
    @classmethod
    def _sort_analysis_periods(cls, v: List[AnalysisPeriod]) -> List[AnalysisPeriod]:
        return sorted(v)

    @field_validator("analysis_type_id", mode="before")
    @classmethod
    def _convert_analysis_type_id(cls, v: Any) -> Any:
        if isinstance(v, AnalysisType):
            return v.value

        return v

    @model_validator(mode="after")
    def _set_default_name(self) -> "AnalysisItem":
        if self.name is None:
            if self.analysis_type_id == AnalysisTypeItem.GENERAL_LEDGER:
                self.name = "General ledger analysis"
            elif (
                self.analysis_type_id == AnalysisTypeItem.NOT_FOR_PROFIT_GENERAL_LEDGER
            ):
                self.name = "Not for profit general ledger analysis"
            elif (
                self.analysis_type_id
                == AnalysisTypeItem.NOT_FOR_PROFIT_GENERAL_LEDGER_FUND
            ):
                self.name = "Not for profit general ledger with funds analysis"

        return self

    def _get_post_json(
        self, out_class: Type[Union[ApiAnalysisCreateOnly, ApiAnalysisUpdate]]
    ) -> Dict[str, Any]:
        in_class_dict = self.model_dump()
        out_class_object = out_class.model_validate(in_class_dict)
        return out_class_object.model_dump(
            mode="json", by_alias=True, exclude_none=True
        )

    @property
    def create_json(self) -> Dict[str, Any]:
        return self._get_post_json(out_class=ApiAnalysisCreateOnly)

    @property
    def update_json(self) -> Dict[str, Any]:
        return self._get_post_json(out_class=ApiAnalysisUpdate)

    def add_prior_periods(self, num_to_add: int = 1) -> None:
        """Adds prior periods (assumes 1 year period)"""

        earliest_period = self.analysis_periods[-1]
        for _ in range(num_to_add):
            if earliest_period.start_date is None:
                raise ItemError(  # noqa: TRY003
                    "Earliest Analysis Period had no start_date"
                )

            end_date = earliest_period.start_date - timedelta(days=1)
            if end_date.month == 2 and end_date.day == 29:
                start_date = end_date.replace(
                    year=(end_date.year - 1), day=28
                ) + timedelta(days=1)
            else:
                start_date = end_date.replace(year=(end_date.year - 1)) + timedelta(
                    days=1
                )

            earliest_period = AnalysisPeriod(start_date=start_date, end_date=end_date)  # type: ignore[call-arg]
            self.analysis_periods.append(earliest_period)
