Skip to content

Commit

Permalink
Code cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
dainnilsson committed Oct 22, 2024
1 parent 9dc0019 commit 9dbefee
Showing 1 changed file with 107 additions and 66 deletions.
173 changes: 107 additions & 66 deletions fido2/arkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,78 +11,85 @@
import fastecdsa.keys

from dataclasses import dataclass
from typing import Tuple
from typing import Tuple, Sequence
import struct


def strxor(a, b):
def strxor(a: bytes, b: bytes) -> bytes:
c = bytearray(len(a))
for i in range(len(a)):
c[i] = a[i] ^ b[i]
return c


def expand_message_xmd(H, msg, dst, out_len):
ell = -(-out_len // H.digest_size)
dst_prime = dst + struct.pack(">B", len(dst))
z_pad = b"\x00" * H.block_size
l_i_b_str = struct.pack(">H", out_len)
msg_prime = z_pad + msg + l_i_b_str + b"\x00" + dst_prime
d = Hash(H)
d.update(msg_prime)
b_0 = d.finalize()
b_xor = b_0
uniform_bytes = bytearray()
for i in range(1, ell + 1):
d = Hash(H)
d.update(b_xor + struct.pack(">B", i) + dst_prime)
b_i = d.finalize()
uniform_bytes.extend(b_i)
b_xor = strxor(b_0, b_i)
return uniform_bytes[:out_len]


"""
hash_to_field(msg, count)
Parameters:
- DST, a domain separation tag (see Section 3.1).
- F, a finite field of characteristic p and order q = p^m.
- p, the characteristic of F (see immediately above).
- m, the extension degree of F, m >= 1 (see immediately above).
- L = ceil((ceil(log2(p)) + k) / 8), where k is the security
parameter of the suite (e.g., k = 128).
- expand_message, a function that expands a byte string and
domain separation tag into a uniformly random byte string
(see Section 5.3).
"""


# m is always 1
def hash_to_field(msg, count, dst, p, L, H):
elements = list()
uniform_bytes = expand_message_xmd(H, msg, dst, count * L)
for i in range(count):
offset = L * i
tv = uniform_bytes[offset : offset + L]
e_j = bytes2int(tv) % p
elements.append(e_j)
return elements


@dataclass
class HTF:
"""
hash_to_field(msg, count)
Parameters:
- DST, a domain separation tag (see Section 3.1).
- F, a finite field of characteristic p and order q = p^m.
- p, the characteristic of F (see immediately above).
- m, the extension degree of F, m >= 1 (see immediately above).
- L = ceil((ceil(log2(p)) + k) / 8), where k is the security
parameter of the suite (e.g., k = 128).
- expand_message, a function that expands a byte string and
domain separation tag into a uniformly random byte string
(see Section 5.3).
"""

DST: bytes
p: int
# m: int - is always 1
L: int
Hash: HashAlgorithm
# expand_message is always xmd

def expand_message_xmd(self, msg: bytes, len_in_bytes: int):
"""
expand_message_xmd(msg, DST, len_in_bytes)
Parameters:
- H, a hash function (see requirements above).
- b_in_bytes, b / 8 for b the output size of H in bits.
For example, for b = 256, b_in_bytes = 32.
- s_in_bytes, the input block size of H, measured in bytes (see
discussion above). For example, for SHA-256, s_in_bytes = 64.
Input:
- msg, a byte string.
- DST, a byte string of at most 255 bytes.
See below for information on using longer DSTs.
- len_in_bytes, the length of the requested output in bytes,
not greater than the lesser of (255 * b_in_bytes) or 2^16-1.
Output:
- uniform_bytes, a byte string.
Steps:
1. ell = ceil(len_in_bytes / b_in_bytes)
2. ABORT if ell > 255 or len_in_bytes > 65535 or len(DST) > 255
3. DST_prime = DST || I2OSP(len(DST), 1)
4. Z_pad = I2OSP(0, s_in_bytes)
5. l_i_b_str = I2OSP(len_in_bytes, 2)
6. msg_prime = Z_pad || msg || l_i_b_str || I2OSP(0, 1) || DST_prime
7. b_0 = H(msg_prime)
8. b_1 = H(b_0 || I2OSP(1, 1) || DST_prime)
9. for i in (2, ..., ell):
10. b_i = H(strxor(b_0, b_(i - 1)) || I2OSP(i, 1) || DST_prime)
11. uniform_bytes = b_1 || ... || b_ell
12. return substr(uniform_bytes, 0, len_in_bytes)
"""
b_in_bytes = self.Hash.digest_size

ell = -(-len_in_bytes // b_in_bytes)
if ell > 255 or len_in_bytes > 65535 or len(self.DST) > 255:
raise ValueError("Invalid size of input/output")

def expand_message_xmd(self, msg, out_len):
ell = -(-out_len // self.Hash.digest_size)
dst_prime = self.DST + struct.pack(">B", len(self.DST))
z_pad = b"\x00" * self.Hash.block_size
l_i_b_str = struct.pack(">H", out_len)
l_i_b_str = struct.pack(">H", len_in_bytes)
msg_prime = z_pad + msg + l_i_b_str + b"\x00" + dst_prime
d = Hash(self.Hash)
d.update(msg_prime)
Expand All @@ -95,11 +102,33 @@ def expand_message_xmd(self, msg, out_len):
b_i = d.finalize()
uniform_bytes.extend(b_i)
b_xor = strxor(b_0, b_i)
return uniform_bytes[:out_len]
return uniform_bytes[:len_in_bytes]

def hash_to_field(self, msg, count):
elements = list()
def hash_to_field(self, msg: bytes, count: int) -> Sequence[int]:
"""
hash_to_field(msg, count)
Input:
- msg, a byte string containing the message to hash.
- count, the number of elements of F to output.
Output:
- (u_0, ..., u_(count - 1)), a list of field elements.
Steps:
1. len_in_bytes = count * m * L
2. uniform_bytes = expand_message(msg, DST, len_in_bytes)
3. for i in (0, ..., count - 1):
4. for j in (0, ..., m - 1):
5. elm_offset = L * (j + i * m)
6. tv = substr(uniform_bytes, elm_offset, L)
7. e_j = OS2IP(tv) mod p
8. u_i = (e_0, ..., e_(m - 1))
9. return (u_0, ..., u_(count - 1))
"""
# Only implemented for m = 1
uniform_bytes = self.expand_message_xmd(msg, count * self.L)
elements = list()
for i in range(count):
offset = self.L * i
tv = uniform_bytes[offset : offset + self.L]
Expand Down Expand Up @@ -132,9 +161,10 @@ def blind_public_key(self, pk: Point, tau: bytes, info: bytes) -> Point:
"""
dst = b"ARKG-BL-EC." + self.DST_ext + info
htf = HTF(dst, self.crv.q, 48, self.Hash())
tau_p = htf.hash_to_field(tau, 1)[0]

pk_tau = pk + (tau_p * self.crv.G)
tau_prime = htf.hash_to_field(tau, 1)[0]

pk_tau = pk + (tau_prime * self.crv.G)
return pk_tau


Expand Down Expand Up @@ -170,6 +200,7 @@ def sub_kem_encaps(self, pk: Point, info: bytes) -> Tuple[bytes, bytes]:
c = Elliptic-Curve-Point-to-Octet-String(pk')
"""
pk_prime, sk_prime = self.sub_kem_generate()
# TODO: Don't hardcode length
k = int2bytes((pk * sk_prime).x, 32)
c = SEC1Encoder().encode_public_key(pk_prime, compressed=False)

Expand Down Expand Up @@ -208,13 +239,23 @@ def encaps(self, pk: Point, info: bytes) -> Tuple[bytes, bytes]:
info_sub = b"ARKG-KEM-HMAC." + dst_ext + info
k_prime, c_prime = self.sub_kem_encaps(pk, info_sub)

mk = HKDF(h, 32, None, b"ARKG-KEM-HMAC-mac." + dst_ext + info).derive(k_prime)
mk = HKDF(
h,
h.digest_size,
None,
b"ARKG-KEM-HMAC-mac." + dst_ext + info,
).derive(k_prime)

hmac = HMAC(mk, h)
hmac.update(c_prime)
t = hmac.finalize()[:16]
t = hmac.finalize()[:16] # Truncate to 128-bit

k = HKDF(h, 32, None, b"ARKG-KEM-HMAC-shared." + dst_ext + info).derive(k_prime)
k = HKDF(
h,
len(k_prime),
None,
b"ARKG-KEM-HMAC-shared." + dst_ext + info,
).derive(k_prime)

c = t + c_prime

Expand Down Expand Up @@ -289,10 +330,6 @@ def derive_public_key(
DST_ext: 'ARKG-P256ADD-ECDH'.
"""
arkg_p256_ecdh = ARKG(
bl=BL(crv=P256, Hash=SHA256, DST_ext=b"ARKG-P256ADD-ECDH"),
kem=KEM(crv=P256, Hash=SHA256, DST_ext=b"ARKG-P256ADD-ECDH"),
)


def _cose2point(cose):
Expand All @@ -302,6 +339,10 @@ def _cose2point(cose):
class ARKG_P256ADD_ECDH(CoseKey):
ALGORITHM = -65539
_HASH_ALG = SHA256()
_ARKG = ARKG(
bl=BL(crv=P256, Hash=SHA256, DST_ext=b"ARKG-P256ADD-ECDH"),
kem=KEM(crv=P256, Hash=SHA256, DST_ext=b"ARKG-P256ADD-ECDH"),
)

@property
def blinding_key(self) -> CoseKey:
Expand All @@ -312,7 +353,7 @@ def kem_key(self) -> CoseKey:
return CoseKey.parse(self[-2])

def derive_public_key(self, info: bytes) -> Tuple[CoseKey, bytes]:
point, kh = arkg_p256_ecdh.derive_public_key(
point, kh = self._ARKG.derive_public_key(
_cose2point(self.blinding_key),
_cose2point(self.kem_key),
info,
Expand Down

0 comments on commit 9dbefee

Please sign in to comment.