import logging
import math
import threading
import time
import typing
from dataclasses import dataclass
from urllib.parse import parse_qsl, urlparse

from pycognito.utils import RequestsSrpAuth

from ideas.utils import api_types, http

GRACE_PERIOD = 60  # How close we should get to an expiry (in seconds) before requesting new presigned URLs

logger = logging.getLogger()


def parse_presigned_url(url: str) -> typing.Tuple[int, int, str]:
    """
    Given a presigned multipart upload URL, parse it and return a tuple of the part number, the
    expiry time (as a unix timestamp), and the original URL (for convenience).

    >>> parse_presigned_url('/upload/tenant/1/user/<uuid>?partNumber=1&Expires=1678207852')
    (1, 1678207852, '/upload/tenant/1/user/<uuid>?partNumber=1&Expires=1678207852')
    """
    # This treats each query param as a single value, and doesn't support lists
    # like ?foo=bar&foo=baz
    querystring = urlparse(url).query
    query_params = dict(parse_qsl(querystring))

    try:
        expiry = int(query_params["Expires"])
    except KeyError:
        # Some regions don't give us an Expires timestamp, we need to calculate it ourselves
        expiry = math.floor(time.time() + int(query_params["X-Amz-Expires"]))

    part_number = int(query_params["partNumber"])
    return (part_number, expiry, url)


def get_presigned_urls(base_url, headers, auth, file_id):
    """
    Retrieves the presigned multipart upload URLs from the API and remaps them
    into a dict where the keys are the part number, and the values are a tuple
    of
        (
            the expiry time, in seconds,
            the original URL
        )
    """
    url_get_presigned_urls = (
        f"{base_url}/api/{http.IDEAS_API_VERSION}/drs/files/{file_id}/upload_urls/"
    )
    response = typing.cast(
        api_types.FileUploadUrls, http.get(url_get_presigned_urls, headers, auth)
    )
    presigned_urls = response["PresignedUrls"]
    return dict(
        (part_number, (expiry, url))
        for part_number, expiry, url in map(parse_presigned_url, presigned_urls)
    )


@dataclass
class PresignedUrlsManager:
    """
    Manages a set of presigned multipart upload URLs associated with a given
    `file_id`, including handling refreshing the URLs when they are close to
    expiring.
    """

    base_url: str
    headers: dict
    auth: RequestsSrpAuth
    file_id: str

    _invalidated: bool = False

    # A simple exclusive lock for reading/writing from the cache of presigned URLs. While this is a
    # read-heavy, write-light operation that might benefit from something like a readers-write lock,
    # the implementation overhead of designing something like that (which must be reentrant) is too
    # much for how little overhead this simple locking mechanism provides.
    #
    # In a test scenario with 6 upload threads, the lock for readers (aka, just getting a presigned
    # URL for their part that hasn't expired) takes about 3-10 ms to acquire.
    lock = threading.Lock()

    def __post_init__(self):
        self._refresh_presigned_urls()

    def _refresh_presigned_urls(self):
        # Internal cache of presigned URLs, to be refreshed when the expiry is reached
        self._presigned_urls = get_presigned_urls(
            self.base_url,
            self.headers,
            self.auth,
            self.file_id,
        )
        self._invalidated = False

    def invalidate(self) -> None:
        self._invalidated = True

    def get_url_for_part(self, part: int) -> str:
        """
        Given a part number, return a non-expired multipart URL corresponding
        to that part.
        """
        try:
            self.lock.acquire()

            now = time.time()
            expiry, _ = self._presigned_urls[part]
            if now + GRACE_PERIOD > expiry:
                logger.info(
                    "Upload link will expire within %s seconds, refreshing.",
                    GRACE_PERIOD,
                )
                self._refresh_presigned_urls()
            elif self._invalidated:
                logger.info("Cached URLs are invalid, refreshing.")
                self._refresh_presigned_urls()

            # Re-retrieve the presigned URL for this part from our local cache
            # in case we refreshed.
            _, presigned_url = self._presigned_urls[part]
        finally:
            self.lock.release()

        return presigned_url
