"""
FlatBuffers Parser for Open-Meteo API responses.

This module provides zero-copy parsing of Open-Meteo FlatBuffers responses,
enabling efficient processing of large weather datasets.

Usage:
    pip install openmeteo-python[fast]  # Install with FlatBuffers support

    from openmeteo import OpenMeteo
    client = OpenMeteo(format="flatbuffers")
"""

import struct
import logging
from typing import Dict, List, Any, Optional, Tuple, Union
from dataclasses import dataclass, field

logger = logging.getLogger(__name__)

# Check if flatbuffers and numpy are available
try:
    import flatbuffers
    import numpy as np
    FLATBUFFERS_AVAILABLE = True
except ImportError:
    FLATBUFFERS_AVAILABLE = False
    flatbuffers = None
    np = None

from openmeteo.flatbuffers_schema import Variable, Unit, Aggregation, VARIABLE_NAMES


def check_flatbuffers_available() -> None:
    """Check if FlatBuffers dependencies are installed."""
    if not FLATBUFFERS_AVAILABLE:
        raise ImportError(
            "FlatBuffers support requires additional dependencies. "
            "Install with: pip install openmeteo-python[fast]"
        )


@dataclass
class VariableData:
    """
    Parsed variable data with zero-copy numpy array access.

    Attributes:
        variable: Variable enum type
        unit: Unit enum type
        values: Numpy array of values (zero-copy from FlatBuffers)
        altitude: Altitude in meters (e.g., 2 for 2m temperature)
        aggregation: Aggregation method (for daily variables)
        pressure_level: Pressure level in hPa (for pressure level variables)
        depth: Depth in cm (for soil variables)
        ensemble_member: Ensemble member number
    """
    variable: Variable
    unit: Unit
    values: Any  # numpy array
    altitude: int = 0
    aggregation: Aggregation = Aggregation.none
    pressure_level: int = 0
    depth: int = 0
    ensemble_member: int = 0

    @property
    def name(self) -> str:
        """Get the variable name with modifiers."""
        name = VARIABLE_NAMES.get(self.variable, str(self.variable))

        # Add altitude suffix (e.g., temperature_2m)
        if self.altitude > 0:
            name = f"{name}_{self.altitude}m"

        # Add pressure level suffix (e.g., temperature_850hPa)
        if self.pressure_level > 0:
            name = f"{name}_{self.pressure_level}hPa"

        # Add depth suffix (e.g., soil_temperature_0_to_7cm)
        if self.depth > 0:
            name = f"{name}_{self.depth}cm"

        # Add aggregation suffix for daily (e.g., temperature_2m_max)
        if self.aggregation != Aggregation.none:
            agg_name = self.aggregation.name
            name = f"{name}_{agg_name}"

        return name


