From c343dc235b688de076b2bde0e719262a813ea74f Mon Sep 17 00:00:00 2001 From: LLFourn Date: Fri, 1 Jul 2022 11:15:49 +0800 Subject: [PATCH] Make keygen a two stage process The motivation is to simplify critical parts of the implementation, especially signing, verifying and tweak application. Where you generate an ordinary key first, apply plain tweaks, then convert it to a BIP340 xonly key and can apply xonly tweaks. --- bip-musig2/reference.py | 155 +++++++++++++++++++++++----------------- 1 file changed, 89 insertions(+), 66 deletions(-) diff --git a/bip-musig2/reference.py b/bip-musig2/reference.py index 1a796abad1..035ba78976 100644 --- a/bip-musig2/reference.py +++ b/bip-musig2/reference.py @@ -159,11 +159,13 @@ def cpoint_extended(x: bytes) -> Optional[Point]: else: return cpoint(x) -KeyGenContext = namedtuple('KeyGenContext', ['Q', 'gacc', 'tacc']) +KeyGenContext = namedtuple('KeyGenContext', ['Q', 'tacc']) +XOnlyKeyGenContext = namedtuple('XOnlyKeyGenContext', ['Q', 'gacc', 'tacc']) -def get_pk(keygen_ctx: KeyGenContext) -> bytes: - Q, _, _ = keygen_ctx - return bytes_from_point(Q) +# HACK: return the public key for either context +def get_pk(keygen_ctx): + Q, *_ = keygen_ctx + return Q def key_agg(pubkeys: List[bytes]) -> KeyGenContext: pk2 = get_second_key(pubkeys) @@ -177,9 +179,15 @@ def key_agg(pubkeys: List[bytes]) -> KeyGenContext: Q = point_add(Q, point_mul(P_i, a_i)) # Q is not the point at infinity except with negligible probability. assert(Q is not None) - gacc = 1 tacc = 0 - return KeyGenContext(Q, gacc, tacc) + return KeyGenContext(Q, tacc) + +def to_bip340_context(keygen_ctx: KeyGenContext) -> XOnlyKeyGenContext: + Q, tacc = keygen_ctx + gacc = 1 if has_even_y(Q) else n - 1 + _tacc = gacc * tacc % n + _Q = bytes_from_point(Q) + return XOnlyKeyGenContext(_Q, gacc, _tacc) def hash_keys(pubkeys: List[bytes]) -> bytes: return tagged_hash('KeyAgg list', b''.join(pubkeys)) @@ -201,23 +209,36 @@ def key_agg_coeff_internal(pubkeys: List[bytes], pk_: bytes, pk2: bytes) -> int: return 1 return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n -def apply_tweak(keygen_ctx: KeyGenContext, tweak: bytes, is_xonly: bool) -> KeyGenContext: +def apply_plain_tweak(keygen_ctx: KeyGenContext, tweak: bytes) -> KeyGenContext: + Q, tacc = keygen_ctx if len(tweak) != 32: raise ValueError('The tweak must be a 32-byte array.') + t = int_from_bytes(tweak) + if t >= n: + raise ValueError('The tweak must be less than n.') + + Q_ = point_add(Q, point_mul(G, t)) + if Q_ is None: + raise ValueError('The result of tweaking cannot be infinity.') + tacc_ = t + tacc % n + return KeyGenContext(Q_, tacc_) + + +def apply_xonly_tweak(keygen_ctx: XOnlyKeyGenContext, tweak: bytes) -> XOnlyKeyGenContext: Q, gacc, tacc = keygen_ctx - if is_xonly and not has_even_y(Q): - g = n - 1 - else: - g = 1 + if len(tweak) != 32: + raise ValueError('The tweak must be a 32-byte array.') t = int_from_bytes(tweak) if t >= n: raise ValueError('The tweak must be less than n.') - Q_ = point_add(point_mul(Q, g), point_mul(G, t)) + Q_ = point_add(lift_x(Q), point_mul(G, t)) if Q_ is None: raise ValueError('The result of tweaking cannot be infinity.') + g = 1 if has_even_y(Q_) else n - 1 gacc_ = g * gacc % n - tacc_ = (t + g * tacc) % n - return KeyGenContext(Q_, gacc_, tacc_) + tacc_ = g * (tacc + t) % n + return XOnlyKeyGenContext(bytes_from_point(Q_), gacc_, tacc_) + def bytes_xor(a: bytes, b: bytes) -> bytes: return bytes(x ^ y for x, y in zip(a, b)) @@ -276,19 +297,24 @@ def nonce_agg(pubnonces: List[bytes]) -> bytes: aggnonce += cbytes_extended(R_i) return aggnonce -SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'tweaks', 'is_xonly', 'msg']) +SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'plain_tweaks', 'xonly_tweaks', 'msg']) -def key_agg_and_tweak(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]): +def key_agg_and_tweak(pubkeys: List[bytes], plain_tweaks: List[bytes], xonly_tweaks: List[bytes]) -> XOnlyKeyGenContext: keygen_ctx = key_agg(pubkeys) - v = len(tweaks) - for i in range(v): - keygen_ctx = apply_tweak(keygen_ctx, tweaks[i], is_xonly[i]) + for i in range(len(plain_tweaks)): + keygen_ctx = apply_plain_tweak(keygen_ctx, plain_tweaks[i]) + + keygen_ctx = to_bip340_context(keygen_ctx) + + for i in range(len(xonly_tweaks)): + keygen_ctx = apply_xonly_tweak(keygen_ctx, xonly_tweaks[i]) + return keygen_ctx def get_session_values(session_ctx: SessionContext) -> Tuple[Point, int, int, int, Point, int]: - (aggnonce, pubkeys, tweaks, is_xonly, msg) = session_ctx - Q, gacc, tacc = key_agg_and_tweak(pubkeys, tweaks, is_xonly) - b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n + (aggnonce, pubkeys, plain_tweaks, xonly_tweaks, msg) = session_ctx + Q, gacc, tacc = key_agg_and_tweak(pubkeys, plain_tweaks, xonly_tweaks) + b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + Q + msg)) % n try: R_1 = cpoint_extended(aggnonce[0:33]) R_2 = cpoint_extended(aggnonce[33:66]) @@ -298,7 +324,7 @@ def get_session_values(session_ctx: SessionContext) -> Tuple[Point, int, int, in R_ = point_add(R_1, point_mul(R_2, b)) R = R_ if not is_infinite(R_) else G assert R is not None - e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n + e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + Q + msg)) % n return (Q, gacc, tacc, b, R, e) def get_session_key_agg_coeff(session_ctx: SessionContext, P: Point) -> int: @@ -323,8 +349,7 @@ def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes: assert P is not None a = get_session_key_agg_coeff(session_ctx, P) gp = 1 if has_even_y(P) else n - 1 - g = 1 if has_even_y(Q) else n - 1 - d = g * gacc * gp * d_ % n + d = gacc * gp * d_ % n s = (k_1 + b * k_2 + e * a * d) % n psig = bytes_from_int(s) R_1_ = point_mul(G, k_1_) @@ -336,9 +361,9 @@ def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes: assert partial_sig_verify_internal(psig, pubnonce, bytes_from_point(P), session_ctx) return psig -def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, i: int) -> bool: +def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], plain_tweaks: List[bytes], xonly_tweaks: List[bytes], msg: bytes, i: int) -> bool: aggnonce = nonce_agg(pubnonces) - session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + session_ctx = SessionContext(aggnonce, pubkeys, plain_tweaks, xonly_tweaks, msg) return partial_sig_verify_internal(psig, pubnonces[i], pubkeys[i], session_ctx) def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk_: bytes, session_ctx: SessionContext) -> bool: @@ -350,9 +375,7 @@ def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk_: bytes, sessio R_2_ = cpoint(pubnonce[33:66]) R__ = point_add(R_1_, point_mul(R_2_, b)) R_ = R__ if has_even_y(R) else point_negate(R__) - g = 1 if has_even_y(Q) else n - 1 - g_ = g * gacc % n - P = point_mul(lift_x(pk_), g_) + P = point_mul(lift_x(pk_), gacc) if P is None: return False a = get_session_key_agg_coeff(session_ctx, P) @@ -367,8 +390,7 @@ def partial_sig_agg(psigs: List[bytes], session_ctx: SessionContext) -> bytes: if s_i >= n: raise InvalidContributionError(i, "psig") s = (s + s_i) % n - g = 1 if has_even_y(Q) else n - 1 - s = (s + e * g * tacc) % n + s = (s + e * tacc) % n return bytes_from_point(R) + bytes_from_int(s) # # The following code is only used for testing. @@ -403,13 +425,13 @@ def test_key_agg_vectors(): ]) # Vector 1 - assert get_pk(key_agg([X[0], X[1], X[2]])) == expected[0] + assert get_pk(to_bip340_context(key_agg([X[0], X[1], X[2]]))) == expected[0] # Vector 2 - assert get_pk(key_agg([X[2], X[1], X[0]])) == expected[1] + assert get_pk(to_bip340_context(key_agg([X[2], X[1], X[0]]))) == expected[1] # Vector 3 - assert get_pk(key_agg([X[0], X[0], X[0]])) == expected[2] + assert get_pk(to_bip340_context(key_agg([X[0], X[0], X[0]]))) == expected[2] # Vector 4 - assert get_pk(key_agg([X[0], X[0], X[1], X[1]])) == expected[3] + assert get_pk(to_bip340_context(key_agg([X[0], X[0], X[1], X[1]]))) == expected[3] # Vector 5: Invalid public key invalid_pk = bytes.fromhex('0000000000000000000000000000000000000000000000000000000000000005') @@ -424,13 +446,13 @@ def test_key_agg_vectors(): # Vector 7: Tweak is out of range invalid_tweak = bytes.fromhex('FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141') assertRaises(ValueError, - lambda: key_agg_and_tweak([X[0], X[1]], [invalid_tweak], [True]), + lambda: key_agg_and_tweak([X[0], X[1]], [], [invalid_tweak]), lambda e: str(e) == 'The tweak must be less than n.') # Vector 8: Intermediate tweaking result is point at infinity G_ = bytes_from_point(G) coeff = bytes_from_int(n - key_agg_coeff([G_], G_)) assertRaises(ValueError, - lambda: key_agg_and_tweak([G_], [coeff], [False]), + lambda: key_agg_and_tweak([G_], [coeff], []), lambda e: str(e) == 'The result of tweaking cannot be infinity.') def test_nonce_gen_vectors(): @@ -672,31 +694,31 @@ def test_tweak_vectors(): pk = bytes_from_point(point_mul(G, int_from_bytes(sk))) # Vector 1: A single x-only tweak - session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:1], [True], msg) + session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], [], tweaks[:1], msg) assert sign(secnonce, sk, session_ctx) == expected[0] # WARNING: An actual implementation should clear the secnonce after use, # e.g. by setting secnonce = bytes(64) after usage. Reusing the secnonce, as # we do here for testing purposes, can leak the secret key. - assert partial_sig_verify(expected[0], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:1], [True], msg, 2) + assert partial_sig_verify(expected[0], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], [], tweaks[:1], msg, 2) # Vector 2: A single plain tweak - session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:1], [False], msg) + session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:1], [], msg) assert sign(secnonce, sk, session_ctx) == expected[1] - assert partial_sig_verify(expected[1], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:1], [False], msg, 2) + assert partial_sig_verify(expected[1], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:1], [], msg, 2) # Vector 3: A plain tweak followed by an x-only tweak - session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:2], [False, True], msg) + session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:1], tweaks[1:2], msg) assert sign(secnonce, sk, session_ctx) == expected[2] - assert partial_sig_verify(expected[2], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:2], [False, True], msg, 2) + assert partial_sig_verify(expected[2], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:1], tweaks[1:2], msg, 2) - # Vector 4: Four tweaks: x-only, plain, x-only, plain - session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:4], [True, False, True, False], msg) - assert sign(secnonce, sk, session_ctx) == expected[3] - assert partial_sig_verify(expected[3], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:4], [True, False, True, False], msg, 2) + # # Vector 4: Four tweaks: x-only, plain, x-only, plain + # session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:4], [True, False, True, False], msg) + # assert sign(secnonce, sk, session_ctx) == expected[3] + # assert partial_sig_verify(expected[3], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:4], [True, False, True, False], msg, 2) # Vector 5: Tweak is invalid because it exceeds group size invalid_tweak = bytes.fromhex('FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141') - session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], [invalid_tweak], [False], msg) + session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], [invalid_tweak], [], msg) assertRaises(ValueError, lambda: sign(secnonce, sk, session_ctx), lambda e: str(e) == 'The tweak must be less than n.') @@ -774,29 +796,29 @@ def test_sig_agg_vectors(): session_ctx = SessionContext(aggnonce[0], [X[0], X[1]], [], [], msg) sig = partial_sig_agg([psig[0], psig[1]], session_ctx) assert sig == expected[0] - aggpk = get_pk(key_agg([X[0], X[1]])) + aggpk = get_pk(to_bip340_context(key_agg([X[0], X[1]]))) assert schnorr_verify(msg, aggpk, sig) # Vector 2 session_ctx = SessionContext(aggnonce[1], [X[0], X[2]], [], [], msg) sig = partial_sig_agg([psig[2], psig[3]], session_ctx) assert sig == expected[1] - aggpk = get_pk(key_agg([X[0], X[2]])) + aggpk = get_pk(to_bip340_context(key_agg([X[0], X[2]]))) assert schnorr_verify(msg, aggpk, sig) # Vector 3 - session_ctx = SessionContext(aggnonce[2], [X[0], X[2]], [tweaks[0]], [False], msg) + session_ctx = SessionContext(aggnonce[2], [X[0], X[2]], [tweaks[0]], [], msg) sig = partial_sig_agg([psig[4], psig[5]], session_ctx) assert sig == expected[2] - aggpk = get_pk(key_agg_and_tweak([X[0], X[2]], [tweaks[0]], [False])) + aggpk = get_pk(key_agg_and_tweak([X[0], X[2]], [tweaks[0]], [])) assert schnorr_verify(msg, aggpk, sig) # Vector 4 - session_ctx = SessionContext(aggnonce[3], [X[0], X[3]], tweaks, [True, False, True], msg) - sig = partial_sig_agg([psig[6], psig[7]], session_ctx) - assert sig == expected[3] - aggpk = get_pk(key_agg_and_tweak([X[0], X[3]], tweaks, [True, False, True])) - assert schnorr_verify(msg, aggpk, sig) + # session_ctx = SessionContext(aggnonce[3], [X[0], X[3]], tweaks, [True, False, True], msg) + # sig = partial_sig_agg([psig[6], psig[7]], session_ctx) + # assert sig == expected[3] + # aggpk = get_pk(key_agg_and_tweak([X[0], X[3]], tweaks, [True, False, True])) + # assert schnorr_verify(msg, aggpk, sig) # Vector 5: Partial signature is invalid because it exceeds group size invalid_psig = bytes.fromhex('FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141') @@ -822,9 +844,10 @@ def test_sign_and_verify_random(iters): # instead. msg = secrets.token_bytes(32) v = secrets.randbelow(4) - tweaks = [secrets.token_bytes(32) for _ in range(v)] - is_xonly = [secrets.choice([False, True]) for _ in range(v)] - aggpk = get_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly)) + plain_tweaks = [secrets.token_bytes(32) for _ in range(v)] + v = secrets.randbelow(4) + xonly_tweaks = [secrets.token_bytes(32) for _ in range(v)] + aggpk = get_pk(key_agg_and_tweak(pubkeys, plain_tweaks, xonly_tweaks)) # Use a non-repeating counter for extra_in secnonce_1, pubnonce_1 = nonce_gen(sk_1, aggpk, msg, i.to_bytes(4, 'big')) @@ -836,22 +859,22 @@ def test_sign_and_verify_random(iters): pubnonces = [pubnonce_1, pubnonce_2] aggnonce = nonce_agg(pubnonces) - session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg) + session_ctx = SessionContext(aggnonce, pubkeys, plain_tweaks, xonly_tweaks, msg) psig_1 = sign(secnonce_1, sk_1, session_ctx) # Clear the secnonce after use secnonce_1 = bytes(64) - assert partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, msg, 0) + assert partial_sig_verify(psig_1, pubnonces, pubkeys, plain_tweaks, xonly_tweaks, msg, 0) # Wrong signer index - assert not partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, msg, 1) + assert not partial_sig_verify(psig_1, pubnonces, pubkeys, plain_tweaks, xonly_tweaks, msg, 1) # Wrong message - assert not partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, secrets.token_bytes(32), 0) + assert not partial_sig_verify(psig_1, pubnonces, pubkeys, plain_tweaks, xonly_tweaks, secrets.token_bytes(32), 0) psig_2 = sign(secnonce_2, sk_2, session_ctx) # Clear the secnonce after use secnonce_2 = bytes(64) - assert partial_sig_verify(psig_2, pubnonces, pubkeys, tweaks, is_xonly, msg, 1) + assert partial_sig_verify(psig_2, pubnonces, pubkeys, plain_tweaks, xonly_tweaks, msg, 1) sig = partial_sig_agg([psig_1, psig_2], session_ctx) assert schnorr_verify(msg, aggpk, sig)