Skip to content

Commit 54417da

Browse files
committed
Add AESKey and DIRKey implementations
1 parent 7f22486 commit 54417da

File tree

8 files changed

+677
-11
lines changed

8 files changed

+677
-11
lines changed

jose/backends/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,17 @@
2929
except ImportError:
3030
from jose.backends.ecdsa_backend import ECDSAECKey as ECKey # noqa: F401
3131

32+
try:
33+
from jose.backends.cryptography_backend import CryptographyAESKey as AESKey # noqa: F401
34+
except ImportError:
35+
try:
36+
from jose.backends.pycrypto_backend import AESKey # noqa: F401
37+
except ImportError:
38+
AESKey = None
39+
3240
try:
3341
from jose.backends.cryptography_backend import CryptographyHMACKey as HMACKey # noqa: F401
3442
except ImportError:
3543
from jose.backends.native import HMACKey # noqa: F401
44+
45+
from .base import DIRKey # noqa: F401

jose/backends/base.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import six
2+
3+
from ..utils import base64url_encode
4+
5+
16
class Key(object):
27
"""
38
A simple interface for implementing JWK keys.
@@ -19,3 +24,67 @@ def to_pem(self):
1924

2025
def to_dict(self):
2126
raise NotImplementedError()
27+
28+
def encrypt(self, plain_text, aad=None):
29+
"""
30+
Encrypt the plain text and generate an auth tag if appropriate
31+
32+
Args:
33+
plain_text (bytes): Data to encrypt
34+
aad (bytes, optional): Authenticated Additional Data if key's algorithm supports auth mode
35+
36+
Returns:
37+
(bytes, bytes, bytes): IV, cipher text, and auth tag
38+
"""
39+
raise NotImplementedError()
40+
41+
def decrypt(self, cipher_text, iv=None, aad=None, tag=None):
42+
"""
43+
Decrypt the cipher text and validate the auth tag if present
44+
Args:
45+
cipher_text (bytes): Cipher text to decrypt
46+
iv (bytes): IV if block mode
47+
aad (bytes): Additional Authenticated Data to verify if auth mode
48+
tag (bytes): Authentication tag if auth mode
49+
50+
Returns:
51+
bytes: Decrypted value
52+
"""
53+
raise NotImplementedError()
54+
55+
def wrap_key(self, key_data):
56+
"""
57+
Wrap the the plain text key data
58+
59+
Args:
60+
key_data (bytes): Key data to wrap
61+
62+
Returns:
63+
bytes: Wrapped key
64+
"""
65+
raise NotImplementedError()
66+
67+
def unwrap_key(self, wrapped_key):
68+
"""
69+
Unwrap the the wrapped key data
70+
71+
Args:
72+
wrapped_key (bytes): Wrapped key data to unwrap
73+
74+
Returns:
75+
bytes: Unwrapped key
76+
"""
77+
raise NotImplementedError()
78+
79+
80+
class DIRKey(Key):
81+
def __init__(self, key_data, algorithm):
82+
self._key = six.ensure_binary(key_data)
83+
self._alg = algorithm
84+
85+
def to_dict(self):
86+
return {
87+
'alg': self._alg,
88+
'kty': 'oct',
89+
'k': base64url_encode(self._key),
90+
}

jose/backends/cryptography_backend.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@
88
from .base import Key
99
from ..utils import base64_to_long, long_to_base64, base64url_decode, base64url_encode
1010
from ..constants import ALGORITHMS
11-
from ..exceptions import JWKError
11+
from ..exceptions import JWKError, JWEError
1212

13-
from cryptography.exceptions import InvalidSignature
13+
from cryptography.exceptions import InvalidSignature, InvalidTag
1414
from cryptography.hazmat.backends import default_backend
1515
from cryptography.hazmat.bindings.openssl.binding import Binding
1616
from cryptography.hazmat.primitives import hashes, serialization, hmac
1717
from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding
1818
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
19+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, aead, modes
20+
from cryptography.hazmat.primitives.keywrap import aes_key_wrap, aes_key_unwrap, InvalidUnwrap
21+
from cryptography.hazmat.primitives.padding import PKCS7
1922
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
2023
from cryptography.utils import int_from_bytes, int_to_bytes
2124
from cryptography.x509 import load_pem_x509_certificate
@@ -223,6 +226,10 @@ class CryptographyRSAKey(Key):
223226
SHA384 = hashes.SHA384
224227
SHA512 = hashes.SHA512
225228

229+
RSA1_5 = padding.PKCS1v15()
230+
RSA_OAEP = padding.OAEP(padding.MGF1(hashes.SHA1()), hashes.SHA1(), None)
231+
RSA_OAEP_256 = padding.OAEP(padding.MGF1(hashes.SHA256()), hashes.SHA256(), None)
232+
226233
def __init__(self, key, algorithm, cryptography_backend=default_backend):
227234
if algorithm not in ALGORITHMS.RSA:
228235
raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)
@@ -234,6 +241,12 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend):
234241
}.get(algorithm)
235242
self._algorithm = algorithm
236243

244+
self.padding = {
245+
ALGORITHMS.RSA1_5: self.RSA1_5,
246+
ALGORITHMS.RSA_OAEP: self.RSA_OAEP,
247+
ALGORITHMS.RSA_OAEP_256: self.RSA_OAEP_256
248+
}.get(algorithm)
249+
237250
self.cryptography_backend = cryptography_backend
238251

239252
# if it conforms to RSAPublicKey interface
@@ -396,6 +409,149 @@ def to_dict(self):
396409

397410
return data
398411

412+
def wrap_key(self, key_data):
413+
try:
414+
wrapped_key = self.prepared_key.encrypt(key_data, self.padding)
415+
except Exception as e:
416+
raise JWEError(e)
417+
418+
return wrapped_key
419+
420+
def unwrap_key(self, wrapped_key):
421+
try:
422+
unwrapped_key = self.prepared_key.decrypt(
423+
wrapped_key,
424+
self.padding
425+
)
426+
return unwrapped_key
427+
except Exception as e:
428+
raise JWEError(e)
429+
430+
431+
class CryptographyAESKey(Key):
432+
KEY_128 = (ALGORITHMS.A128GCM, ALGORITHMS.A128GCMKW, ALGORITHMS.A128KW,
433+
ALGORITHMS.A128CBC)
434+
KEY_192 = (ALGORITHMS.A192GCM, ALGORITHMS.A192GCMKW, ALGORITHMS.A192KW,
435+
ALGORITHMS.A192CBC)
436+
KEY_256 = (ALGORITHMS.A256GCM, ALGORITHMS.A256GCMKW, ALGORITHMS.A256KW,
437+
ALGORITHMS.A128CBC_HS256, ALGORITHMS.A256CBC)
438+
KEY_384 = (ALGORITHMS.A192CBC_HS384,)
439+
KEY_512 = (ALGORITHMS.A256CBC_HS512,)
440+
441+
AES_KW_ALGS = (ALGORITHMS.A128KW, ALGORITHMS.A192KW, ALGORITHMS.A256KW)
442+
443+
MODES = {
444+
ALGORITHMS.A128GCM: modes.GCM,
445+
ALGORITHMS.A192GCM: modes.GCM,
446+
ALGORITHMS.A256GCM: modes.GCM,
447+
ALGORITHMS.A128CBC_HS256: modes.CBC,
448+
ALGORITHMS.A192CBC_HS384: modes.CBC,
449+
ALGORITHMS.A256CBC_HS512: modes.CBC,
450+
ALGORITHMS.A128CBC: modes.CBC,
451+
ALGORITHMS.A192CBC: modes.CBC,
452+
ALGORITHMS.A256CBC: modes.CBC,
453+
ALGORITHMS.A128GCMKW: modes.GCM,
454+
ALGORITHMS.A192GCMKW: modes.GCM,
455+
ALGORITHMS.A256GCMKW: modes.GCM,
456+
ALGORITHMS.A128KW: None,
457+
ALGORITHMS.A192KW: None,
458+
ALGORITHMS.A256KW: None
459+
}
460+
461+
def __init__(self, key, algorithm):
462+
if algorithm not in ALGORITHMS.AES:
463+
raise JWKError('%s is not a valid AES algorithm' % algorithm)
464+
if algorithm not in ALGORITHMS.SUPPORTED.union(ALGORITHMS.AES_PSEUDO):
465+
raise JWKError('%s is not a supported algorithm' % algorithm)
466+
467+
self._algorithm = algorithm
468+
self._mode = self.MODES.get(self._algorithm)
469+
470+
if algorithm in self.KEY_128 and len(key) != 16:
471+
raise JWKError("Key must be 128 bit for alg {}".format(algorithm))
472+
elif algorithm in self.KEY_192 and len(key) != 24:
473+
raise JWKError("Key must be 192 bit for alg {}".format(algorithm))
474+
elif algorithm in self.KEY_256 and len(key) != 32:
475+
raise JWKError("Key must be 256 bit for alg {}".format(algorithm))
476+
elif algorithm in self.KEY_384 and len(key) != 48:
477+
raise JWKError("Key must be 384 bit for alg {}".format(algorithm))
478+
elif algorithm in self.KEY_512 and len(key) != 64:
479+
raise JWKError("Key must be 512 bit for alg {}".format(algorithm))
480+
481+
self._key = key
482+
483+
def to_dict(self):
484+
data = {
485+
'alg': self._algorithm,
486+
'kty': 'oct',
487+
'k': base64url_encode(self._key)
488+
}
489+
return data
490+
491+
def encrypt(self, plain_text, aad=None):
492+
plain_text = six.ensure_binary(plain_text)
493+
try:
494+
iv = get_random_bytes(algorithms.AES.block_size//8)
495+
mode = self._mode(iv)
496+
if mode.name == "GCM":
497+
cipher = aead.AESGCM(self._key)
498+
cipher_text_and_tag = cipher.encrypt(iv, plain_text, aad)
499+
cipher_text = cipher_text_and_tag[:len(cipher_text_and_tag) - 16]
500+
auth_tag = cipher_text_and_tag[-16:]
501+
else:
502+
cipher = Cipher(algorithms.AES(self._key), mode,
503+
backend=default_backend())
504+
encryptor = cipher.encryptor()
505+
padder = PKCS7(algorithms.AES.block_size).padder()
506+
padded_data = padder.update(plain_text)
507+
padded_data += padder.finalize()
508+
cipher_text = encryptor.update(padded_data) + encryptor.finalize()
509+
auth_tag = None
510+
return iv, cipher_text, auth_tag
511+
except Exception as e:
512+
raise JWEError(e)
513+
514+
def decrypt(self, cipher_text, iv=None, aad=None, tag=None):
515+
cipher_text = six.ensure_binary(cipher_text)
516+
try:
517+
iv = six.ensure_binary(iv)
518+
mode = self._mode(iv)
519+
if mode.name == "GCM":
520+
if tag is None:
521+
raise ValueError("tag cannot be None")
522+
cipher = aead.AESGCM(self._key)
523+
cipher_text_and_tag = cipher_text + tag
524+
try:
525+
plain_text = cipher.decrypt(iv, cipher_text_and_tag, aad)
526+
except InvalidTag:
527+
raise JWEError("Invalid JWE Auth Tag")
528+
else:
529+
cipher = Cipher(algorithms.AES(self._key), mode,
530+
backend=default_backend())
531+
decryptor = cipher.decryptor()
532+
padded_plain_text = decryptor.update(cipher_text)
533+
padded_plain_text += decryptor.finalize()
534+
unpadder = PKCS7(algorithms.AES.block_size).unpadder()
535+
plain_text = unpadder.update(padded_plain_text)
536+
plain_text += unpadder.finalize()
537+
538+
return plain_text
539+
except Exception as e:
540+
raise JWEError(e)
541+
542+
def wrap_key(self, key_data):
543+
key_data = six.ensure_binary(key_data)
544+
cipher_text = aes_key_wrap(self._key, key_data, default_backend())
545+
return cipher_text # IV, cipher text, auth tag
546+
547+
def unwrap_key(self, wrapped_key):
548+
wrapped_key = six.ensure_binary(wrapped_key)
549+
try:
550+
plain_text = aes_key_unwrap(self._key, wrapped_key, default_backend())
551+
except InvalidUnwrap as cause:
552+
raise JWEError(cause)
553+
return plain_text
554+
399555

400556
class CryptographyHMACKey(Key):
401557
"""

0 commit comments

Comments
 (0)