from functools import wraps
from hashlib import md5
from json import dumps, loads
from time import time
from base64 import b64encode


class RedisCache:
    def __init__(self, redis_client, prefix="rc", serializer=dumps, deserializer=loads):
        self.client = redis_client
        self.prefix = prefix
        self.serialzer = serializer
        self.deserializer = deserializer
        self.set_cache = None

    def get_set_cache(self):
        return self.client.register_script("""
local ttl = tonumber(ARGV[2])
local value
if ttl > 0 then
  value = redis.call('SETEX', KEYS[1], ttl, ARGV[1])
else
  value = redis.call('SET', KEYS[1], ARGV[1])
end
local limit = tonumber(ARGV[3])
if limit > 0 then
  local time = tonumber(redis.call('TIME')[1])
  redis.call('ZADD', KEYS[2], time, KEYS[1])
  local count = tonumber(redis.call('ZCOUNT', KEYS[2], '-inf', '+inf'))
  local over = count - limit
  if over > 0 then
    local stale_keys_and_scores = redis.call('ZPOPMIN', KEYS[2], over)
    -- Remove the the scores and just leave the keys
    local stale_keys = {}
    for i = 1, #stale_keys_and_scores, 2 do
      stale_keys[#stale_keys+1] = stale_keys_and_scores[i]
    end
    redis.call('ZREM', KEYS[2], unpack(stale_keys))
    redis.call('DEL', unpack(stale_keys))
  end
end
return value
""")

    def cache(self, ttl=0, limit=0, namespace=None):
        def decorator(fn):
            nonlocal namespace
            nonlocal ttl
            nonlocal limit

            if not namespace:
                namespace = f'{fn.__module__}.{fn.__name__}'
        
            @wraps(fn)
            def inner(*args, **kwargs):
                args_hash = str(b64encode(md5(self.serialzer([args, kwargs]).encode('utf-8')).digest()), 'utf-8')
                key = f'{self.prefix}:{namespace}:{args_hash}'
                keys_key = f'{self.prefix}:{namespace}:keys'
                result = self.client.get(key)

                if not result:
                    result = fn(*args, **kwargs)
                    result_json = self.serialzer(result)
                    if not self.set_cache:
                       self.set_cache = self.get_set_cache() 
                    
                    self.set_cache(keys=[key, keys_key], args=[result_json, ttl, limit])
                else:
                    result = self.deserializer(result)
                return result
            return inner
        return decorator
