"""
Comprehensive validation module for Open-Meteo API.

Provides validation for coordinates, dates, variables, and response data.
"""

from datetime import datetime, date
from typing import List, Optional, Dict, Any, Set, Union
import logging

logger = logging.getLogger(__name__)


class ValidationError(ValueError):
    """Custom exception for validation errors."""
    pass


class CoordinateValidator:
    """Validates geographic coordinates."""

    @staticmethod
    def validate_latitude(lat: float, param_name: str = "latitude") -> None:
        """Validate latitude value."""
        if lat is None:
            raise ValidationError(f"{param_name} is required")
        if not isinstance(lat, (int, float)):
            raise ValidationError(f"{param_name} must be a number, got {type(lat).__name__}")
        if not (-90 <= lat <= 90):
            raise ValidationError(f"Invalid {param_name}: {lat}. Must be between -90 and 90.")

    @staticmethod
    def validate_longitude(lon: float, param_name: str = "longitude") -> None:
        """Validate longitude value."""
        if lon is None:
            raise ValidationError(f"{param_name} is required")
        if not isinstance(lon, (int, float)):
            raise ValidationError(f"{param_name} must be a number, got {type(lon).__name__}")
        if not (-180 <= lon <= 180):
            raise ValidationError(f"Invalid {param_name}: {lon}. Must be between -180 and 180.")

    @staticmethod
    def validate_coordinates(
        latitude: Union[float, List[float]],
        longitude: Union[float, List[float]]
    ) -> None:
        """Validate latitude and longitude pairs."""
        lats = [latitude] if isinstance(latitude, (int, float)) else list(latitude)
        lons = [longitude] if isinstance(longitude, (int, float)) else list(longitude)

        if len(lats) != len(lons):
            raise ValidationError(
                f"Latitude and longitude lists must have the same length. "
                f"Got {len(lats)} latitudes and {len(lons)} longitudes."
            )

        if len(lats) == 0:
            raise ValidationError("At least one coordinate pair is required")

        for i, (lat, lon) in enumerate(zip(lats, lons)):
            try:
                CoordinateValidator.validate_latitude(lat, f"latitude[{i}]")
                CoordinateValidator.validate_longitude(lon, f"longitude[{i}]")
            except ValidationError as e:
                raise ValidationError(f"Invalid coordinate at index {i}: {e}")

    @staticmethod
    def validate_elevation(elevation: float) -> None:
        """Validate elevation value."""
        if elevation is not None:
            if not isinstance(elevation, (int, float)):
                raise ValidationError(f"Elevation must be a number, got {type(elevation).__name__}")
            # Mt. Everest is ~8849m, Dead Sea is ~-430m
            if not (-500 <= elevation <= 9000):
                logger.warning(f"Unusual elevation value: {elevation}m")


