Skip to content

Commit

Permalink
Make keygen a two stage process
Browse files Browse the repository at this point in the history
The motivation is to simplify critical parts of the implementation,
especially signing, verifying and tweak application.

Where you generate an ordinary key first, apply plain tweaks, then
convert it to a BIP340 xonly key and can apply xonly tweaks.
  • Loading branch information
LLFourn committed Jul 1, 2022
1 parent 3a4f168 commit c343dc2
Showing 1 changed file with 89 additions and 66 deletions.
155 changes: 89 additions & 66 deletions bip-musig2/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,13 @@ def cpoint_extended(x: bytes) -> Optional[Point]:
else:
return cpoint(x)

KeyGenContext = namedtuple('KeyGenContext', ['Q', 'gacc', 'tacc'])
KeyGenContext = namedtuple('KeyGenContext', ['Q', 'tacc'])
XOnlyKeyGenContext = namedtuple('XOnlyKeyGenContext', ['Q', 'gacc', 'tacc'])

def get_pk(keygen_ctx: KeyGenContext) -> bytes:
Q, _, _ = keygen_ctx
return bytes_from_point(Q)
# HACK: return the public key for either context
def get_pk(keygen_ctx):
Q, *_ = keygen_ctx
return Q

def key_agg(pubkeys: List[bytes]) -> KeyGenContext:
pk2 = get_second_key(pubkeys)
Expand All @@ -177,9 +179,15 @@ def key_agg(pubkeys: List[bytes]) -> KeyGenContext:
Q = point_add(Q, point_mul(P_i, a_i))
# Q is not the point at infinity except with negligible probability.
assert(Q is not None)
gacc = 1
tacc = 0
return KeyGenContext(Q, gacc, tacc)
return KeyGenContext(Q, tacc)

def to_bip340_context(keygen_ctx: KeyGenContext) -> XOnlyKeyGenContext:
Q, tacc = keygen_ctx
gacc = 1 if has_even_y(Q) else n - 1
_tacc = gacc * tacc % n
_Q = bytes_from_point(Q)
return XOnlyKeyGenContext(_Q, gacc, _tacc)

def hash_keys(pubkeys: List[bytes]) -> bytes:
return tagged_hash('KeyAgg list', b''.join(pubkeys))
Expand All @@ -201,23 +209,36 @@ def key_agg_coeff_internal(pubkeys: List[bytes], pk_: bytes, pk2: bytes) -> int:
return 1
return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n

def apply_tweak(keygen_ctx: KeyGenContext, tweak: bytes, is_xonly: bool) -> KeyGenContext:
def apply_plain_tweak(keygen_ctx: KeyGenContext, tweak: bytes) -> KeyGenContext:
Q, tacc = keygen_ctx
if len(tweak) != 32:
raise ValueError('The tweak must be a 32-byte array.')
t = int_from_bytes(tweak)
if t >= n:
raise ValueError('The tweak must be less than n.')

Q_ = point_add(Q, point_mul(G, t))
if Q_ is None:
raise ValueError('The result of tweaking cannot be infinity.')
tacc_ = t + tacc % n
return KeyGenContext(Q_, tacc_)


def apply_xonly_tweak(keygen_ctx: XOnlyKeyGenContext, tweak: bytes) -> XOnlyKeyGenContext:
Q, gacc, tacc = keygen_ctx
if is_xonly and not has_even_y(Q):
g = n - 1
else:
g = 1
if len(tweak) != 32:
raise ValueError('The tweak must be a 32-byte array.')
t = int_from_bytes(tweak)
if t >= n:
raise ValueError('The tweak must be less than n.')
Q_ = point_add(point_mul(Q, g), point_mul(G, t))
Q_ = point_add(lift_x(Q), point_mul(G, t))
if Q_ is None:
raise ValueError('The result of tweaking cannot be infinity.')
g = 1 if has_even_y(Q_) else n - 1
gacc_ = g * gacc % n
tacc_ = (t + g * tacc) % n
return KeyGenContext(Q_, gacc_, tacc_)
tacc_ = g * (tacc + t) % n
return XOnlyKeyGenContext(bytes_from_point(Q_), gacc_, tacc_)


def bytes_xor(a: bytes, b: bytes) -> bytes:
return bytes(x ^ y for x, y in zip(a, b))
Expand Down Expand Up @@ -276,19 +297,24 @@ def nonce_agg(pubnonces: List[bytes]) -> bytes:
aggnonce += cbytes_extended(R_i)
return aggnonce

SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'tweaks', 'is_xonly', 'msg'])
SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'plain_tweaks', 'xonly_tweaks', 'msg'])

def key_agg_and_tweak(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]):
def key_agg_and_tweak(pubkeys: List[bytes], plain_tweaks: List[bytes], xonly_tweaks: List[bytes]) -> XOnlyKeyGenContext:
keygen_ctx = key_agg(pubkeys)
v = len(tweaks)
for i in range(v):
keygen_ctx = apply_tweak(keygen_ctx, tweaks[i], is_xonly[i])
for i in range(len(plain_tweaks)):
keygen_ctx = apply_plain_tweak(keygen_ctx, plain_tweaks[i])

keygen_ctx = to_bip340_context(keygen_ctx)

for i in range(len(xonly_tweaks)):
keygen_ctx = apply_xonly_tweak(keygen_ctx, xonly_tweaks[i])

return keygen_ctx

def get_session_values(session_ctx: SessionContext) -> Tuple[Point, int, int, int, Point, int]:
(aggnonce, pubkeys, tweaks, is_xonly, msg) = session_ctx
Q, gacc, tacc = key_agg_and_tweak(pubkeys, tweaks, is_xonly)
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
(aggnonce, pubkeys, plain_tweaks, xonly_tweaks, msg) = session_ctx
Q, gacc, tacc = key_agg_and_tweak(pubkeys, plain_tweaks, xonly_tweaks)
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + Q + msg)) % n
try:
R_1 = cpoint_extended(aggnonce[0:33])
R_2 = cpoint_extended(aggnonce[33:66])
Expand All @@ -298,7 +324,7 @@ def get_session_values(session_ctx: SessionContext) -> Tuple[Point, int, int, in
R_ = point_add(R_1, point_mul(R_2, b))
R = R_ if not is_infinite(R_) else G
assert R is not None
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + Q + msg)) % n
return (Q, gacc, tacc, b, R, e)

def get_session_key_agg_coeff(session_ctx: SessionContext, P: Point) -> int:
Expand All @@ -323,8 +349,7 @@ def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes:
assert P is not None
a = get_session_key_agg_coeff(session_ctx, P)
gp = 1 if has_even_y(P) else n - 1
g = 1 if has_even_y(Q) else n - 1
d = g * gacc * gp * d_ % n
d = gacc * gp * d_ % n
s = (k_1 + b * k_2 + e * a * d) % n
psig = bytes_from_int(s)
R_1_ = point_mul(G, k_1_)
Expand All @@ -336,9 +361,9 @@ def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes:
assert partial_sig_verify_internal(psig, pubnonce, bytes_from_point(P), session_ctx)
return psig

def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, i: int) -> bool:
def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], plain_tweaks: List[bytes], xonly_tweaks: List[bytes], msg: bytes, i: int) -> bool:
aggnonce = nonce_agg(pubnonces)
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
session_ctx = SessionContext(aggnonce, pubkeys, plain_tweaks, xonly_tweaks, msg)
return partial_sig_verify_internal(psig, pubnonces[i], pubkeys[i], session_ctx)

def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk_: bytes, session_ctx: SessionContext) -> bool:
Expand All @@ -350,9 +375,7 @@ def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk_: bytes, sessio
R_2_ = cpoint(pubnonce[33:66])
R__ = point_add(R_1_, point_mul(R_2_, b))
R_ = R__ if has_even_y(R) else point_negate(R__)
g = 1 if has_even_y(Q) else n - 1
g_ = g * gacc % n
P = point_mul(lift_x(pk_), g_)
P = point_mul(lift_x(pk_), gacc)
if P is None:
return False
a = get_session_key_agg_coeff(session_ctx, P)
Expand All @@ -367,8 +390,7 @@ def partial_sig_agg(psigs: List[bytes], session_ctx: SessionContext) -> bytes:
if s_i >= n:
raise InvalidContributionError(i, "psig")
s = (s + s_i) % n
g = 1 if has_even_y(Q) else n - 1
s = (s + e * g * tacc) % n
s = (s + e * tacc) % n
return bytes_from_point(R) + bytes_from_int(s)
#
# The following code is only used for testing.
Expand Down Expand Up @@ -403,13 +425,13 @@ def test_key_agg_vectors():
])

