88from .base import Key
99from ..utils import base64_to_long , long_to_base64 , base64url_decode , base64url_encode
1010from ..constants import ALGORITHMS
11- from ..exceptions import JWKError
11+ from ..exceptions import JWKError , JWEError
1212
13- from cryptography .exceptions import InvalidSignature
13+ from cryptography .exceptions import InvalidSignature , InvalidTag
1414from cryptography .hazmat .backends import default_backend
1515from cryptography .hazmat .bindings .openssl .binding import Binding
1616from cryptography .hazmat .primitives import hashes , serialization , hmac
1717from cryptography .hazmat .primitives .asymmetric import ec , rsa , padding
1818from cryptography .hazmat .primitives .asymmetric .utils import decode_dss_signature , encode_dss_signature
19+ from cryptography .hazmat .primitives .ciphers import Cipher , algorithms , aead , modes
20+ from cryptography .hazmat .primitives .keywrap import aes_key_wrap , aes_key_unwrap , InvalidUnwrap
21+ from cryptography .hazmat .primitives .padding import PKCS7
1922from cryptography .hazmat .primitives .serialization import load_pem_private_key , load_pem_public_key
2023from cryptography .utils import int_from_bytes , int_to_bytes
2124from cryptography .x509 import load_pem_x509_certificate
@@ -223,6 +226,10 @@ class CryptographyRSAKey(Key):
223226 SHA384 = hashes .SHA384
224227 SHA512 = hashes .SHA512
225228
229+ RSA1_5 = padding .PKCS1v15 ()
230+ RSA_OAEP = padding .OAEP (padding .MGF1 (hashes .SHA1 ()), hashes .SHA1 (), None )
231+ RSA_OAEP_256 = padding .OAEP (padding .MGF1 (hashes .SHA256 ()), hashes .SHA256 (), None )
232+
226233 def __init__ (self , key , algorithm , cryptography_backend = default_backend ):
227234 if algorithm not in ALGORITHMS .RSA :
228235 raise JWKError ('hash_alg: %s is not a valid hash algorithm' % algorithm )
@@ -234,6 +241,12 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend):
234241 }.get (algorithm )
235242 self ._algorithm = algorithm
236243
244+ self .padding = {
245+ ALGORITHMS .RSA1_5 : self .RSA1_5 ,
246+ ALGORITHMS .RSA_OAEP : self .RSA_OAEP ,
247+ ALGORITHMS .RSA_OAEP_256 : self .RSA_OAEP_256
248+ }.get (algorithm )
249+
237250 self .cryptography_backend = cryptography_backend
238251
239252 # if it conforms to RSAPublicKey interface
@@ -396,6 +409,149 @@ def to_dict(self):
396409
397410 return data
398411
412+ def wrap_key (self , key_data ):
413+ try :
414+ wrapped_key = self .prepared_key .encrypt (key_data , self .padding )
415+ except Exception as e :
416+ raise JWEError (e )
417+
418+ return wrapped_key
419+
420+ def unwrap_key (self , wrapped_key ):
421+ try :
422+ unwrapped_key = self .prepared_key .decrypt (
423+ wrapped_key ,
424+ self .padding
425+ )
426+ return unwrapped_key
427+ except Exception as e :
428+ raise JWEError (e )
429+
430+
431+ class CryptographyAESKey (Key ):
432+ KEY_128 = (ALGORITHMS .A128GCM , ALGORITHMS .A128GCMKW , ALGORITHMS .A128KW ,
433+ ALGORITHMS .A128CBC )
434+ KEY_192 = (ALGORITHMS .A192GCM , ALGORITHMS .A192GCMKW , ALGORITHMS .A192KW ,
435+ ALGORITHMS .A192CBC )
436+ KEY_256 = (ALGORITHMS .A256GCM , ALGORITHMS .A256GCMKW , ALGORITHMS .A256KW ,
437+ ALGORITHMS .A128CBC_HS256 , ALGORITHMS .A256CBC )
438+ KEY_384 = (ALGORITHMS .A192CBC_HS384 ,)
439+ KEY_512 = (ALGORITHMS .A256CBC_HS512 ,)
440+
441+ AES_KW_ALGS = (ALGORITHMS .A128KW , ALGORITHMS .A192KW , ALGORITHMS .A256KW )
442+
443+ MODES = {
444+ ALGORITHMS .A128GCM : modes .GCM ,
445+ ALGORITHMS .A192GCM : modes .GCM ,
446+ ALGORITHMS .A256GCM : modes .GCM ,
447+ ALGORITHMS .A128CBC_HS256 : modes .CBC ,
448+ ALGORITHMS .A192CBC_HS384 : modes .CBC ,
449+ ALGORITHMS .A256CBC_HS512 : modes .CBC ,
450+ ALGORITHMS .A128CBC : modes .CBC ,
451+ ALGORITHMS .A192CBC : modes .CBC ,
452+ ALGORITHMS .A256CBC : modes .CBC ,
453+ ALGORITHMS .A128GCMKW : modes .GCM ,
454+ ALGORITHMS .A192GCMKW : modes .GCM ,
455+ ALGORITHMS .A256GCMKW : modes .GCM ,
456+ ALGORITHMS .A128KW : None ,
457+ ALGORITHMS .A192KW : None ,
458+ ALGORITHMS .A256KW : None
459+ }
460+
461+ def __init__ (self , key , algorithm ):
462+ if algorithm not in ALGORITHMS .AES :
463+ raise JWKError ('%s is not a valid AES algorithm' % algorithm )
464+ if algorithm not in ALGORITHMS .SUPPORTED .union (ALGORITHMS .AES_PSEUDO ):
465+ raise JWKError ('%s is not a supported algorithm' % algorithm )
466+
467+ self ._algorithm = algorithm
468+ self ._mode = self .MODES .get (self ._algorithm )
469+
470+ if algorithm in self .KEY_128 and len (key ) != 16 :
471+ raise JWKError ("Key must be 128 bit for alg {}" .format (algorithm ))
472+ elif algorithm in self .KEY_192 and len (key ) != 24 :
473+ raise JWKError ("Key must be 192 bit for alg {}" .format (algorithm ))
474+ elif algorithm in self .KEY_256 and len (key ) != 32 :
475+ raise JWKError ("Key must be 256 bit for alg {}" .format (algorithm ))
476+ elif algorithm in self .KEY_384 and len (key ) != 48 :
477+ raise JWKError ("Key must be 384 bit for alg {}" .format (algorithm ))
478+ elif algorithm in self .KEY_512 and len (key ) != 64 :
479+ raise JWKError ("Key must be 512 bit for alg {}" .format (algorithm ))
480+
481+ self ._key = key
482+
483+ def to_dict (self ):
484+ data = {
485+ 'alg' : self ._algorithm ,
486+ 'kty' : 'oct' ,
487+ 'k' : base64url_encode (self ._key )
488+ }
489+ return data
490+
491+ def encrypt (self , plain_text , aad = None ):
492+ plain_text = six .ensure_binary (plain_text )
493+ try :
494+ iv = get_random_bytes (algorithms .AES .block_size // 8 )
495+ mode = self ._mode (iv )
496+ if mode .name == "GCM" :
497+ cipher = aead .AESGCM (self ._key )
498+ cipher_text_and_tag = cipher .encrypt (iv , plain_text , aad )
499+ cipher_text = cipher_text_and_tag [:len (cipher_text_and_tag ) - 16 ]
500+ auth_tag = cipher_text_and_tag [- 16 :]
501+ else :
502+ cipher = Cipher (algorithms .AES (self ._key ), mode ,
503+ backend = default_backend ())
504+ encryptor = cipher .encryptor ()
505+ padder = PKCS7 (algorithms .AES .block_size ).padder ()
506+ padded_data = padder .update (plain_text )
507+ padded_data += padder .finalize ()
508+ cipher_text = encryptor .update (padded_data ) + encryptor .finalize ()
509+ auth_tag = None
510+ return iv , cipher_text , auth_tag
511+ except Exception as e :
512+ raise JWEError (e )
513+
514+ def decrypt (self , cipher_text , iv = None , aad = None , tag = None ):
515+ cipher_text = six .ensure_binary (cipher_text )
516+ try :
517+ iv = six .ensure_binary (iv )
518+ mode = self ._mode (iv )
519+ if mode .name == "GCM" :
520+ if tag is None :
521+ raise ValueError ("tag cannot be None" )
522+ cipher = aead .AESGCM (self ._key )
523+ cipher_text_and_tag = cipher_text + tag
524+ try :
525+ plain_text = cipher .decrypt (iv , cipher_text_and_tag , aad )
526+ except InvalidTag :
527+ raise JWEError ("Invalid JWE Auth Tag" )
528+ else :
529+ cipher = Cipher (algorithms .AES (self ._key ), mode ,
530+ backend = default_backend ())
531+ decryptor = cipher .decryptor ()
532+ padded_plain_text = decryptor .update (cipher_text )
533+ padded_plain_text += decryptor .finalize ()
534+ unpadder = PKCS7 (algorithms .AES .block_size ).unpadder ()
535+ plain_text = unpadder .update (padded_plain_text )
536+ plain_text += unpadder .finalize ()
537+
538+ return plain_text
539+ except Exception as e :
540+ raise JWEError (e )
541+
542+ def wrap_key (self , key_data ):
543+ key_data = six .ensure_binary (key_data )
544+ cipher_text = aes_key_wrap (self ._key , key_data , default_backend ())
545+ return cipher_text # IV, cipher text, auth tag
546+
547+ def unwrap_key (self , wrapped_key ):
548+ wrapped_key = six .ensure_binary (wrapped_key )
549+ try :
550+ plain_text = aes_key_unwrap (self ._key , wrapped_key , default_backend ())
551+ except InvalidUnwrap as cause :
552+ raise JWEError (cause )
553+ return plain_text
554+
399555
400556class CryptographyHMACKey (Key ):
401557 """
0 commit comments