Skip to content

Introduce cryptography module #526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 12 additions & 99 deletions src/saml2/aes.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,18 @@
import os
from base64 import b64decode
from base64 import b64encode
import warnings as _warnings

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers import algorithms
from cryptography.hazmat.primitives.ciphers import modes
from saml2.cryptography.symmetric import AESCipher as _AESCipher


POSTFIX_MODE = {
'cbc': modes.CBC,
'cfb': modes.CFB,
}
_deprecation_msg = (
'{name} {type} is deprecated. '
'It will be removed in the next version. '
'Use saml2.cryptography.symmetric instead.'
).format(name=__name__, type='module')

AES_BLOCK_SIZE = int(algorithms.AES.block_size / 8)
_warnings.simplefilter('default')
_warnings.warn(_deprecation_msg, DeprecationWarning)


class AESCipher(object):
def __init__(self, key):
"""
:param key: The encryption key
:return: AESCipher instance
"""
self.key = key

def build_cipher(self, alg='aes_128_cbc'):
"""
:param alg: cipher algorithm
:return: A Cipher instance
"""
typ, bits, cmode = alg.lower().split('_')
bits = int(bits)
iv = os.urandom(AES_BLOCK_SIZE)

if len(iv) != AES_BLOCK_SIZE:
raise Exception('Wrong iv size: {}'.format(len(iv)))

if bits not in algorithms.AES.key_sizes:
raise Exception('Unsupported key length: {}'.format(bits))

if len(self.key) != bits / 8:
raise Exception('Wrong Key length: {}'.format(len(self.key)))

try:
mode = POSTFIX_MODE[cmode]
except KeyError:
raise Exception('Unsupported chaining mode: {}'.format(cmode))

cipher = Cipher(
algorithms.AES(self.key),
mode(iv),
backend=default_backend())

return cipher, iv

def encrypt(self, msg, alg='aes_128_cbc', padding='PKCS#7', b64enc=True,
block_size=AES_BLOCK_SIZE):
"""
:param key: The encryption key
:param msg: Message to be encrypted
:param padding: Which padding that should be used
:param b64enc: Whether the result should be base64encoded
:param block_size: If PKCS#7 padding which block size to use
:return: The encrypted message
"""

if padding == 'PKCS#7':
_block_size = block_size
elif padding == 'PKCS#5':
_block_size = 8
else:
_block_size = 0

if _block_size:
plen = _block_size - (len(msg) % _block_size)
c = chr(plen).encode()
msg += c * plen

cipher, iv = self.build_cipher(alg)
encryptor = cipher.encryptor()
cmsg = iv + encryptor.update(msg) + encryptor.finalize()

if b64enc:
enc_msg = b64encode(cmsg)
else:
enc_msg = cmsg

return enc_msg

def decrypt(self, msg, alg='aes_128_cbc', padding='PKCS#7', b64dec=True):
"""
:param key: The encryption key
:param msg: Base64 encoded message to be decrypted
:return: The decrypted message
"""
data = b64decode(msg) if b64dec else msg

cipher, iv = self.build_cipher(alg=alg)
decryptor = cipher.decryptor()
res = decryptor.update(data)[AES_BLOCK_SIZE:] + decryptor.finalize()
if padding in ['PKCS#5', 'PKCS#7']:
idx = bytearray(res)[-1]
res = res[:-idx]
return res
AESCipher = _AESCipher
POSTFIX_MODE = _AESCipher.POSTFIX_MODE
AES_BLOCK_SIZE = _AESCipher.AES_BLOCK_SIZE
19 changes: 15 additions & 4 deletions src/saml2/authn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import warnings
import logging
import six
import time
from saml2 import SAMLError
from saml2.aes import AESCipher
import saml2.cryptography.symmetric
from saml2.httputil import Response
from saml2.httputil import make_cookie
from saml2.httputil import Redirect
Expand All @@ -14,6 +15,7 @@
__author__ = 'rolandh'

logger = logging.getLogger(__name__)
warnings.simplefilter('default')