# Vector 1
assert get_pk(key_agg([X[0], X[1], X[2]])) == expected[0]
assert get_pk(to_bip340_context(key_agg([X[0], X[1], X[2]]))) == expected[0]
# Vector 2
assert get_pk(key_agg([X[2], X[1], X[0]])) == expected[1]
assert get_pk(to_bip340_context(key_agg([X[2], X[1], X[0]]))) == expected[1]
# Vector 3
assert get_pk(key_agg([X[0], X[0], X[0]])) == expected[2]
assert get_pk(to_bip340_context(key_agg([X[0], X[0], X[0]]))) == expected[2]
# Vector 4
assert get_pk(key_agg([X[0], X[0], X[1], X[1]])) == expected[3]
assert get_pk(to_bip340_context(key_agg([X[0], X[0], X[1], X[1]]))) == expected[3]

# Vector 5: Invalid public key
invalid_pk = bytes.fromhex('0000000000000000000000000000000000000000000000000000000000000005')
Expand All @@ -424,13 +446,13 @@ def test_key_agg_vectors():
# Vector 7: Tweak is out of range
invalid_tweak = bytes.fromhex('FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141')
assertRaises(ValueError,
lambda: key_agg_and_tweak([X[0], X[1]], [invalid_tweak], [True]),
lambda: key_agg_and_tweak([X[0], X[1]], [], [invalid_tweak]),
lambda e: str(e) == 'The tweak must be less than n.')
# Vector 8: Intermediate tweaking result is point at infinity
G_ = bytes_from_point(G)
coeff = bytes_from_int(n - key_agg_coeff([G_], G_))
assertRaises(ValueError,
lambda: key_agg_and_tweak([G_], [coeff], [False]),
lambda: key_agg_and_tweak([G_], [coeff], []),
lambda e: str(e) == 'The result of tweaking cannot be infinity.')

def test_nonce_gen_vectors():
Expand Down Expand Up @@ -672,31 +694,31 @@ def test_tweak_vectors():
pk = bytes_from_point(point_mul(G, int_from_bytes(sk)))

# Vector 1: A single x-only tweak
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:1], [True], msg)
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], [], tweaks[:1], msg)
assert sign(secnonce, sk, session_ctx) == expected[0]
# WARNING: An actual implementation should clear the secnonce after use,
# e.g. by setting secnonce = bytes(64) after usage. Reusing the secnonce, as
# we do here for testing purposes, can leak the secret key.
assert partial_sig_verify(expected[0], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:1], [True], msg, 2)
assert partial_sig_verify(expected[0], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], [], tweaks[:1], msg, 2)

# Vector 2: A single plain tweak
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:1], [False], msg)
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:1], [], msg)
assert sign(secnonce, sk, session_ctx) == expected[1]
assert partial_sig_verify(expected[1], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:1], [False], msg, 2)
assert partial_sig_verify(expected[1], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:1], [], msg, 2)

# Vector 3: A plain tweak followed by an x-only tweak
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:2], [False, True], msg)
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:1], tweaks[1:2], msg)
assert sign(secnonce, sk, session_ctx) == expected[2]
assert partial_sig_verify(expected[2], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:2], [False, True], msg, 2)
assert partial_sig_verify(expected[2], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:1], tweaks[1:2], msg, 2)

# Vector 4: Four tweaks: x-only, plain, x-only, plain
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:4], [True, False, True, False], msg)
assert sign(secnonce, sk, session_ctx) == expected[3]
assert partial_sig_verify(expected[3], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:4], [True, False, True, False], msg, 2)
# # Vector 4: Four tweaks: x-only, plain, x-only, plain
# session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:4], [True, False, True, False], msg)
# assert sign(secnonce, sk, session_ctx) == expected[3]
# assert partial_sig_verify(expected[3], [pnonce[1], pnonce[2], pnonce[0]], [X[0], X[1], pk], tweaks[:4], [True, False, True, False], msg, 2)

# Vector 5: Tweak is invalid because it exceeds group size
invalid_tweak = bytes.fromhex('FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141')
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], [invalid_tweak], [False], msg)
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], [invalid_tweak], [], msg)
assertRaises(ValueError,
lambda: sign(secnonce, sk, session_ctx),
lambda e: str(e) == 'The tweak must be less than n.')
Expand Down Expand Up @@ -774,29 +796,29 @@ def test_sig_agg_vectors():
session_ctx = SessionContext(aggnonce[0], [X[0], X[1]], [], [], msg)
sig = partial_sig_agg([psig[0], psig[1]], session_ctx)
assert sig == expected[0]
aggpk = get_pk(key_agg([X[0], X[1]]))
aggpk = get_pk(to_bip340_context(key_agg([X[0], X[1]])))
assert schnorr_verify(msg, aggpk, sig)

