From ae1a906952557f616706f79c66030fd812e48cdf Mon Sep 17 00:00:00 2001 From: Andrey Semakin Date: Mon, 4 Nov 2019 18:35:23 +0500 Subject: [PATCH] Add more type hints --- rsa/_compat.py | 2 +- rsa/cli.py | 10 +++++----- rsa/common.py | 2 +- rsa/core.py | 2 +- rsa/key.py | 28 ++++++++++++++++------------ rsa/parallel.py | 3 ++- rsa/pem.py | 6 +++--- rsa/pkcs1.py | 2 +- rsa/pkcs1_v2.py | 2 +- 9 files changed, 31 insertions(+), 26 deletions(-) diff --git a/rsa/_compat.py b/rsa/_compat.py index d9cd8d8..050e81b 100644 --- a/rsa/_compat.py +++ b/rsa/_compat.py @@ -17,7 +17,7 @@ from struct import pack -def byte(num: int): +def byte(num: int) -> bytes: """ Converts a number between 0 and 255 (both inclusive) to a base-256 (byte) representation. diff --git a/rsa/cli.py b/rsa/cli.py index 2bba47f..3166150 100644 --- a/rsa/cli.py +++ b/rsa/cli.py @@ -110,7 +110,7 @@ def __init__(self) -> None: @abc.abstractmethod def perform_operation(self, indata: bytes, key: rsa.key.AbstractKey, - cli_args: Indexable): + cli_args: Indexable) -> typing.Any: """Performs the program's operation. Implement in a subclass. @@ -201,7 +201,7 @@ class EncryptOperation(CryptoOperation): operation_progressive = 'encrypting' def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey, - cli_args: Indexable = ()): + cli_args: Indexable = ()) -> bytes: """Encrypts files.""" assert isinstance(pub_key, rsa.key.PublicKey) return rsa.encrypt(indata, pub_key) @@ -219,7 +219,7 @@ class DecryptOperation(CryptoOperation): key_class = rsa.PrivateKey def perform_operation(self, indata: bytes, priv_key: rsa.key.AbstractKey, - cli_args: Indexable = ()): + cli_args: Indexable = ()) -> bytes: """Decrypts files.""" assert isinstance(priv_key, rsa.key.PrivateKey) return rsa.decrypt(indata, priv_key) @@ -242,7 +242,7 @@ class SignOperation(CryptoOperation): 'to stdout if this option is not present.') def perform_operation(self, indata: bytes, priv_key: rsa.key.AbstractKey, - cli_args: Indexable): + cli_args: Indexable) -> bytes: """Signs files.""" assert isinstance(priv_key, rsa.key.PrivateKey) @@ -269,7 +269,7 @@ class VerifyOperation(CryptoOperation): has_output = False def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey, - cli_args: Indexable): + cli_args: Indexable) -> None: """Verifies files.""" assert isinstance(pub_key, rsa.key.PublicKey) diff --git a/rsa/common.py b/rsa/common.py index c5b647d..e7df21d 100644 --- a/rsa/common.py +++ b/rsa/common.py @@ -18,7 +18,7 @@ class NotRelativePrimeError(ValueError): - def __init__(self, a, b, d, msg=''): + def __init__(self, a: int, b: int, d: int, msg: str = '') -> None: super().__init__(msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d)) self.a = a self.b = b diff --git a/rsa/core.py b/rsa/core.py index d6e146f..23032e3 100644 --- a/rsa/core.py +++ b/rsa/core.py @@ -19,7 +19,7 @@ """ -def assert_int(var: int, name: str): +def assert_int(var: int, name: str) -> None: if isinstance(var, int): return diff --git a/rsa/key.py b/rsa/key.py index 94a14b1..b1e2030 100644 --- a/rsa/key.py +++ b/rsa/key.py @@ -94,7 +94,7 @@ def _save_pkcs1_der(self) -> bytes: """ @classmethod - def load_pkcs1(cls, keyfile: bytes, format='PEM') -> 'AbstractKey': + def load_pkcs1(cls, keyfile: bytes, format: str = 'PEM') -> 'AbstractKey': """Loads a key in PKCS#1 DER or PEM format. :param keyfile: contents of a DER- or PEM-encoded file that contains @@ -128,7 +128,7 @@ def _assert_format_exists(file_format: str, methods: typing.Mapping[str, typing. raise ValueError('Unsupported format: %r, try one of %s' % (file_format, formats)) - def save_pkcs1(self, format='PEM') -> bytes: + def save_pkcs1(self, format: str = 'PEM') -> bytes: """Saves the key in PKCS#1 DER or PEM format. :param format: the format to save; 'PEM' or 'DER' @@ -203,7 +203,7 @@ class PublicKey(AbstractKey): __slots__ = ('n', 'e') - def __getitem__(self, key): + def __getitem__(self, key: str) -> int: return getattr(self, key) def __repr__(self) -> str: @@ -378,7 +378,7 @@ def __init__(self, n: int, e: int, d: int, p: int, q: int) -> None: self.exp2 = int(d % (q - 1)) self.coef = rsa.common.inverse(q, p) - def __getitem__(self, key): + def __getitem__(self, key: str) -> int: return getattr(self, key) def __repr__(self) -> str: @@ -388,7 +388,7 @@ def __getstate__(self) -> typing.Tuple[int, int, int, int, int, int, int, int]: """Returns the key as tuple for pickling.""" return self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef - def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]): + def __setstate__(self, state: typing.Tuple[int, int, int, int, int, int, int, int]) -> None: """Sets the key from tuple.""" self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef = state @@ -574,7 +574,9 @@ def _save_pkcs1_pem(self) -> bytes: return rsa.pem.save_pem(der, b'RSA PRIVATE KEY') -def find_p_q(nbits: int, getprime_func=rsa.prime.getprime, accurate=True) -> typing.Tuple[int, int]: +def find_p_q(nbits: int, + getprime_func: typing.Callable[[int], int] = rsa.prime.getprime, + accurate: bool = True) -> typing.Tuple[int, int]: """Returns a tuple of two different primes of nbits bits each. The resulting p * q has exacty 2 * nbits bits, and the returned p and q @@ -619,7 +621,7 @@ def find_p_q(nbits: int, getprime_func=rsa.prime.getprime, accurate=True) -> typ log.debug('find_p_q(%i): Finding q', nbits) q = getprime_func(qbits) - def is_acceptable(p, q): + def is_acceptable(p: int, q: int) -> bool: """Returns True iff p and q are acceptable: - p and q differ @@ -697,8 +699,8 @@ def calculate_keys(p: int, q: int) -> typing.Tuple[int, int]: def gen_keys(nbits: int, getprime_func: typing.Callable[[int], int], - accurate=True, - exponent=DEFAULT_EXPONENT) -> typing.Tuple[int, int, int, int]: + accurate: bool = True, + exponent: int = DEFAULT_EXPONENT) -> typing.Tuple[int, int, int, int]: """Generate RSA keys of nbits bits. Returns (p, q, e, d). Note: this can take a long time, depending on the key size. @@ -726,8 +728,10 @@ def gen_keys(nbits: int, return p, q, e, d -def newkeys(nbits: int, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT) \ - -> typing.Tuple[PublicKey, PrivateKey]: +def newkeys(nbits: int, + accurate: bool = True, + poolsize: int = 1, + exponent: int = DEFAULT_EXPONENT) -> typing.Tuple[PublicKey, PrivateKey]: """Generates public and private keys, and returns them as (pub, priv). The public key is also known as the 'encryption key', and is a @@ -763,7 +767,7 @@ def newkeys(nbits: int, accurate=True, poolsize=1, exponent=DEFAULT_EXPONENT) \ if poolsize > 1: from rsa import parallel - def getprime_func(nbits): + def getprime_func(nbits: int) -> int: return parallel.getprime(nbits, poolsize=poolsize) else: getprime_func = rsa.prime.getprime diff --git a/rsa/parallel.py b/rsa/parallel.py index 1c98442..f9afedb 100644 --- a/rsa/parallel.py +++ b/rsa/parallel.py @@ -23,12 +23,13 @@ """ import multiprocessing as mp +from multiprocessing.connection import Connection import rsa.prime import rsa.randnum -def _find_prime(nbits: int, pipe) -> None: +def _find_prime(nbits: int, pipe: Connection) -> None: while True: integer = rsa.randnum.read_random_odd_int(nbits) diff --git a/rsa/pem.py b/rsa/pem.py index 24edd90..1ffb446 100644 --- a/rsa/pem.py +++ b/rsa/pem.py @@ -49,7 +49,7 @@ def _pem_lines(contents: bytes, pem_start: bytes, pem_end: bytes) -> typing.Iter # Handle start marker if line == pem_start: if in_pem_part: - raise ValueError('Seen start marker "%s" twice' % pem_start) + raise ValueError('Seen start marker "%r" twice' % pem_start) in_pem_part = True seen_pem_start = True @@ -72,10 +72,10 @@ def _pem_lines(contents: bytes, pem_start: bytes, pem_end: bytes) -> typing.Iter # Do some sanity checks if not seen_pem_start: - raise ValueError('No PEM start marker "%s" found' % pem_start) + raise ValueError('No PEM start marker "%r" found' % pem_start) if in_pem_part: - raise ValueError('No PEM end marker "%s" found' % pem_end) + raise ValueError('No PEM end marker "%r" found' % pem_end) def load_pem(contents: FlexiText, pem_marker: FlexiText) -> bytes: diff --git a/rsa/pkcs1.py b/rsa/pkcs1.py index 10ee50b..8d77a97 100644 --- a/rsa/pkcs1.py +++ b/rsa/pkcs1.py @@ -157,7 +157,7 @@ def _pad_for_signing(message: bytes, target_length: int) -> bytes: message]) -def encrypt(message: bytes, pub_key: key.PublicKey): +def encrypt(message: bytes, pub_key: key.PublicKey) -> bytes: """Encrypts the given message using PKCS#1 v1.5 :param message: the message to encrypt. Must be a byte string no longer than diff --git a/rsa/pkcs1_v2.py b/rsa/pkcs1_v2.py index db94f87..f780aff 100644 --- a/rsa/pkcs1_v2.py +++ b/rsa/pkcs1_v2.py @@ -25,7 +25,7 @@ ) -def mgf1(seed: bytes, length: int, hasher='SHA-1') -> bytes: +def mgf1(seed: bytes, length: int, hasher: str = 'SHA-1') -> bytes: """ MGF1 is a Mask Generation Function based on a hash function.