# Copyright (C) 2011-2013 Versile AS
# 
# This file is part of Versile Python.
# 
# Versile Python is free software: you can redistribute it and/or
# modify it under the terms of the GNU Affero General Public License
# as published by the Free Software Foundation, either version 3 of
# the License, or (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Affero General Public License for more details.
# 
# You should have received a copy of the GNU Affero General Public
# License along with this program.  If not, see
# <http://www.gnu.org/licenses/>.
#
# Other Usage
# Alternatively, this file may be used in accordance with the terms
# and conditions contained in a signed written agreement between you
# and Versile AS.
#
# Versile Python implements Versile Platform which is a copyrighted
# specification that is not part of this software.  Modification of
# the software is subject to Versile Platform licensing, see
# https://versile.com/ for details. Distribution of unmodified
# versions released by Versile AS is not subject to Versile Platform
# licensing.
#

"""Locally implemented crypto provider."""


import hashlib

from versile.internal import _b2s, _s2b, _val2b, _vexport, _b_ord, _b_chr
from versile.internal import _pyver
from versile.common.iface import abstract
from versile.common.util import VObjectIdentifier
from versile.crypto import VCrypto, VHash, VCryptoException
from versile.crypto import VNumCipher, VNumTransform, VKeyFactory
from versile.crypto import VRSAKeyFactory
from versile.crypto import VAsymmetricKey, VBlockCipher, VBlockTransform, VKey
from versile.crypto import VDecentralIdentitySchemeA
from versile.crypto.algorithm.blowfish import Blowfish
from versile.crypto.math import is_prime, mod_inv
from versile.crypto.rand import VPseudoRandomHMAC

__all__ = ['VLocalCrypto']
__all__ = _vexport(__all__)


class VLocalCrypto(VCrypto):
    """Local crypto provider implementation.
    
    Provides the following cryptographic methods:
    
    +----------------+-------------------------------------------+
    | Domain         | Methods                                   |
    +================+===========================================+
    | Hash types     | sha1, sha224, sha256, sha384, sha512, md5 |
    +----------------+-------------------------------------------+
    | Block ciphers  | blowfish in cbc/ofb modes                 |
    +----------------+-------------------------------------------+
    | Num ciphers    | rsa                                       |
    +----------------+-------------------------------------------+
    | Decentral keys | dia                                       |
    +----------------+-------------------------------------------+

    .. warning::

        Cryptograhic methods are implemented locally in pure python,
        and do not offer the similar level of performance as optimized
        3rd party libraries.

    Though performance is lower than 3rd party alternatives for
    e.g. block ciphers, this provider has the major advantage that it
    is always available as part of :term:`VPy` and does not rely on
    3rd party libraries. 'Slow encryption' may be a preferable
    alternative to 'no encryption'\ .

    RSA key generation is performed by generating two primes of equal
    size. Each prime is generated by calling
    :meth:`versile.crypto.VRSAKeyFactory.extract_prime` with *advance*
    set to False, which effectively reduces the cryptographic strength
    of the generated prime by a few bits - see the documentation of
    that method for details.
    
    .. note::

        The *rsa* cipher accepts
        :meth:`versile.crypto.VKeyFactory.generate` without a *p*
        argument. If *p* is not set then a value of *p=64* is used by
        default. This corresponds to a probability <= 2^(-127)
        that the generated key is not generated from two prime
        numbers.
        
    """

    @property
    def hash_types(self):
        return ('sha1', 'sha224', 'sha256', 'sha384', 'sha512', 'md5')

    def hash_cls(self, hash_name):
        hash_oid = None
        if hash_name == 'sha1':
            _hash = hashlib.sha1
            hash_oid = VObjectIdentifier(1, 3, 14, 3, 2, 26)
        elif hash_name == 'sha224':
            _hash = hashlib.sha224
        elif hash_name == 'sha256':
            _hash = hashlib.sha256
            hash_oid = VObjectIdentifier(2, 16, 840, 1, 101, 3, 4, 2, 1)
        elif hash_name == 'sha384':
            _hash = hashlib.sha384
            hash_oid = VObjectIdentifier(2, 16, 840, 1, 101, 3, 4, 2, 2)
        elif hash_name == 'sha512':
            _hash = hashlib.sha512
            hash_oid = VObjectIdentifier(2, 16, 840, 1, 101, 3, 4, 2, 3)
        elif hash_name == 'md5':
            _hash = hashlib.md5
        else:
            raise VCryptoException('Hash method not implemented')

        class HashCls(VHash):
            @classmethod
            def name(cls):
                return hash_name
            @classmethod
            def oid(cls):
                return hash_oid
            @classmethod
            def digest_size(cls):
                return _hash().digest_size
            def __init__(self, data=None):
                self.__hash = _hash()
                super(HashCls, self).__init__(data=data)
            def update(self, data):
                if _pyver == 2:
                    self.__hash.update(_b2s(data))
                else:
                    self.__hash.update(data)
            def digest(self):
                if _pyver == 2:
                    return _s2b(self.__hash.digest())
                else:
                    return self.__hash.digest()

        return HashCls
        
    @property
    def block_ciphers(self):
        return ('blowfish', 'blowfish128', 'rsa')

    def block_cipher(self, cipher_name):
        if cipher_name == 'blowfish':
            return _VLocalBlowfish()
        elif cipher_name == 'blowfish128':
            return _VLocalBlowfish128()
        elif cipher_name == 'rsa':
            num_cipher = self.num_cipher('rsa')
            return num_cipher.block_cipher()
        else:
            raise VCryptoException('Cipher not supported by this provider')
    
    @property
    def num_ciphers(self):
        return ('rsa',)

    def num_cipher(self, cipher_name):
        if cipher_name == 'rsa':
            return _VLocalRSANumCipher()
        else:
            raise VCryptoException('Cipher not supported by this provider')
    
    @property
    def decentral_key_schemes(self):
        return ('dia',)

    def decentral_key_scheme(self, name):
        if name == 'dia':
            return _VDecentralIdentitySchemeA()
        else:
            raise VCryptoException('Scheme not implemented.')

    def import_ascii_key(self, keydata):
        name, data = VKeyFactory._decode_ascii(keydata)
        if name.startswith('VERSILE RSA'):
            return self.rsa.key_factory.import_ascii(keydata)
        else:
            raise VCryptoException()