# Vector 2
session_ctx = SessionContext(aggnonce[1], [X[0], X[2]], [], [], msg)
sig = partial_sig_agg([psig[2], psig[3]], session_ctx)
assert sig == expected[1]
aggpk = get_pk(key_agg([X[0], X[2]]))
aggpk = get_pk(to_bip340_context(key_agg([X[0], X[2]])))
assert schnorr_verify(msg, aggpk, sig)

# Vector 3
session_ctx = SessionContext(aggnonce[2], [X[0], X[2]], [tweaks[0]], [False], msg)
session_ctx = SessionContext(aggnonce[2], [X[0], X[2]], [tweaks[0]], [], msg)
sig = partial_sig_agg([psig[4], psig[5]], session_ctx)
assert sig == expected[2]
aggpk = get_pk(key_agg_and_tweak([X[0], X[2]], [tweaks[0]], [False]))
aggpk = get_pk(key_agg_and_tweak([X[0], X[2]], [tweaks[0]], []))
assert schnorr_verify(msg, aggpk, sig)

# Vector 4
session_ctx = SessionContext(aggnonce[3], [X[0], X[3]], tweaks, [True, False, True], msg)
sig = partial_sig_agg([psig[6], psig[7]], session_ctx)
assert sig == expected[3]
aggpk = get_pk(key_agg_and_tweak([X[0], X[3]], tweaks, [True, False, True]))
assert schnorr_verify(msg, aggpk, sig)
# session_ctx = SessionContext(aggnonce[3], [X[0], X[3]], tweaks, [True, False, True], msg)
# sig = partial_sig_agg([psig[6], psig[7]], session_ctx)
# assert sig == expected[3]
# aggpk = get_pk(key_agg_and_tweak([X[0], X[3]], tweaks, [True, False, True]))
# assert schnorr_verify(msg, aggpk, sig)

# Vector 5: Partial signature is invalid because it exceeds group size
invalid_psig = bytes.fromhex('FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141')
Expand All @@ -822,9 +844,10 @@ def test_sign_and_verify_random(iters):
# instead.
msg = secrets.token_bytes(32)
v = secrets.randbelow(4)
tweaks = [secrets.token_bytes(32) for _ in range(v)]
is_xonly = [secrets.choice([False, True]) for _ in range(v)]
aggpk = get_pk(key_agg_and_tweak(pubkeys, tweaks, is_xonly))
plain_tweaks = [secrets.token_bytes(32) for _ in range(v)]
v = secrets.randbelow(4)
xonly_tweaks = [secrets.token_bytes(32) for _ in range(v)]
aggpk = get_pk(key_agg_and_tweak(pubkeys, plain_tweaks, xonly_tweaks))

# Use a non-repeating counter for extra_in
secnonce_1, pubnonce_1 = nonce_gen(sk_1, aggpk, msg, i.to_bytes(4, 'big'))
Expand All @@ -836,22 +859,22 @@ def test_sign_and_verify_random(iters):
pubnonces = [pubnonce_1, pubnonce_2]
aggnonce = nonce_agg(pubnonces)

session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
session_ctx = SessionContext(aggnonce, pubkeys, plain_tweaks, xonly_tweaks, msg)
psig_1 = sign(secnonce_1, sk_1, session_ctx)
# Clear the secnonce after use
secnonce_1 = bytes(64)
assert partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, msg, 0)
assert partial_sig_verify(psig_1, pubnonces, pubkeys, plain_tweaks, xonly_tweaks, msg, 0)

# Wrong signer index
assert not partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, msg, 1)
assert not partial_sig_verify(psig_1, pubnonces, pubkeys, plain_tweaks, xonly_tweaks, msg, 1)

# Wrong message
assert not partial_sig_verify(psig_1, pubnonces, pubkeys, tweaks, is_xonly, secrets.token_bytes(32), 0)
assert not partial_sig_verify(psig_1, pubnonces, pubkeys, plain_tweaks, xonly_tweaks, secrets.token_bytes(32), 0)

psig_2 = sign(secnonce_2, sk_2, session_ctx)
# Clear the secnonce after use
secnonce_2 = bytes(64)
assert partial_sig_verify(psig_2, pubnonces, pubkeys, tweaks, is_xonly, msg, 1)
assert partial_sig_verify(psig_2, pubnonces, pubkeys, plain_tweaks, xonly_tweaks, msg, 1)

sig = partial_sig_agg([psig_1, psig_2], session_ctx)
assert schnorr_verify(msg, aggpk, sig)
Expand Down

0 comments on commit c343dc2

Please sign in to comment.