@dataclass
class TimeSeriesData:
    """
    Time series data with timestamps and variables.

    Attributes:
        time: Unix timestamps as numpy array
        time_end: End timestamps for intervals (optional)
        interval: Time interval in seconds
        variables: List of variable data
    """
    time: Any  # numpy array of int64
    time_end: Optional[Any] = None  # numpy array of int64
    interval: int = 0
    variables: List[VariableData] = field(default_factory=list)

    def get_variable(
        self,
        variable: Variable,
        altitude: int = 0,
        pressure_level: int = 0
    ) -> Optional[VariableData]:
        """Get a specific variable by type and modifiers."""
        for var in self.variables:
            if (var.variable == variable and
                var.altitude == altitude and
                var.pressure_level == pressure_level):
                return var
        return None

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary with variable names as keys."""
        result = {"time": self.time}
        if self.time_end is not None:
            result["time_end"] = self.time_end
        for var in self.variables:
            result[var.name] = var.values
        return result


@dataclass
class WeatherApiResponse:
    """
    Parsed Open-Meteo FlatBuffers API response.

    Attributes:
        latitude: Location latitude
        longitude: Location longitude
        elevation: Elevation in meters
        utc_offset_seconds: UTC offset in seconds
        timezone: Timezone string
        timezone_abbreviation: Timezone abbreviation
        model: Weather model used
        current: Current conditions data
        hourly: Hourly forecast data
        daily: Daily forecast data
        minutely_15: 15-minute forecast data
    """
    latitude: float
    longitude: float
    elevation: float = 0.0
    utc_offset_seconds: int = 0
    timezone: str = "UTC"
    timezone_abbreviation: str = "UTC"
    model: str = ""
    current: Optional[TimeSeriesData] = None
    hourly: Optional[TimeSeriesData] = None
    daily: Optional[TimeSeriesData] = None
    minutely_15: Optional[TimeSeriesData] = None


class FlatBuffersParser:
    """
    Parser for Open-Meteo FlatBuffers responses.

    This parser provides zero-copy access to weather data arrays,
    making it significantly faster than JSON parsing for large datasets.
    """

    # FlatBuffers field offsets (determined from schema)
    # WeatherApiResponse table offsets
    VTABLE_OFFSET = 4

    def __init__(self):
        """Initialize the parser."""
        check_flatbuffers_available()

    def parse(self, data: bytes) -> List[WeatherApiResponse]:
        """
        Parse FlatBuffers response data.

        The Open-Meteo API returns size-prefixed FlatBuffers messages
        when multiple locations are requested.

        Args:
            data: Raw bytes from API response

        Returns:
            List of parsed WeatherApiResponse objects
        """
        responses = []
        offset = 0

        while offset < len(data):
            # Read size prefix (4 bytes, little endian)
            if offset + 4 > len(data):
                break

            size = struct.unpack('<I', data[offset:offset + 4])[0]
            offset += 4

            if offset + size > len(data):
                logger.warning(f"Truncated FlatBuffers message at offset {offset}")
                break

            # Parse single message
            message_data = data[offset:offset + size]
            response = self._parse_single_response(message_data)
            if response:
                responses.append(response)

            offset += size

        return responses

    def _parse_single_response(self, data: bytes) -> Optional[WeatherApiResponse]:
        """Parse a single FlatBuffers message."""
        try:
            buf = flatbuffers.encode.Get(flatbuffers.packer.uoffset, data, 0)

            # Create a simple table reader
            reader = _TableReader(data, buf)

            # Read basic fields
            latitude = reader.read_float(4, 0.0)
            longitude = reader.read_float(6, 0.0)
            elevation = reader.read_float(8, 0.0)
            utc_offset = reader.read_int32(10, 0)
            timezone = reader.read_string(12, "UTC")
            timezone_abbr = reader.read_string(14, "UTC")

            # Read time series data
            current = self._parse_time_series(reader, 18)
            hourly = self._parse_time_series(reader, 22)
            daily = self._parse_time_series(reader, 20)
            minutely_15 = self._parse_time_series(reader, 24)

            return WeatherApiResponse(
                latitude=latitude,
                longitude=longitude,
                elevation=elevation,
                utc_offset_seconds=utc_offset,
                timezone=timezone,
                timezone_abbreviation=timezone_abbr,
                current=current,
                hourly=hourly,
                daily=daily,
                minutely_15=minutely_15,
            )

        except Exception as e:
            logger.error(f"Failed to parse FlatBuffers response: {e}")
            return None

    def _parse_time_series(
        self,
        reader: '_TableReader',
        field_offset: int
    ) -> Optional[TimeSeriesData]:
        """Parse a VariablesWithTime table."""
        table_offset = reader.read_table_offset(field_offset)
        if table_offset == 0:
            return None

        ts_reader = _TableReader(reader.data, table_offset)

        # Read time array
        time_offset = ts_reader.read_vector_offset(4)
        time = self._read_int64_vector(reader.data, time_offset) if time_offset else None

        if time is None:
            return None

        # Read time_end array (optional)
        time_end_offset = ts_reader.read_vector_offset(6)
        time_end = self._read_int64_vector(reader.data, time_end_offset) if time_end_offset else None

        # Read interval
        interval = ts_reader.read_int32(8, 0)

        # Read variables vector
        variables = self._parse_variables(ts_reader, reader.data, 10)

        return TimeSeriesData(
            time=time,
            time_end=time_end,
            interval=interval,
            variables=variables,
        )

    def _parse_variables(
        self,
        ts_reader: '_TableReader',
        data: bytes,
        field_offset: int
    ) -> List[VariableData]:
        """Parse a vector of VariableWithValues."""
        variables = []

        vec_offset = ts_reader.read_vector_offset(field_offset)
        if vec_offset == 0:
            return variables

        # Read vector length
        vec_len = struct.unpack('<I', data[vec_offset:vec_offset + 4])[0]
        vec_start = vec_offset + 4

        for i in range(vec_len):
            # Each element is a table offset
            elem_offset = struct.unpack('<I', data[vec_start + i * 4:vec_start + i * 4 + 4])[0]
            var_offset = vec_start + i * 4 + elem_offset

            var = self._parse_variable(data, var_offset)
            if var:
                variables.append(var)

        return variables

    def _parse_variable(self, data: bytes, offset: int) -> Optional[VariableData]:
        """Parse a single VariableWithValues table."""
        try:
            reader = _TableReader(data, offset)

            # Read variable type (ubyte enum)
            variable_id = reader.read_uint8(4, 0)
            variable = Variable(variable_id) if variable_id < len(Variable) else Variable.undefined

            # Read unit (ubyte enum)
            unit_id = reader.read_uint8(6, 0)
            unit = Unit(unit_id) if unit_id < len(Unit) else Unit.undefined

            # Read values vector (zero-copy)
            values_offset = reader.read_vector_offset(8)
            values = self._read_float_vector(data, values_offset) if values_offset else None

            if values is None:
                return None

            # Read modifiers
            altitude = reader.read_int16(10, 0)
            aggregation_id = reader.read_uint8(12, 0)
            aggregation = Aggregation(aggregation_id) if aggregation_id < len(Aggregation) else Aggregation.none
            pressure_level = reader.read_int16(14, 0)
            depth = reader.read_int16(16, 0)
            ensemble_member = reader.read_int16(18, 0)

            return VariableData(
                variable=variable,
                unit=unit,
                values=values,
                altitude=altitude,
                aggregation=aggregation,
                pressure_level=pressure_level,
                depth=depth,
                ensemble_member=ensemble_member,
            )

        except Exception as e:
            logger.debug(f"Failed to parse variable: {e}")
            return None

    def _read_float_vector(self, data: bytes, offset: int) -> Optional[Any]:
        """Read a float vector as numpy array (zero-copy when possible)."""
        if offset == 0 or offset >= len(data):
            return None

        # Read vector length
        vec_len = struct.unpack('<I', data[offset:offset + 4])[0]
        vec_start = offset + 4
        vec_end = vec_start + vec_len * 4

        if vec_end > len(data):
            return None

        # Zero-copy numpy array from buffer
        return np.frombuffer(data, dtype=np.float32, count=vec_len, offset=vec_start)

    def _read_int64_vector(self, data: bytes, offset: int) -> Optional[Any]:
        """Read an int64 vector as numpy array (zero-copy when possible)."""
        if offset == 0 or offset >= len(data):
            return None

        vec_len = struct.unpack('<I', data[offset:offset + 4])[0]
        vec_start = offset + 4
        vec_end = vec_start + vec_len * 8

        if vec_end > len(data):
            return None

        return np.frombuffer(data, dtype=np.int64, count=vec_len, offset=vec_start)


class _TableReader:
    """Helper class for reading FlatBuffers tables."""

    def __init__(self, data: bytes, table_offset: int):
        self.data = data
        self.table_offset = table_offset

        # Read vtable offset
        vtable_offset_rel = struct.unpack('<i', data[table_offset:table_offset + 4])[0]
        self.vtable_offset = table_offset - vtable_offset_rel

        # Read vtable size
        self.vtable_size = struct.unpack('<H', data[self.vtable_offset:self.vtable_offset + 2])[0]

    def _get_field_offset(self, field_id: int) -> int:
        """Get the offset for a field in the vtable."""
        vtable_field_offset = self.vtable_offset + field_id

        if vtable_field_offset + 2 > self.vtable_offset + self.vtable_size:
            return 0

        offset = struct.unpack('<H', self.data[vtable_field_offset:vtable_field_offset + 2])[0]
        return self.table_offset + offset if offset else 0

    def read_float(self, field_id: int, default: float = 0.0) -> float:
        """Read a float field."""
        offset = self._get_field_offset(field_id)
        if offset == 0:
            return default
        return struct.unpack('<f', self.data[offset:offset + 4])[0]

    def read_int32(self, field_id: int, default: int = 0) -> int:
        """Read an int32 field."""
        offset = self._get_field_offset(field_id)
        if offset == 0:
            return default
        return struct.unpack('<i', self.data[offset:offset + 4])[0]

    def read_int16(self, field_id: int, default: int = 0) -> int:
        """Read an int16 field."""
        offset = self._get_field_offset(field_id)
        if offset == 0:
            return default
        return struct.unpack('<h', self.data[offset:offset + 2])[0]

    def read_uint8(self, field_id: int, default: int = 0) -> int:
        """Read a uint8 field."""
        offset = self._get_field_offset(field_id)
        if offset == 0:
            return default
        return self.data[offset]

    def read_string(self, field_id: int, default: str = "") -> str:
        """Read a string field."""
        offset = self._get_field_offset(field_id)
        if offset == 0:
            return default

        # Read string offset
        str_offset_rel = struct.unpack('<I', self.data[offset:offset + 4])[0]
        str_offset = offset + str_offset_rel

        # Read string length
        str_len = struct.unpack('<I', self.data[str_offset:str_offset + 4])[0]
        str_start = str_offset + 4

        return self.data[str_start:str_start + str_len].decode('utf-8')

    def read_table_offset(self, field_id: int) -> int:
        """Read a table offset field."""
        offset = self._get_field_offset(field_id)
        if offset == 0:
            return 0

        table_offset_rel = struct.unpack('<I', self.data[offset:offset + 4])[0]
        return offset + table_offset_rel

    def read_vector_offset(self, field_id: int) -> int:
        """Read a vector offset field."""
        offset = self._get_field_offset(field_id)
        if offset == 0:
            return 0

        vec_offset_rel = struct.unpack('<I', self.data[offset:offset + 4])[0]
        return offset + vec_offset_rel


def response_to_dict(response: WeatherApiResponse) -> Dict[str, Any]:
    """
    Convert a FlatBuffers response to a dictionary format
    compatible with the JSON response format.

    This provides a seamless transition between JSON and FlatBuffers backends.
    """
    result = {
        "latitude": response.latitude,
        "longitude": response.longitude,
        "elevation": response.elevation,
        "utc_offset_seconds": response.utc_offset_seconds,
        "timezone": response.timezone,
        "timezone_abbreviation": response.timezone_abbreviation,
    }

    if response.current:
        result["current"] = _time_series_to_dict(response.current, single=True)

    if response.hourly:
        result["hourly"] = _time_series_to_dict(response.hourly)

    if response.daily:
        result["daily"] = _time_series_to_dict(response.daily)

    if response.minutely_15:
        result["minutely_15"] = _time_series_to_dict(response.minutely_15)

    return result


def _time_series_to_dict(ts: TimeSeriesData, single: bool = False) -> Dict[str, Any]:
    """Convert time series data to dictionary."""
    result = {}

    # Convert timestamps to ISO format strings or keep as unix time
    if ts.time is not None:
        if single and len(ts.time) == 1:
            result["time"] = int(ts.time[0])
        else:
            result["time"] = ts.time.tolist()

    # Add variables
    for var in ts.variables:
        if single and len(var.values) == 1:
            result[var.name] = float(var.values[0])
        else:
            result[var.name] = var.values.tolist()

    return result
