import os
from typing import Optional, Union

import pandas as pd

from snowflake import snowpark
from snowflake.ml.modeling._internal.constants import IN_ML_RUNTIME_ENV_VAR
from snowflake.ml.modeling._internal.local_implementations.pandas_handlers import (
    PandasTransformHandlers,
)
from snowflake.ml.modeling._internal.ml_runtime_implementations.ml_runtime_handlers import (
    MLRuntimeTransformHandlers,
)
from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import (
    SnowparkTransformHandlers,
)
from snowflake.ml.modeling._internal.transformer_protocols import ModelTransformHandlers


class ModelTransformerBuilder:
    """
    A builder class to create instances of model transformers for different usage configurations.

    This class provides methods to build model transformers tailored to specific machine learning
    models and post-training configurations like dataset's location etc. It abstracts the creation process,
    allowing the user to obtain a configured model transformer for a particular model architecture or configuration.
    """

    @classmethod
    def build(
        cls,
        dataset: Union[snowpark.DataFrame, pd.DataFrame],
        estimator: object,
        class_name: str,
        subproject: str,
        autogenerated: Optional[bool] = False,
    ) -> ModelTransformHandlers:
        """
        Builder method that creates an appropriate ModelTrainer instance based on the given params.
        These params are the specific parameters required to determine where we execute transforms
        (currently remote and local)

        Args:
            dataset: The dataset on which transforms will be executed.
            estimator: The estimator object used to execute transformations. Must support inference and scoring.
            class_name: class name to be used in telemetry.
            subproject: subproject to be used in telemetry.
            autogenerated: Whether the class was autogenerated from a template.

        Returns:
            A ModelTransformHandlers based on function inputs

        Raises:
            TypeError: Dataset is not one of the currently supported types(pd.DataFrame, snowpark.DataFrame)
        """
        if isinstance(dataset, pd.DataFrame):
            return PandasTransformHandlers(
                dataset=dataset,
                estimator=estimator,
                class_name=class_name,
                subproject=subproject,
                autogenerated=autogenerated,
            )

        elif isinstance(dataset, snowpark.DataFrame):
            if os.environ.get(IN_ML_RUNTIME_ENV_VAR):
                return MLRuntimeTransformHandlers(
                    dataset=dataset,
                    estimator=estimator,
                    class_name=class_name,
                    subproject=subproject,
                    autogenerated=autogenerated,
                )
            return SnowparkTransformHandlers(
                dataset=dataset,
                estimator=estimator,
                class_name=class_name,
                subproject=subproject,
                autogenerated=autogenerated,
            )
        else:
            raise TypeError(
                f"Unexpected dataset type: {type(dataset)}."
                "Supported dataset types: snowpark.DataFrame, pandas.DataFrame."
            )