## Blowfish classes

class _VLocalBlowfish(VBlockCipher):
    def __init__(self, name='blowfish'):
        super_init=super(_VLocalBlowfish, self).__init__
        super_init(name, ('cbc', 'ofb'), True)

    def blocksize(self, key=None):
        return 8
    
    def c_blocksize(self, key=None):
        return 8
    
    def encrypter(self, key, iv=None, mode='cbc'):
        """Returns block transformer for encryption"""        
        keydata = self._keydata(key)
        if iv is None:
            iv = self.blocksize(key)*b'\x00'
        return self._transform(keydata, iv, mode, encrypt=True)
        
    def decrypter(self, key, iv=None, mode='cbc'):
        """Returns block transformer for decryption"""
        keydata = self._keydata(key)
        if iv is None:
            iv = self.blocksize(key)*b'\x00'
        return self._transform(keydata, iv, mode, encrypt=False)

    @property
    def key_factory(self):
        """Return a key factory for this cipher."""
        return _VLocalBlowfishKeyFactory()

    def _transform(self, keydata, iv, mode, encrypt):
        return _VLocalBlowfishTransform(keydata, iv, mode, encrypt)
        
    def _keydata(self, key):
        if isinstance(key, VKey):
            if key.cipher_name == self.name:
                return key.keydata
            else:
                raise VCryptoException('Key ciphername mismatch')
        raise TypeError('Key must be bytes or a Key object')
        

class _VLocalBlowfish128(_VLocalBlowfish):
    def __init__(self):
        super(_VLocalBlowfish128, self).__init__(name='blowfish128')

    @property
    def key_factory(self):
        """Return a key factory for this cipher."""
        return _VLocalBlowfish128KeyFactory()