class DateValidator:
    """Validates date parameters."""

    DATE_FORMAT = "%Y-%m-%d"
    DATETIME_FORMAT = "%Y-%m-%dT%H:%M"

    # Data availability ranges for different APIs
    HISTORICAL_START = date(1940, 1, 1)
    FLOOD_START = date(1984, 1, 1)
    CLIMATE_START = date(1950, 1, 1)
    CLIMATE_END = date(2050, 12, 31)

    @staticmethod
    def parse_date(date_str: str, param_name: str = "date") -> date:
        """Parse and validate date string."""
        if not date_str:
            raise ValidationError(f"{param_name} is required")

        if not isinstance(date_str, str):
            raise ValidationError(f"{param_name} must be a string in YYYY-MM-DD format")

        try:
            return datetime.strptime(date_str, DateValidator.DATE_FORMAT).date()
        except ValueError:
            raise ValidationError(
                f"Invalid {param_name}: '{date_str}'. Expected format: YYYY-MM-DD"
            )

    @staticmethod
    def validate_date_range(
        start_date: str,
        end_date: str,
        min_date: Optional[date] = None,
        max_date: Optional[date] = None
    ) -> None:
        """Validate date range."""
        start = DateValidator.parse_date(start_date, "start_date")
        end = DateValidator.parse_date(end_date, "end_date")

        if start > end:
            raise ValidationError(
                f"start_date ({start_date}) must be before or equal to end_date ({end_date})"
            )

        if min_date and start < min_date:
            raise ValidationError(
                f"start_date ({start_date}) is before available data ({min_date.isoformat()})"
            )

        if max_date and end > max_date:
            raise ValidationError(
                f"end_date ({end_date}) is after available data ({max_date.isoformat()})"
            )

        # Warn for very large date ranges
        days_diff = (end - start).days
        if days_diff > 365 * 5:
            logger.warning(
                f"Large date range requested: {days_diff} days. "
                "This may result in slow response times."
            )

    @staticmethod
    def validate_forecast_days(days: int, max_days: int = 16) -> None:
        """Validate forecast days parameter."""
        if days is not None:
            if not isinstance(days, int):
                raise ValidationError(f"forecast_days must be an integer, got {type(days).__name__}")
            if days < 0:
                raise ValidationError(f"forecast_days cannot be negative: {days}")
            if days > max_days:
                raise ValidationError(
                    f"forecast_days ({days}) exceeds maximum allowed ({max_days})"
                )

    @staticmethod
    def validate_past_days(days: int, max_days: int = 92) -> None:
        """Validate past days parameter."""
        if days is not None:
            if not isinstance(days, int):
                raise ValidationError(f"past_days must be an integer, got {type(days).__name__}")
            if days < 0:
                raise ValidationError(f"past_days cannot be negative: {days}")
            if days > max_days:
                raise ValidationError(
                    f"past_days ({days}) exceeds maximum allowed ({max_days})"
                )