class AuthnFailure(SAMLError):
Expand Down Expand Up @@ -120,7 +122,16 @@ def __init__(self, srv, mako_template, template_lookup, pwd, return_to):
self.return_to = return_to
self.active = {}
self.query_param = "upm_answer"
self.aes = AESCipher(self.srv.symkey.encode())
self.symmetric = saml2.cryptography.symmetric.Default(srv.symkey)

@property
def aes(self):
_deprecation_msg = (
'This attribute is deprecated. '
'It will be removed in the next version. '
'Use self.symmetric instead.')
warnings.warn(_deprecation_msg, DeprecationWarning)
return self.symmetric

def __call__(self, cookie=None, policy_url=None, logo_url=None,
query="", **kwargs):
Expand Down Expand Up @@ -172,7 +183,7 @@ def verify(self, request, **kwargs):
self._verify(_dict["password"][0], _dict["login"][0])
timestamp = str(int(time.mktime(time.gmtime())))
msg = "::".join([_dict["login"][0], timestamp])
info = self.aes.encrypt(msg.encode())
info = self.symmetric.encrypt(msg.encode())
self.active[info] = timestamp
cookie = make_cookie(self.cookie_name, info, self.srv.seed)
return_to = create_return_url(self.return_to, _dict["query"][0],
Expand All @@ -192,7 +203,7 @@ def authenticated_as(self, cookie=None, **kwargs):
info, timestamp = parse_cookie(self.cookie_name,
self.srv.seed, cookie)
if self.active[info] == timestamp:
msg = self.aes.decrypt(info).decode()
msg = self.symmetric.decrypt(info).decode()
uid, _ts = msg.split("::")
if timestamp == _ts:
return {"uid": uid}
Expand Down
7 changes: 3 additions & 4 deletions src/saml2/cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from os.path import join
from os import remove

from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate
import saml2.cryptography.pki

backend = default_backend()

class WrongInput(Exception):
pass
Expand Down Expand Up @@ -325,7 +323,8 @@ def verify(self, signing_cert_str, cert_str):
cert_algorithm = cert_algorithm.decode('ascii')
cert_str = cert_str.encode('ascii')

cert_crypto = load_pem_x509_certificate(cert_str, backend)
cert_crypto = saml2.cryptography.pki.load_pem_x509_certificate(
cert_str)

try:
crypto.verify(ca_cert, cert_crypto.signature,
Expand Down
1 change: 1 addition & 0 deletions src/saml2/cryptography/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""This module provides cryptographic elements needed by saml2."""
37 changes: 37 additions & 0 deletions src/saml2/cryptography/asymmetric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""This module provides methods for asymmetric cryptography."""

import cryptography.hazmat.backends as _backends
import cryptography.hazmat.primitives.asymmetric as _asymmetric
import cryptography.hazmat.primitives.hashes as _hashes
import cryptography.hazmat.primitives.serialization as _serialization


def load_pem_private_key(data, password):
"""Load RSA PEM certificate."""
key = _serialization.load_pem_private_key(
data, password, _backends.default_backend())
return key


def key_sign(rsakey, message, digest):
"""Sign the given message with the RSA key."""
padding = _asymmetric.padding.PKCS1v15()
signature = rsakey.sign(message, padding, digest)
return signature


def key_verify(rsakey, signature, message, digest):
"""Verify the given signature with the RSA key."""
padding = _asymmetric.padding.PKCS1v15()
if isinstance(rsakey, _asymmetric.rsa.RSAPrivateKey):
rsakey = rsakey.public_key()

try:
rsakey.verify(signature, message, padding, digest)
except Exception as e:
return False
else:
return True


hashes = _hashes
9 changes: 9 additions & 0 deletions src/saml2/cryptography/pki.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""This module provides methods for PKI operations."""

import cryptography.hazmat.backends as _backends
import cryptography.x509 as _x509


def load_pem_x509_certificate(data):
"""Load X.509 PEM certificate."""
return _x509.load_pem_x509_certificate(data, _backends.default_backend())
Loading