From 9dbefeeea27e42bcdb0c4feeb80a779b972f5446 Mon Sep 17 00:00:00 2001 From: Dain Nilsson Date: Tue, 22 Oct 2024 14:35:52 +0200 Subject: [PATCH] Code cleanups --- fido2/arkg.py | 173 +++++++++++++++++++++++++++++++------------------- 1 file changed, 107 insertions(+), 66 deletions(-) diff --git a/fido2/arkg.py b/fido2/arkg.py index 14ad4b2..dd07ba4 100644 --- a/fido2/arkg.py +++ b/fido2/arkg.py @@ -11,78 +11,85 @@ import fastecdsa.keys from dataclasses import dataclass -from typing import Tuple +from typing import Tuple, Sequence import struct -def strxor(a, b): +def strxor(a: bytes, b: bytes) -> bytes: c = bytearray(len(a)) for i in range(len(a)): c[i] = a[i] ^ b[i] return c -def expand_message_xmd(H, msg, dst, out_len): - ell = -(-out_len // H.digest_size) - dst_prime = dst + struct.pack(">B", len(dst)) - z_pad = b"\x00" * H.block_size - l_i_b_str = struct.pack(">H", out_len) - msg_prime = z_pad + msg + l_i_b_str + b"\x00" + dst_prime - d = Hash(H) - d.update(msg_prime) - b_0 = d.finalize() - b_xor = b_0 - uniform_bytes = bytearray() - for i in range(1, ell + 1): - d = Hash(H) - d.update(b_xor + struct.pack(">B", i) + dst_prime) - b_i = d.finalize() - uniform_bytes.extend(b_i) - b_xor = strxor(b_0, b_i) - return uniform_bytes[:out_len] - - -""" -hash_to_field(msg, count) - -Parameters: -- DST, a domain separation tag (see Section 3.1). -- F, a finite field of characteristic p and order q = p^m. -- p, the characteristic of F (see immediately above). -- m, the extension degree of F, m >= 1 (see immediately above). -- L = ceil((ceil(log2(p)) + k) / 8), where k is the security - parameter of the suite (e.g., k = 128). -- expand_message, a function that expands a byte string and - domain separation tag into a uniformly random byte string - (see Section 5.3). -""" - - -# m is always 1 -def hash_to_field(msg, count, dst, p, L, H): - elements = list() - uniform_bytes = expand_message_xmd(H, msg, dst, count * L) - for i in range(count): - offset = L * i - tv = uniform_bytes[offset : offset + L] - e_j = bytes2int(tv) % p - elements.append(e_j) - return elements - - @dataclass class HTF: + """ + hash_to_field(msg, count) + + Parameters: + - DST, a domain separation tag (see Section 3.1). + - F, a finite field of characteristic p and order q = p^m. + - p, the characteristic of F (see immediately above). + - m, the extension degree of F, m >= 1 (see immediately above). + - L = ceil((ceil(log2(p)) + k) / 8), where k is the security + parameter of the suite (e.g., k = 128). + - expand_message, a function that expands a byte string and + domain separation tag into a uniformly random byte string + (see Section 5.3). + """ + DST: bytes p: int # m: int - is always 1 L: int Hash: HashAlgorithm + # expand_message is always xmd + + def expand_message_xmd(self, msg: bytes, len_in_bytes: int): + """ + expand_message_xmd(msg, DST, len_in_bytes) + + Parameters: + - H, a hash function (see requirements above). + - b_in_bytes, b / 8 for b the output size of H in bits. + For example, for b = 256, b_in_bytes = 32. + - s_in_bytes, the input block size of H, measured in bytes (see + discussion above). For example, for SHA-256, s_in_bytes = 64. + + Input: + - msg, a byte string. + - DST, a byte string of at most 255 bytes. + See below for information on using longer DSTs. + - len_in_bytes, the length of the requested output in bytes, + not greater than the lesser of (255 * b_in_bytes) or 2^16-1. + + Output: + - uniform_bytes, a byte string. + + Steps: + 1. ell = ceil(len_in_bytes / b_in_bytes) + 2. ABORT if ell > 255 or len_in_bytes > 65535 or len(DST) > 255 + 3. DST_prime = DST || I2OSP(len(DST), 1) + 4. Z_pad = I2OSP(0, s_in_bytes) + 5. l_i_b_str = I2OSP(len_in_bytes, 2) + 6. msg_prime = Z_pad || msg || l_i_b_str || I2OSP(0, 1) || DST_prime + 7. b_0 = H(msg_prime) + 8. b_1 = H(b_0 || I2OSP(1, 1) || DST_prime) + 9. for i in (2, ..., ell): + 10. b_i = H(strxor(b_0, b_(i - 1)) || I2OSP(i, 1) || DST_prime) + 11. uniform_bytes = b_1 || ... || b_ell + 12. return substr(uniform_bytes, 0, len_in_bytes) + """ + b_in_bytes = self.Hash.digest_size + + ell = -(-len_in_bytes // b_in_bytes) + if ell > 255 or len_in_bytes > 65535 or len(self.DST) > 255: + raise ValueError("Invalid size of input/output") - def expand_message_xmd(self, msg, out_len): - ell = -(-out_len // self.Hash.digest_size) dst_prime = self.DST + struct.pack(">B", len(self.DST)) z_pad = b"\x00" * self.Hash.block_size - l_i_b_str = struct.pack(">H", out_len) + l_i_b_str = struct.pack(">H", len_in_bytes) msg_prime = z_pad + msg + l_i_b_str + b"\x00" + dst_prime d = Hash(self.Hash) d.update(msg_prime) @@ -95,11 +102,33 @@ def expand_message_xmd(self, msg, out_len): b_i = d.finalize() uniform_bytes.extend(b_i) b_xor = strxor(b_0, b_i) - return uniform_bytes[:out_len] + return uniform_bytes[:len_in_bytes] - def hash_to_field(self, msg, count): - elements = list() + def hash_to_field(self, msg: bytes, count: int) -> Sequence[int]: + """ + hash_to_field(msg, count) + + Input: + - msg, a byte string containing the message to hash. + - count, the number of elements of F to output. + + Output: + - (u_0, ..., u_(count - 1)), a list of field elements. + + Steps: + 1. len_in_bytes = count * m * L + 2. uniform_bytes = expand_message(msg, DST, len_in_bytes) + 3. for i in (0, ..., count - 1): + 4. for j in (0, ..., m - 1): + 5. elm_offset = L * (j + i * m) + 6. tv = substr(uniform_bytes, elm_offset, L) + 7. e_j = OS2IP(tv) mod p + 8. u_i = (e_0, ..., e_(m - 1)) + 9. return (u_0, ..., u_(count - 1)) + """ + # Only implemented for m = 1 uniform_bytes = self.expand_message_xmd(msg, count * self.L) + elements = list() for i in range(count): offset = self.L * i tv = uniform_bytes[offset : offset + self.L] @@ -132,9 +161,10 @@ def blind_public_key(self, pk: Point, tau: bytes, info: bytes) -> Point: """ dst = b"ARKG-BL-EC." + self.DST_ext + info htf = HTF(dst, self.crv.q, 48, self.Hash()) - tau_p = htf.hash_to_field(tau, 1)[0] - pk_tau = pk + (tau_p * self.crv.G) + tau_prime = htf.hash_to_field(tau, 1)[0] + + pk_tau = pk + (tau_prime * self.crv.G) return pk_tau @@ -170,6 +200,7 @@ def sub_kem_encaps(self, pk: Point, info: bytes) -> Tuple[bytes, bytes]: c = Elliptic-Curve-Point-to-Octet-String(pk') """ pk_prime, sk_prime = self.sub_kem_generate() + # TODO: Don't hardcode length k = int2bytes((pk * sk_prime).x, 32) c = SEC1Encoder().encode_public_key(pk_prime, compressed=False) @@ -208,13 +239,23 @@ def encaps(self, pk: Point, info: bytes) -> Tuple[bytes, bytes]: info_sub = b"ARKG-KEM-HMAC." + dst_ext + info k_prime, c_prime = self.sub_kem_encaps(pk, info_sub) - mk = HKDF(h, 32, None, b"ARKG-KEM-HMAC-mac." + dst_ext + info).derive(k_prime) + mk = HKDF( + h, + h.digest_size, + None, + b"ARKG-KEM-HMAC-mac." + dst_ext + info, + ).derive(k_prime) hmac = HMAC(mk, h) hmac.update(c_prime) - t = hmac.finalize()[:16] + t = hmac.finalize()[:16] # Truncate to 128-bit - k = HKDF(h, 32, None, b"ARKG-KEM-HMAC-shared." + dst_ext + info).derive(k_prime) + k = HKDF( + h, + len(k_prime), + None, + b"ARKG-KEM-HMAC-shared." + dst_ext + info, + ).derive(k_prime) c = t + c_prime @@ -289,10 +330,6 @@ def derive_public_key( DST_ext: 'ARKG-P256ADD-ECDH'. """ -arkg_p256_ecdh = ARKG( - bl=BL(crv=P256, Hash=SHA256, DST_ext=b"ARKG-P256ADD-ECDH"), - kem=KEM(crv=P256, Hash=SHA256, DST_ext=b"ARKG-P256ADD-ECDH"), -) def _cose2point(cose): @@ -302,6 +339,10 @@ def _cose2point(cose): class ARKG_P256ADD_ECDH(CoseKey): ALGORITHM = -65539 _HASH_ALG = SHA256() + _ARKG = ARKG( + bl=BL(crv=P256, Hash=SHA256, DST_ext=b"ARKG-P256ADD-ECDH"), + kem=KEM(crv=P256, Hash=SHA256, DST_ext=b"ARKG-P256ADD-ECDH"), + ) @property def blinding_key(self) -> CoseKey: @@ -312,7 +353,7 @@ def kem_key(self) -> CoseKey: return CoseKey.parse(self[-2]) def derive_public_key(self, info: bytes) -> Tuple[CoseKey, bytes]: - point, kh = arkg_p256_ecdh.derive_public_key( + point, kh = self._ARKG.derive_public_key( _cose2point(self.blinding_key), _cose2point(self.kem_key), info,