class VariableValidator:
    """Validates weather variable parameters."""

    # Valid hourly variables for Forecast API
    FORECAST_HOURLY = {
        "temperature_2m", "relative_humidity_2m", "dew_point_2m", "apparent_temperature",
        "wet_bulb_temperature_2m", "pressure_msl", "surface_pressure", "cloud_cover",
        "cloud_cover_low", "cloud_cover_mid", "cloud_cover_high", "wind_speed_10m",
        "wind_speed_80m", "wind_speed_120m", "wind_speed_180m", "wind_direction_10m",
        "wind_direction_80m", "wind_direction_120m", "wind_direction_180m", "wind_gusts_10m",
        "shortwave_radiation", "direct_radiation", "direct_normal_irradiance",
        "diffuse_radiation", "global_tilted_irradiance", "terrestrial_radiation",
        "shortwave_radiation_instant", "direct_radiation_instant",
        "diffuse_radiation_instant", "direct_normal_irradiance_instant",
        "global_tilted_irradiance_instant", "terrestrial_radiation_instant",
        "precipitation", "snowfall", "precipitation_probability", "rain", "showers",
        "weather_code", "snow_depth", "freezing_level_height", "visibility",
        "cape", "lifted_index", "convective_inhibition", "soil_temperature_0cm",
        "soil_temperature_6cm", "soil_temperature_18cm", "soil_temperature_54cm",
        "soil_moisture_0_to_1cm", "soil_moisture_1_to_3cm", "soil_moisture_3_to_9cm",
        "soil_moisture_9_to_27cm", "soil_moisture_27_to_81cm", "uv_index",
        "uv_index_clear_sky", "is_day", "sunshine_duration", "vapour_pressure_deficit",
        "evapotranspiration", "et0_fao_evapotranspiration",
        "total_column_integrated_water_vapour", "boundary_layer_height",
    }

    FORECAST_DAILY = {
        "weather_code", "temperature_2m_max", "temperature_2m_min", "temperature_2m_mean",
        "apparent_temperature_max", "apparent_temperature_min", "apparent_temperature_mean",
        "sunrise", "sunset", "daylight_duration", "sunshine_duration", "uv_index_max",
        "uv_index_clear_sky_max", "precipitation_sum", "rain_sum", "showers_sum",
        "snowfall_sum", "precipitation_hours", "precipitation_probability_max",
        "precipitation_probability_min", "precipitation_probability_mean",
        "wind_speed_10m_max", "wind_gusts_10m_max", "wind_direction_10m_dominant",
        "shortwave_radiation_sum", "et0_fao_evapotranspiration",
    }

    # Air Quality variables
    AIR_QUALITY_HOURLY = {
        "pm10", "pm2_5", "carbon_monoxide", "nitrogen_dioxide", "sulphur_dioxide",
        "ozone", "ammonia", "aerosol_optical_depth", "dust", "uv_index",
        "uv_index_clear_sky", "alder_pollen", "birch_pollen", "grass_pollen",
        "mugwort_pollen", "olive_pollen", "ragweed_pollen", "european_aqi",
        "european_aqi_pm2_5", "european_aqi_pm10", "european_aqi_nitrogen_dioxide",
        "european_aqi_ozone", "european_aqi_sulphur_dioxide", "us_aqi",
        "us_aqi_pm2_5", "us_aqi_pm10", "us_aqi_nitrogen_dioxide",
        "us_aqi_carbon_monoxide", "us_aqi_ozone", "us_aqi_sulphur_dioxide",
        "methane", "formaldehyde", "glyoxal", "non_methane_volatile_organic_compounds",
        "carbon_dioxide",
    }

    # Marine variables
    MARINE_HOURLY = {
        "wave_height", "wave_direction", "wave_period", "wind_wave_height",
        "wind_wave_direction", "wind_wave_period", "wind_wave_peak_period",
        "swell_wave_height", "swell_wave_direction", "swell_wave_period",
        "swell_wave_peak_period", "secondary_swell_wave_height",
        "secondary_swell_wave_direction", "secondary_swell_wave_period",
        "secondary_swell_wave_peak_period", "tertiary_swell_wave_height",
        "tertiary_swell_wave_direction", "tertiary_swell_wave_period",
        "tertiary_swell_wave_peak_period", "ocean_current_velocity",
        "ocean_current_direction", "sea_surface_temperature",
        "sea_level_height_msl", "sea_level_height_absolute",
    }

    # Flood variables
    FLOOD_DAILY = {
        "river_discharge", "river_discharge_mean", "river_discharge_median",
        "river_discharge_max", "river_discharge_min", "river_discharge_p25",
        "river_discharge_p75",
    }

    # Climate models
    CLIMATE_MODELS = {
        "CMCC_CM2_VHR4", "FGOALS_f3_H", "HiRAM_SIT_HR", "MRI_AGCM3_2_S",
        "EC_Earth3P_HR", "MPI_ESM1_2_XR", "NICAM16_8S",
    }

    # Ensemble models
    ENSEMBLE_MODELS = {
        "icon_seamless", "icon_global", "icon_eu", "icon_d2",
        "gfs_seamless", "gfs025", "gfs05", "ecmwf_ifs04", "ecmwf_aifs025",
        "gem_global", "bom_access_global_ensemble", "ukmo_global_ensemble",
        "meteoswiss_icon_ch_ensemble",
    }

    @staticmethod
    def validate_variables(
        variables: Optional[List[str]],
        valid_set: Set[str],
        param_name: str = "variables"
    ) -> None:
        """Validate list of variables against valid set."""
        if variables is None:
            return

        if not isinstance(variables, (list, tuple)):
            raise ValidationError(f"{param_name} must be a list")

        invalid = set(variables) - valid_set
        if invalid:
            # Find similar valid variables for helpful error message
            suggestions = []
            for inv in invalid:
                for valid in valid_set:
                    if inv.lower() in valid.lower() or valid.lower() in inv.lower():
                        suggestions.append(valid)
                        break

            error_msg = f"Invalid {param_name}: {sorted(invalid)}"
            if suggestions:
                error_msg += f". Did you mean: {suggestions[:3]}?"
            raise ValidationError(error_msg)

    @staticmethod
    def validate_models(
        models: Optional[List[str]],
        valid_set: Set[str],
        param_name: str = "models"
    ) -> None:
        """Validate list of models against valid set."""
        if models is None:
            return

        if not isinstance(models, (list, tuple)):
            raise ValidationError(f"{param_name} must be a list")

        if len(models) == 0:
            raise ValidationError(f"At least one model must be specified")

        invalid = set(models) - valid_set
        if invalid:
            raise ValidationError(
                f"Invalid {param_name}: {sorted(invalid)}. "
                f"Valid options: {sorted(valid_set)}"
            )


