import json
import logging
import os
import struct
import sys

import numpy as np
import pandas as pd
from beartype import beartype
from beartype.typing import List, Union

from ideas.analysis.types import NumpyFloatVector


def _hex_to_rgb(hex_color: str) -> tuple:
    """Converts a hex color to an RGB tuple.

    Args:
        hex_color (str): The hex color to convert.

    Returns:
        tuple: The RGB tuple.
    """
    hex_color = hex_color.lstrip("#")
    r = int(hex_color[0:2], 16)
    g = int(hex_color[2:4], 16)
    b = int(hex_color[4:6], 16)
    return (r, g, b)


@beartype
def _find_coord_start(
    row: Union[dict, pd.Series], max_poly_points: Union[int, float]
) -> int:
    """
    Find the starting index of the coordinates in the row of a zones file.
    This function assumes that there are an equal number of x and y coordinates, and that the coordinates are formatted as "X 0", "Y 0", "X 1", "Y 1", etc.

    Note that this function is 0-indexed, and +1 is added to the result to account for the 0-indexing of Python.

    Args:
        row (pd.Series): A row of a zones file.
        max_poly_points (int): The maximum number of points in a polygon.

    """
    max_poly_points = int(max_poly_points)
    if max_poly_points < 0:
        raise ValueError("max_poly_points must be a positive integer")

    if isinstance(row, pd.Series):
        try:
            if pd.isnull(row[f"X {max_poly_points}"]):
                return _find_coord_start(row, max_poly_points - 1)
            else:
                return int(max_poly_points)
        except KeyError:
            return _find_coord_start(row, max_poly_points - 1)
    elif isinstance(row, dict):
        if f"X {max_poly_points}" not in row.keys():
            return _find_coord_start(row, max_poly_points - 1)
        else:
            return int(max_poly_points)
    else:
        raise TypeError("Input must be a pandas Series or a dictionary.")


@beartype
def isxd_type(file_path: str) -> str:
    """infer ISXD file type"""

    metadata = _extract_footer(file_path)

    isx_datatype_mapping = {
        0: "miniscope_movie",
        1: "cell_set",
        2: "isxd_behavioral_movie",  # not currently supported on IDEAS
        3: "gpio_data",
        4: "miniscope_image",
        5: "neural_events",
        6: "isxd_metrics",  # not currently supported on IDEAS
        7: "imu_data",
        8: "vessel_set",
    }

    if metadata["type"] not in isx_datatype_mapping.keys():
        raise KeyError(
            f"Unknown key: {metadata['type']}. Expected it to be an +ve integer < 9"
        )

    return isx_datatype_mapping[metadata["type"]]


@beartype
def _footer_length(isxd_file: str) -> int:
    """find the length of the footer in bytes"""

    with open(isxd_file, mode="rb") as file:
        file.seek(-8, os.SEEK_END)
        data = file.read()
    footer_length = struct.unpack("ii", data)[0]

    return footer_length


@beartype
def _extract_footer(isxd_file: str) -> dict:
    """extract movie footer from ISXD file"""

    footer_length = _footer_length(isxd_file)

    with open(isxd_file, mode="rb") as file:
        file.seek(-8 - footer_length - 1, os.SEEK_END)
        data = file.read(footer_length)

    footer = data.decode("utf-8")
    return json.loads(footer)


def _get_isxd_times(input_filename: str):
    """Get the timestamps of every sample of an isxd file from its metadata.

    The timestamps are generated by getting the average sampling period
    of the isxd file.

    :param input_filename str: path to the input file (.isxd)
    :return: The timestamps of every sample in the isxd file
    """

    metadata = _extract_footer(input_filename)
    period = (
        metadata["timingInfo"]["period"]["num"]
        / metadata["timingInfo"]["period"]["den"]
    )
    num_times = metadata["timingInfo"]["numTimes"]
    times = np.linspace(
        0,
        (num_times - 1) * period,
        num_times,
    )
    return times


@beartype
def _sort_isxd_files_by_start_time(
    input_files: List[str],
) -> List:
    """Sort isxd files by their start time.
    :param input_files: list of isxd file paths
    :return: sorted list of isxd file paths
    """
    start_times = []
    for file in input_files:
        isxd_metadata = _extract_footer(file)

        if isxd_metadata["type"] == 5:
            # isxd events
            start_time = (
                isxd_metadata["global times"][0]["secsSinceEpoch"]["num"]
                / isxd_metadata["global times"][0]["secsSinceEpoch"]["den"]
            )
        else:
            # other isxd files
            start_time = (
                isxd_metadata["timingInfo"]["start"]["secsSinceEpoch"]["num"]
                / isxd_metadata["timingInfo"]["start"]["secsSinceEpoch"]["den"]
            )
        start_times.append(start_time)

    sorted_indices = np.argsort(start_times)
    sorted_files = np.array(input_files)[sorted_indices]
    return sorted_files.tolist()


@beartype
def _add_suffix(names: List, suffix: str) -> List:
    """adds a suffix to a list of filenames, useful for
    interacting with the isx API

    ### Arguments:

    - names: list of string arrays
    - suffix: suffix to add

    ### Returns:

    - list of names with suffix

    """

    names = names.copy()

    for i, thing in enumerate(names):
        name, ext = os.path.splitext(thing)
        names[i] = name + suffix + ext

    return names


@beartype
def check_file_extention_is(
    file_name: str,
    *,
    ext: str = ".isxd",
):
    """small util func to check the extention of a file
    and fail otherwise"""

    _, ext_ = os.path.splitext(os.path.basename(file_name))
    if ext_.lower() != ext:
        raise Exception(
            f"{file_name} does not have the extension: {ext}. It instead has {ext_}"
        )


@beartype
def subsample(x: NumpyFloatVector, bin_size: int = 50) -> NumpyFloatVector:
    """
    min-max resampler to make plots smaller
    and to prevent the heat death of your graphics card
    this destroys and alters data, and should only be used
    for plots!
    Args:
        x: a vector
        bin_size: max and min will be computed on this bin size (default 50)
    """

    vector_end = np.floor(x.shape[0] / bin_size).astype(int) * bin_size
    x = x[0:vector_end]
    nbins = int(len(x) / bin_size)

    reshaped_x = np.reshape(x, (nbins, bin_size))
    tops = reshaped_x.max(axis=1)
    bottoms = reshaped_x.min(axis=1)

    reduced_x = np.zeros(tops.shape[0] * 2)
    reduced_x[1::2] = tops
    reduced_x[0::2] = bottoms

    return reduced_x


def get_file_size(file_path: str) -> str:
    """Compute and format the size of a file on disk.

    :param file_path: path to the input file
    """
    file_size = abs(os.path.getsize(file_path))
    for unit_prefix in ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]:
        if file_size < 1024.0:
            return f"{file_size:.2f} {unit_prefix}B"
        file_size /= 1024.0
    return f"{os.path.getsize(file_path)} B"


def _set_up_logger():
    """Set up logger for IDEAS tools"""
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter(
        fmt="[%(asctime)s.%(msecs)03d][%(levelname)s] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    stdout_handler = logging.StreamHandler(sys.stdout)
    stdout_handler.setLevel(logging.INFO)
    stdout_handler.addFilter(lambda r: r.levelno <= logging.ERROR)
    stdout_handler.setFormatter(formatter)

    stderr_handler = logging.StreamHandler(sys.stderr)
    stderr_handler.setLevel(logging.ERROR)
    stderr_handler.setFormatter(formatter)

    logger.addHandler(stdout_handler)
    logger.addHandler(stderr_handler)