class _VLocalBlowfishTransform(VBlockTransform):
    def __init__(self, keydata, iv, mode, encrypt):
        super(_VLocalBlowfishTransform, self).__init__(blocksize=8)

        self.__cipher = Blowfish(keydata)
        if not isinstance(iv, bytes) or len(iv) != self.blocksize:
            raise VCryptoException('Invalid initialization vector')
        self.__iv = iv
        if mode == 'cbc':
            self.__transform = self.__transform_cbc
        elif mode == 'ofb':
            self.__transform = self.__transform_ofb
        else:
            raise VCryptoException('Mode not supported')
        self.__encrypt = bool(encrypt)
        
    def _transform(self, data):
        return self.__transform(data)

    def __transform_cbc(self, data):
        len_data = len(data)
        if len_data % 8:
            raise VCryptoException('Input not aligned to blocksize')
        result = []
        start = 0
        while start < len_data:
            end = start + 8
            block = data[start:end]
            if self.__encrypt:
                if _pyver == 2:
                    indata = b''.join([_s2b(_b_chr(_b_ord(a) ^ _b_ord(b)))
                                       for a, b in zip(block, self.__iv)])
                else:
                    indata = bytes([a ^ b for a, b in zip(block, self.__iv)])
                cipher = self.__cipher.encipher(indata)
                self.__iv = cipher
                result.append(cipher)
            else:
                deciphered = self.__cipher.decipher(block)
                if _pyver == 2:
                    plaintext = b''.join([_s2b(_b_chr(_b_ord(a) ^ _b_ord(b)))
                                          for a, b
                                          in zip(deciphered, self.__iv)])
                else:
                    plaintext = bytes([a ^ b for a, b
                                       in zip(deciphered, self.__iv)])
                self.__iv = block
                result.append(plaintext)
            start += 8
        return b''.join(result)
            
    def __transform_ofb(self, data):
        len_data = len(data)
        if len_data % 8:
            raise VCryptoException('Input not aligned to blocksize')
        result = []
        start = 0
        while start < len_data:
            end = start + 8
            block = data[start:end]
            # Same for encryption/decryption
            mask = self.__cipher.encipher(self.__iv)
            self.__iv = mask
            if _pyver == 2:
                cipher = b''.join([_s2b(_b_chr(_b_ord(a) ^ _b_ord(b)))
                                   for a, b in zip(block, mask)])
            else:
                cipher = bytes([a ^ b for a, b in zip(block, mask)])
            result.append(cipher)
            start += 8
        return b''.join(result)


class _VLocalBlowfishKeyFactory(VKeyFactory):
    def __init__(self, max_len=56):
        super_init = super(_VLocalBlowfishKeyFactory, self).__init__
        super_init(min_len=1, max_len=max_len, size_inc=1)
    
    def generate(self, source, length=56, p=None):
        min_l, max_l, size_inc = self.constraints()
        if not min_l <= length <= max_l or length % size_inc:
            raise VCryptoException('Invalid key length')
        keydata = source(length)
        return self._gen_key(keydata)

    def load(self, keydata):
        if not isinstance(keydata, bytes):
            raise VCryptoException('Keydata must be in bytes format')
        min_l, max_l, size_inc = self.constraints()
        if not min_l <= len(keydata) <= max_l or len(keydata) % size_inc:
            raise VCryptoException('Invalid key length')
        return self._gen_key(keydata)

    def _gen_key(self, keydata):
        return _VLocalBlowfishKey(keydata)


class _VLocalBlowfish128KeyFactory(_VLocalBlowfishKeyFactory):
    def __init__(self, max_len=16):
        super(_VLocalBlowfish128KeyFactory, self).__init__(max_len=max_len)
        
    def generate(self, source, length=16, p=None):
        _super = super(_VLocalBlowfish128KeyFactory, self).generate
        return _super(source=source, length=length, p=p)

    def _gen_key(self, keydata):
        return _VLocalBlowfish128Key(keydata)


class _VLocalBlowfishKey(VKey):
    def __init__(self, keydata, name='blowfish'):
        super(_VLocalBlowfishKey, self).__init__(name)
        self.__keydata = keydata

    @property
    def keydata(self):
        """Returns native key data."""
        return self.__keydata


class _VLocalBlowfish128Key(_VLocalBlowfishKey):
    def __init__(self, keydata):
        super(_VLocalBlowfish128Key, self).__init__(keydata, 'blowfish128')


