"""
Module to assist in verifying a signed header.
"""
import six

from Crypto.Hash import HMAC
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5
from base64 import b64decode

from .sign import Signer
from .utils import *


class Verifier(Signer):
    """
    Verifies signed text against a secret.
    For HMAC, the secret is the shared secret.
    For RSA, the secret is the PUBLIC key.
    """
    def _verify(self, data, signature):
        """
        Verifies the data matches a signed version with the given signature.
        `data` is the message to verify
        `signature` is a base64-encoded signature to verify against `data`
        """
        
        if isinstance(data, six.string_types): data = data.encode("ascii")
        if isinstance(signature, six.string_types): signature = signature.encode("ascii")
        
        if self.sign_algorithm == 'rsa':
            h = self._hash.new()
            h.update(data)
            return self._rsa.verify(h, b64decode(signature))
        
        elif self.sign_algorithm == 'hmac':
            h = self._sign_hmac(data)
            s = b64decode(signature)
            return (h == s)
        
        else:
            raise HttpSigException("Unsupported algorithm.")


class HeaderVerifier(Verifier):
    """
    Verifies an HTTP signature from given headers.
    """
    def __init__(self, headers, secret, required_headers=None, method=None, path=None, host=None):
        required_headers = required_headers or ['date']
        
        auth = parse_authorization_header(headers['authorization'])
        if len(auth) == 2:
            self.auth_dict = auth[1]
        else:
            raise HttpSigException("Invalid authorization header.")
        
        self.headers = CaseInsensitiveDict(headers)
        self.required_headers = [s.lower() for s in required_headers]
        self.method = method
        self.path = path
        self.host = host
        
        super(HeaderVerifier, self).__init__(secret, algorithm=self.auth_dict['algorithm'])

    def verify(self):
        auth_headers = self.auth_dict.get('headers', 'date').split(' ')
        
        if len(set(self.required_headers) - set(auth_headers)) > 0:
            raise Exception('{} is a required header(s)'.format(', '.join(set(self.required_headers)-set(auth_headers))))
        
        signing_str = generate_message(auth_headers, self.headers, self.host, self.method, self.path)
        
        return self._verify(signing_str, self.auth_dict['signature'])