class ResponseValidator:
    """Validates API response data."""

    @staticmethod
    def validate_response_structure(data: Dict[str, Any], required_fields: List[str]) -> None:
        """Validate response has required fields."""
        if not isinstance(data, dict):
            raise ValidationError(f"Expected dict response, got {type(data).__name__}")

        missing = [f for f in required_fields if f not in data]
        if missing:
            raise ValidationError(f"Response missing required fields: {missing}")

    @staticmethod
    def validate_time_series_data(data: Dict[str, Any], time_key: str = "time") -> None:
        """Validate time series data consistency."""
        if time_key not in data:
            raise ValidationError(f"Missing '{time_key}' in response data")

        time_length = len(data[time_key])
        if time_length == 0:
            logger.warning("Response contains empty time series data")
            return

        for key, values in data.items():
            if key == time_key:
                continue
            if isinstance(values, list) and len(values) != time_length:
                raise ValidationError(
                    f"Inconsistent data length for '{key}': "
                    f"expected {time_length}, got {len(values)}"
                )

    @staticmethod
    def validate_numeric_range(
        values: List[Any],
        min_val: Optional[float] = None,
        max_val: Optional[float] = None,
        param_name: str = "values"
    ) -> None:
        """Validate numeric values are within expected range."""
        for i, val in enumerate(values):
            if val is None:
                continue
            if min_val is not None and val < min_val:
                logger.warning(
                    f"{param_name}[{i}] = {val} is below expected minimum ({min_val})"
                )
            if max_val is not None and val > max_val:
                logger.warning(
                    f"{param_name}[{i}] = {val} is above expected maximum ({max_val})"
                )

    @staticmethod
    def check_data_quality(hourly_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Check data quality and return statistics.

        Returns dict with:
            - null_counts: count of null values per variable
            - total_records: total number of records
            - completeness: percentage of non-null values
        """
        if "time" not in hourly_data:
            return {"error": "No time data"}

        total_records = len(hourly_data["time"])
        null_counts = {}

        for key, values in hourly_data.items():
            if key == "time":
                continue
            if isinstance(values, list):
                null_count = sum(1 for v in values if v is None)
                if null_count > 0:
                    null_counts[key] = null_count

        total_values = sum(
            len(v) for k, v in hourly_data.items()
            if k != "time" and isinstance(v, list)
        )
        total_nulls = sum(null_counts.values())

        completeness = (
            ((total_values - total_nulls) / total_values * 100)
            if total_values > 0 else 0
        )

        return {
            "total_records": total_records,
            "null_counts": null_counts,
            "completeness": round(completeness, 2),
        }


class InputSanitizer:
    """Sanitizes input parameters."""

    @staticmethod
    def sanitize_string(value: str, max_length: int = 1000) -> str:
        """Sanitize string input."""
        if not isinstance(value, str):
            return str(value)
        # Trim whitespace and limit length
        return value.strip()[:max_length]

    @staticmethod
    def sanitize_timezone(timezone: str) -> str:
        """Validate and sanitize timezone string."""
        if not timezone:
            return "UTC"

        # Common valid timezones
        valid_formats = [
            "UTC", "GMT", "auto",
        ]

        timezone = timezone.strip()

        if timezone in valid_formats:
            return timezone

        # Check IANA format (e.g., "America/New_York", "Europe/Berlin")
        if "/" in timezone and len(timezone) < 50:
            return timezone

        # Check offset format (e.g., "+05:30", "-08:00")
        if timezone.startswith(("+", "-")) and ":" in timezone:
            return timezone

        logger.warning(f"Unusual timezone format: {timezone}")
        return timezone

    @staticmethod
    def sanitize_unit(unit: str, valid_units: List[str], default: str) -> str:
        """Validate and sanitize unit parameter."""
        if not unit:
            return default

        unit = unit.lower().strip()
        if unit not in valid_units:
            logger.warning(
                f"Invalid unit '{unit}', using default '{default}'. "
                f"Valid options: {valid_units}"
            )
            return default
        return unit