class _VLocalRSANumCipher(VNumCipher):
    def __init__(self):
        super_init = super(_VLocalRSANumCipher, self).__init__
        super_init(name='rsa', symmetric=False)
        
    def encrypter(self, key):
        keydata = self._keydata(key)
        return _VLocalRSANumTransform(keydata, encrypt=True)
        
    def decrypter(self, key):
        keydata = self._keydata(key)
        return _VLocalRSANumTransform(keydata, encrypt=False)
    
    @property
    def key_factory(self):
        return _VLocalRSAKeyFactory()

    def _keydata(self, key):
        if isinstance(key, VAsymmetricKey) and key.cipher_name == 'rsa':
            return key.keydata
        else:
            raise VCryptoException('Invalid key data')
    

class _VLocalRSANumTransform(VNumTransform):
    def __init__(self, keydata, encrypt):
        self.__keydata = keydata
        n, e, d = keydata[:3]

        if encrypt:
            if e is None:
                raise VCryptoException('Encrypt requires public key')
            def _transform(num):
                if not 0 <= num < n:
                    raise VCryptoException('Number out of range')
                return pow(num, e, n)
            self.__transform = _transform
        else:
            if d is None:
                raise VCryptoException('Decrypt requires private key')
            def _transform(num):
                if not 0 <= num < n:
                    raise VCryptoException('Number out of range')
                return pow(num, d, n)
            self.__transform = _transform
    
    def transform(self, num):        
        if isinstance(num, int):
            num = int(num)
        elif not isinstance(num, int):
            raise VCryptoException('Transformed number must be int or long')
        return self.__transform(num)

    @property
    def max_number(self):
        return self.__keydata[0] - 1


