import csv
import warnings
from pathlib import Path
from typing import Type, TypeVar, List, Optional, Tuple, Dict

from pydantic import ValidationError, BaseModel
import pandas as pd
from openpyxl.reader.excel import load_workbook
from loguru import logger

from pjdev_sqlmodel.db_models import ModelBase
from pjdev_sqlmodel.service import session_context

T = TypeVar("T", bound=ModelBase)


def get_files_in_directory(directory: Path) -> List[Path]:
    return [f for f in directory.glob("**/*.xlsx") if not f.name.startswith("~$")] + [
        f for f in directory.glob("**/*.csv")
    ]


def get_csv_columns(file_path):
    with open(file_path, "r", encoding="utf-8-sig") as csvfile:
        reader = csv.reader(csvfile)
        header = next(reader)  # This will return the column names
        return header


def get_excel_columns(file_path, header_ndx: int = 1, col_range: Optional[str] = None):
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore")
        wb = load_workbook(filename=file_path, read_only=True)

    sheet = wb.active

    if col_range is not None:
        return [c.value for c in [row for row in sheet[col_range]][0]]

    return [cell.value for cell in sheet[header_ndx]]


def load_csv_data(model_type: Type[T], data_files: List[Path]) -> None:
    fields = model_type.model_fields.keys()
    cols = [
        f
        if model_type.model_fields.get(f).alias is None
        else model_type.model_fields.get(f).alias
        for f in fields
        if f != "row_id"
    ]
    filtered_files = [
        f
        for f in data_files
        if f.name.endswith(".csv")
        and len(set(cols).difference(set(get_csv_columns(f)))) == 0
    ]

    data: List[model_type] = []
    for file in filtered_files:
        df = pd.read_csv(file, usecols=cols, na_filter=False)

        try:
            for _, row in df.iterrows():
                d = model_type.model_validate(row.to_dict())
                data.append(d)
        except ValidationError as e:
            logger.error(f"Error when parsing {file.name}: {e}")

    with session_context() as session:
        session.add_all(data)
        session.commit()
    logger.info("Loaded {} rows for {} table".format(len(data), model_type.__name__))


def load_excel_data(
    model_type: Type[T],
    data_files: List[Path],
    header_ndx: int = 0,
    col_range: Optional[str] = None,
) -> None:
    fields = model_type.model_fields.keys()
    cols = [
        f
        if model_type.model_fields.get(f).alias is None
        else model_type.model_fields.get(f).alias
        for f in fields
        if f != "row_id"
    ]

    filtered_files = [
        f
        for f in data_files
        if f.name.endswith(".xlsx")
        and len(
            set(cols).difference(set(get_excel_columns(f, header_ndx + 1, col_range)))
        )
        == 0
    ]

    data: List[model_type] = []

    if len(filtered_files) == 0:
        raise Exception(
            f"No files found that matched the schema for {model_type.__name__}"
        )

    for file in filtered_files:
        df = pd.read_excel(io=file, usecols=cols, na_filter=False, header=header_ndx)

        try:
            for _, row in df.iterrows():
                d = model_type.model_validate(row.to_dict())
                data.append(d)
        except ValidationError as e:
            logger.error(f"Error when parsing {file.name}: {e}")

    with session_context() as session:
        session.add_all(data)
        session.commit()
    logger.info("Loaded {} rows for {} table".format(len(data), model_type.__name__))


def convert_to_csv(
    data: List[BaseModel],
    col_mapping_tuple: Tuple[List[str], Dict[str, str], List[str]],
    filename: Path,
    index=False,
) -> None:
    include_set, col_mapping, cols = col_mapping_tuple
    dict_data = [
        d.model_dump(by_alias=True, include=dict.fromkeys(include_set)) for d in data
    ]

    # Create DataFrame from the list of dictionaries
    df = pd.DataFrame(dict_data)

    # Export DataFrame to CSV
    df.to_csv(filename, index=index, columns=cols)