class _VLocalRSAKeyFactory(VRSAKeyFactory):
    def __init__(self):
        super_init = super(_VLocalRSAKeyFactory, self).__init__
        super_init(min_len=2, max_len=None, size_inc=1)
        
    def generate(self, source, length, p=64, callback=None):
        pq = [None, None]
        pq_len = (length//2, length//2 + length%2)

        for i in (0, 1):
            index, offset, pq[i] = self.extract_prime(source, pq_len[i], p,
                                                      callback)
        return self.from_primes(*pq)
            
    @classmethod
    def from_primes(cls, p, q):
        n = p*q
        t = (p-1)*(q-1)
        e = 65537

        # Handle the (normally very, very low probability) special
        # case that n is a small number - this may the case for
        # some test code using small primes
        if e >= n:
            e = n/2 + n%2

        while t % e == 0:
            e += 1
            # The value 56 is hardcoded for a high probability test -
            # however - see versile.crypto.math._SMALL_PRIMES doc -
            # with that variable set properly, the test is expected to
            # nearly always be resolved based on deterministic Euler
            # sieve tests
            while not is_prime(e, 56): # HARDCODED
                e += 1
        d = mod_inv(e, t)
        if d is None:
            raise VCryptoException('Could not generate key')
        if (d*e) % t == 1:
            keydata = (n, e, d, p, q)
            return _VLocalRSAKey(keydata)
        else:
            raise VCryptoException('Could not generate key')
        
    def import_ascii(self, keydata):
        name, num_data = self._decode_ascii(keydata)
        if not name.startswith(b'VERSILE RSA '):
            raise VCryptoException()
        name = name[12:]
        numbers = self._decode_numbers(num_data)
        if name == b'KEY PAIR':
            if len(numbers) != 5:
                raise VCryptoException()
            keydata = tuple(numbers)
        elif name == b'PUBLIC KEY':
            if len(numbers) != 2:
                raise VCryptoException()
            keydata = (numbers[0], numbers[1], None, None)
        elif name == b'PRIVATE KEY':
            if len(numbers) != 2:
                raise VCryptoException()
            keydata = (numbers[0], None, numbers[1], None, None)
        else:
            raise VCryptoException()
        return _VLocalRSAKey(keydata)
        
    def load(self, keydata):
        return _VLocalRSAKey(keydata)


class _VLocalRSAKey(VAsymmetricKey):
    def __init__(self, keydata):        
        super(_VLocalRSAKey, self).__init__('rsa')
        
        if not isinstance(keydata, (tuple, list)) or len(keydata) != 5:
            raise VCryptoException('RSA key data must be 5-tuple')
        for item in keydata:
            if not (item is None or
                    isinstance(item, int) and item >= 0):
                raise VCryptoException('Invalid key data')
        n, e, d, p, q = keydata        
        if n is None:
            raise VCryptoException('RSA n parameter cannot be None')
        if e is None and d is None:
            raise VCryptoException('RSA e and d params cannot both be None')
        for _param in (e, d, p, q):
            if _param is not None and not 0 < _param < n:                
                raise VCryptoException('Invalid parameter')
        if p is None != q is None:
            raise VCryptoException('p and q cannot both be None')
        if p is not None and p*q != n:
            raise VCryptoException('p*q != n')
        self.__keydata = keydata
        # Parameters for X.509 encoding, access via properties
        self.__exp1 = self.__exp2 = self.__coeff = None

    @property
    def has_private(self):
        return (self.__keydata[2] is not None)

    @property
    def has_public(self):
        return (self.__keydata[1] is not None)

    @property
    def private(self):
        n, e, d = self.__keydata[:3]        
        return _VLocalRSAKey((n, None, d, None, None))

    @property
    def public(self):
        n, e, d = self.__keydata[:3]        
        return _VLocalRSAKey((n, e, None, None, None))

    @abstract
    def merge_with(self, key):
        if self.name != key.name:
            raise VCryptoException('Key cipher types do not match')
        if self.keydata[0] != key.keydata[0]:
            raise VCryptoException('Key modulos do not match')
        if self.has_public and key.has_public:
            if self.keydata[1] != key.keydata[1]:
                raise VCryptoException('Key public component mismatch')
        if self.has_private and key.has_private:
            if self.keydata[2] != key.keydata[2]:
                raise VCryptoException('Key private component mismatch')
        if self.has_private:            
            n = self.keydata[0]
            e = key.keydata[1]
            d = self.keydata[2]
        else:
            n = self.keydata[0]
            e = self.keydata[1]
            d = key.keydata[2]
        if None in (n, e, d):
            raise VCryptoException('Incomplete key data, cannot merge')
        # Try to recover p, q factors
        p, q = self.keydata[3:]
        if key.keydata[3] is not None:
            p = key.keydata[3]
        if key.keydata[4] is not None:
            q = key.keydata[4]
        if p is None or q is None:
            p = q = None
        elif p*q != n:
            raise VCryptoException('Got p,q, however p*q != n')
        return _VLocalRSAKey((n, e, d, p, q))

    def export_ascii(self):
        if self.has_private and self.has_public:
            name = b'KEY PAIR'
            numbers = self.keydata
        elif self.has_private:
            name = b'PRIVATE KEY'
            numbers = (self.keydata[0], self.keydata[2])
        elif self.has_public:
            name = b'PUBLIC KEY'
            numbers = (self.keydata[0], self.keydata[1])
        else:
            raise VCryptoException()
        name = b'VERSILE RSA ' + name
        return self._encode_ascii(name, numbers) 
        
    @property
    def keydata(self):
        return self.__keydata

    @property
    def _exp1(self):
        if self.__exp1 is None:
            d, p = self.__keydata[2], self.__keydata[3]
            if d is None or p is None:
                raise VCryptoException()
            self.__exp1 = d % (p-1)
        return self.__exp1

    @property
    def _exp2(self):
        if self.__exp2 is None:
            d, q = self.__keydata[2], self.__keydata[4]
            if d is None or q is None:
                raise VCryptoException()
            self.__exp2 = d % (q-1)
        return self.__exp2

    @property
    def _coeff(self):
        if self.__coeff is None:
            p, q = self.__keydata[3], self.__keydata[4]
            if p is None or q is None:
                raise VCryptoException()
            self.__coeff = mod_inv(q, p)
        return self.__coeff


class _VDecentralIdentitySchemeA(VDecentralIdentitySchemeA):
    """Implementation of the decentral key scheme DIA."""
    
    def raw_generate(self, bits, sec_id_data):
        if bits < 512 or bits % 8:
            raise VCryptoException('Invalid number of bits')
        try:
            hmac_input = sec_id_data.encode('utf8')
        except:
            raise VCryptoException('Could not UTF-8 encode sec_id_data')
        hmac_input += b':Scheme:dia' + _val2b(bits)

        crypto = VLocalCrypto()
        p_rand = VPseudoRandomHMAC(hash_cls=crypto.sha256, secret=b'',
                                   seed=hmac_input)
        return crypto.rsa.key_factory.generate(p_rand, bits//8)
    
    @classmethod
    def name(cls):
        return 'dia'
