Skip to content

Commit

Permalink
common: add tal_arr_eq helper.
Browse files Browse the repository at this point in the history
We do `memeq(a, tal_bytelen(a), b, tal_bytelen(b))` remarkably often...

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
  • Loading branch information
rustyrussell authored and cdecker committed Feb 16, 2024
1 parent b6cc0ce commit df44431
Show file tree
Hide file tree
Showing 50 changed files with 98 additions and 178 deletions.
6 changes: 2 additions & 4 deletions common/psbt_open.c
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ static bool input_identical(const struct wally_psbt *a,
const u8 *b_in = linearize_input(tmpctx,
&b->inputs[b_index]);

return memeq(a_in, tal_bytelen(a_in),
b_in, tal_bytelen(b_in));
return tal_arr_eq(a_in, b_in);
}

static bool output_identical(const struct wally_psbt *a,
Expand All @@ -146,8 +145,7 @@ static bool output_identical(const struct wally_psbt *a,
&a->outputs[a_index]);
const u8 *b_out = linearize_output(tmpctx,
&b->outputs[b_index]);
return memeq(a_out, tal_bytelen(a_out),
b_out, tal_bytelen(b_out));
return tal_arr_eq(a_out, b_out);
}

static void sort_inputs(struct wally_psbt *psbt)
Expand Down
3 changes: 1 addition & 2 deletions common/test/run-bolt11.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ static void test_b11(const char *b11str,
list);
list_for_each(&b11->extra_fields, b11_extra, list) {
assert(expect_extra->tag == b11_extra->tag);
assert(memeq(expect_extra->data, tal_bytelen(expect_extra->data),
b11_extra->data, tal_bytelen(b11_extra->data)));
assert(tal_arr_eq(expect_extra->data, b11_extra->data));
expect_extra = list_next(&expect_b11->extra_fields,
expect_extra, list);
}
Expand Down
2 changes: 1 addition & 1 deletion common/test/run-cryptomsg.c
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ int main(int argc, char *argv[])
tal_resize(&enc, tal_bytelen(enc) - CRYPTOMSG_HDR_SIZE);

dec = cryptomsg_decrypt_body(enc, &cs_in, enc);
assert(memeq(dec, tal_bytelen(dec), msg, tal_bytelen(msg)));
assert(tal_arr_eq(dec, (u8 *)msg));
}
common_shutdown();
return 0;
Expand Down
6 changes: 2 additions & 4 deletions common/test/run-features.c
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,15 @@ static void test_featurebits_or(void)
set_feature_bit(&control, i);
}
u8 *result = featurebits_or(tmpctx, take(f1), f2);
assert(
memeq(result, tal_bytelen(result), control, tal_bytelen(control)));
assert(tal_arr_eq(result, control));
}

static bool feature_set_eq(const struct feature_set *f1,
const struct feature_set *f2)
{
/* We assume minimal sizing */
for (size_t i = 0; i < ARRAY_SIZE(f1->bits); i++) {
if (!memeq(f1->bits[i], tal_bytelen(f1->bits[i]),
f2->bits[i], tal_bytelen(f2->bits[i])))
if (!tal_arr_eq(f1->bits[i], f2->bits[i]))
return false;
}
return true;
Expand Down
5 changes: 2 additions & 3 deletions common/test/run-onion-test-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ int main(int argc, char *argv[])

expected = json_tok_bin_from_hex(tmpctx, json, json_get_member(json, toks, "onion"));
actual = serialize_onionpacket(tmpctx, op);
assert(memeq(expected, tal_bytelen(expected), actual, tal_bytelen(actual)));
assert(tal_arr_eq(expected, actual));

/* Now decode! */
op = parse_onionpacket(tmpctx, actual, tal_bytelen(actual), NULL);
Expand All @@ -187,8 +187,7 @@ int main(int argc, char *argv[])
json_to_secret(json, t, &mykey);
test_ecdh(&op->ephemeralkey, &ss);
rs = process_onionpacket(tmpctx, op, &ss, assoc_data, tal_bytelen(assoc_data), true);
assert(memeq(rs->raw_payload, tal_bytelen(rs->raw_payload),
payloads[i], tal_bytelen(payloads[i])));
assert(tal_arr_eq(rs->raw_payload, payloads[i]));
if (rs->nextcase == ONION_FORWARD)
op = rs->next;
else
Expand Down
12 changes: 4 additions & 8 deletions common/test/run-route_blinding_onion_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ int main(int argc, char *argv[])
json_tok_bin_from_hex,
&payload),
JSON_SCAN(json_to_pubkey, &ids[i])) == NULL);
assert(memeq(payload, tal_bytelen(payload),
onionhops[i], tal_bytelen(onionhops[i])));
assert(tal_arr_eq(payload, onionhops[i]));
}

/* Now, create onion! */
Expand All @@ -174,8 +173,7 @@ int main(int argc, char *argv[])
JSON_SCAN_TAL(tmpctx,
json_tok_bin_from_hex,
&expected_onion)) == NULL);
assert(memeq(expected_onion, tal_bytelen(expected_onion),
onion, tal_bytelen(onion)));
assert(tal_arr_eq(expected_onion, onion));

/* FIXME: unwrap and test! */
#if 0
Expand All @@ -196,8 +194,7 @@ int main(int argc, char *argv[])
JSON_SCAN_TAL(tmpctx, json_tok_bin_from_hex, &expected_onion))
== NULL);
serialized = serialize_onionpacket(tmpctx, op);
assert(memeq(expected_onion, tal_bytelen(expected_onion),
serialized, tal_bytelen(serialized)));
assert(tal_arr_eq(expected_onion, serialized));

if (blinding) {
assert(unblind_onion(blinding, test_ecdh,
Expand All @@ -207,8 +204,7 @@ int main(int argc, char *argv[])
}
rs = process_onionpacket(tmpctx, op, &ss, associated_data,
tal_bytelen(associated_data), true);
assert(memeq(rs->raw_payload, tal_bytelen(rs->raw_payload),
onionhops[i], tal_bytelen(onionhops[i])));
assert(tal_arr_eq(rs->raw_payload, onionhops[i]));
if (rs->nextcase == ONION_FORWARD)
op = rs->next;
else
Expand Down
6 changes: 2 additions & 4 deletions common/test/run-route_blinding_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ int main(int argc, char *argv[])
json_get_member(json, t, "encoded_tlvs"));
expected = json_to_enctlvs(tmpctx, json,
json_get_member(json, t, "tlvs"));
assert(memeq(expected, tal_bytelen(expected),
enctlvs[i], tal_bytelen(enctlvs[i])));
assert(tal_arr_eq(expected, enctlvs[i]));
}

/* Now do the blinding. */
Expand Down Expand Up @@ -221,8 +220,7 @@ int main(int argc, char *argv[])
expected_enctlv = json_tok_bin_from_hex(tmpctx,json,
json_get_member(json, t,
"encrypted_data"));
assert(memeq(enctlv, tal_bytelen(enctlv),
expected_enctlv, tal_bytelen(expected_enctlv)));
assert(tal_arr_eq(enctlv, expected_enctlv));

json_to_pubkey(json, json_get_member(json, t, "blinded_node_id"),
&expected_alias);
Expand Down
2 changes: 1 addition & 1 deletion common/test/run-sphinx.c
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ static void run_unit_tests(void)

oreply = unwrap_onionreply(tmpctx, ss, 5, reply, &origin_index);
printf("unwrapped %s\n", tal_hex(tmpctx, oreply));
assert(memeq(raw, tal_bytelen(raw), oreply, tal_bytelen(oreply)));
assert(tal_arr_eq(raw, oreply));
assert(origin_index == 4);
}

Expand Down
6 changes: 6 additions & 0 deletions common/utils.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "config.h"
#include <bitcoin/chainparams.h>
#include <ccan/list/list.h>
#include <ccan/mem/mem.h>
#include <ccan/str/hex/hex.h>
#include <ccan/tal/path/path.h>
#include <ccan/tal/str/str.h>
Expand Down Expand Up @@ -76,6 +77,11 @@ u8 *tal_hexdata(const tal_t *ctx, const void *str, size_t len)
return data;
}

bool tal_arr_eq_(const void *a, const void *b, size_t unused)
{
return memeq(a, tal_bytelen(a), b, tal_bytelen(b));
}

/* Use the POSIX C locale. */
void setup_locale(void)
{
Expand Down
6 changes: 6 additions & 0 deletions common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ u8 *tal_hexdata(const tal_t *ctx, const void *str, size_t len);
(*(p))[n_] = (s); \
} while(0)

/* Checks if two tal_arr are equal: sizeof checks types are the same */
#define tal_arr_eq(a, b) tal_arr_eq_((a), (b), sizeof((a) == (b)))

/* Avoids double-evaluation in macro */
bool tal_arr_eq_(const void *a, const void *b, size_t unused);

/**
* Remove an element from an array
*
Expand Down
3 changes: 1 addition & 2 deletions devtools/checkchannels.c
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ int main(int argc, char *argv[])
wscript = bitcoin_redeem_2of2(ctx, &local_fundingkey, &remote_fundingkey);
expect_scriptpubkey = scriptpubkey_p2wsh(ctx, wscript);

if (!memeq(expect_scriptpubkey, tal_bytelen(expect_scriptpubkey),
scriptpubkey, tal_bytelen(scriptpubkey))) {
if (!tal_arr_eq(expect_scriptpubkey, scriptpubkey)) {
printf("*** FATAL *** outscript %s should be %s\n",
tal_hex(ctx, scriptpubkey),
tal_hex(ctx, expect_scriptpubkey));
Expand Down
3 changes: 1 addition & 2 deletions devtools/onion.c
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,7 @@ static void runtest(const char *filename)

if (oniontok) {
onion = json_tok_bin_from_hex(ctx, buffer, oniontok);
if (!memeq(onion, tal_bytelen(onion), serialized,
tal_bytelen(serialized)))
if (!tal_arr_eq(onion, serialized))
errx(1,
"Generated does not match the expected onion: \n"
"generated: %s\n"
Expand Down
3 changes: 1 addition & 2 deletions gossipd/gossmap_manage.c
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,7 @@ void gossmap_manage_handle_get_txout_reply(struct gossmap_manage *gm, const u8 *
goto bad;
}

if (!memeq(outscript, tal_bytelen(outscript),
pca->scriptpubkey, tal_bytelen(pca->scriptpubkey))) {
if (!tal_arr_eq(outscript, pca->scriptpubkey)) {
peer_warning(gm, pca->source_peer,
"channel_announcement: txout %s expected %s, got %s",
short_channel_id_to_str(tmpctx, &scid),
Expand Down
4 changes: 1 addition & 3 deletions openingd/dualopend.c
Original file line number Diff line number Diff line change
Expand Up @@ -500,9 +500,7 @@ static void handle_peer_shutdown(struct state *state, u8 *msg)
open_err_fatal(state, "Bad shutdown %s", tal_hex(msg, msg));

if (tal_count(state->upfront_shutdown_script[REMOTE])
&& !memeq(scriptpubkey, tal_count(scriptpubkey),
state->upfront_shutdown_script[REMOTE],
tal_count(state->upfront_shutdown_script[REMOTE])))
&& !tal_arr_eq(scriptpubkey, state->upfront_shutdown_script[REMOTE]))
open_err_fatal(state,
"scriptpubkey %s is not as agreed upfront (%s)",
tal_hex(state, scriptpubkey),
Expand Down
2 changes: 1 addition & 1 deletion tests/fuzz/fuzz-cryptomsg.c
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ static void test_encrypt_decrypt_equality(const u8 *msg)
tal_resize(&enc, tal_bytelen(enc) - CRYPTOMSG_HDR_SIZE);

dec = cryptomsg_decrypt_body(msg, &cs_in, enc);
assert(memeq(dec, tal_bytelen(dec), msg, tal_bytelen(msg)));
assert(tal_arr_eq(dec, msg));
}

/* Test header decryption of arbitrary bytes (should always fail). */
Expand Down
9 changes: 3 additions & 6 deletions tests/fuzz/fuzz-wire-accept_channel.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,11 @@ static bool equal(const struct accept_channel *x,

assert(x->tlvs && y->tlvs);

if (!memeq(x->tlvs->upfront_shutdown_script,
tal_bytelen(x->tlvs->upfront_shutdown_script),
y->tlvs->upfront_shutdown_script,
tal_bytelen(y->tlvs->upfront_shutdown_script)))
if (!tal_arr_eq(x->tlvs->upfront_shutdown_script,
y->tlvs->upfront_shutdown_script))
return false;

return memeq(x->tlvs->channel_type, tal_bytelen(x->tlvs->channel_type),
y->tlvs->channel_type, tal_bytelen(y->tlvs->channel_type));
return tal_arr_eq(x->tlvs->channel_type, y->tlvs->channel_type);
}

void run(const u8 *data, size_t size)
Expand Down
9 changes: 3 additions & 6 deletions tests/fuzz/fuzz-wire-accept_channel2.c
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,11 @@ static bool equal(const struct accept_channel2 *x,

assert(x->tlvs && y->tlvs);

if (!memeq(x->tlvs->upfront_shutdown_script,
tal_bytelen(x->tlvs->upfront_shutdown_script),
y->tlvs->upfront_shutdown_script,
tal_bytelen(y->tlvs->upfront_shutdown_script)))
if (!tal_arr_eq(x->tlvs->upfront_shutdown_script,
y->tlvs->upfront_shutdown_script))
return false;

if (!memeq(x->tlvs->channel_type, tal_bytelen(x->tlvs->channel_type),
y->tlvs->channel_type, tal_bytelen(y->tlvs->channel_type)))
if (!tal_arr_eq(x->tlvs->channel_type, y->tlvs->channel_type))
return false;

if (!!x->tlvs->require_confirmed_inputs !=
Expand Down
3 changes: 1 addition & 2 deletions tests/fuzz/fuzz-wire-channel_announcement.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ static bool equal(const struct channel_announcement *x,
if (!node_id_eq(&x->node_id_2, &y->node_id_2))
return false;

return memeq(x->features, tal_bytelen(x->features), y->features,
tal_bytelen(y->features));
return tal_arr_eq(x->features, y->features);
}

void run(const u8 *data, size_t size)
Expand Down
4 changes: 1 addition & 3 deletions tests/fuzz/fuzz-wire-channel_ready.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ static bool equal(const struct channel_ready *x, const struct channel_ready *y)
return false;

assert(x->tlvs && y->tlvs);
return memeq(
x->tlvs->short_channel_id, tal_bytelen(x->tlvs->short_channel_id),
y->tlvs->short_channel_id, tal_bytelen(y->tlvs->short_channel_id));
return tal_arr_eq(x->tlvs->short_channel_id, y->tlvs->short_channel_id);
}

void run(const u8 *data, size_t size)
Expand Down
22 changes: 6 additions & 16 deletions tests/fuzz/fuzz-wire-channel_reestablish.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,16 @@ static bool equal(const struct channel_reestablish *x,

assert(x->tlvs && y->tlvs);

if (!memeq(x->tlvs->next_funding, tal_bytelen(x->tlvs->next_funding),
y->tlvs->next_funding, tal_bytelen(y->tlvs->next_funding)))
if (!tal_arr_eq(x->tlvs->next_funding, y->tlvs->next_funding))
return false;
if (!memeq(x->tlvs->next_to_send, tal_bytelen(x->tlvs->next_to_send),
y->tlvs->next_to_send, tal_bytelen(y->tlvs->next_to_send)))
if (!tal_arr_eq(x->tlvs->next_to_send, y->tlvs->next_to_send))
return false;
if (!memeq(x->tlvs->desired_channel_type,
tal_bytelen(x->tlvs->desired_channel_type),
y->tlvs->desired_channel_type,
tal_bytelen(y->tlvs->desired_channel_type)))
if (!tal_arr_eq(x->tlvs->desired_channel_type, y->tlvs->desired_channel_type))
return false;
if (!memeq(x->tlvs->current_channel_type,
tal_bytelen(x->tlvs->current_channel_type),
y->tlvs->current_channel_type,
tal_bytelen(y->tlvs->current_channel_type)))
if (!tal_arr_eq(x->tlvs->current_channel_type, y->tlvs->current_channel_type))
return false;
return memeq(x->tlvs->upgradable_channel_type,
tal_bytelen(x->tlvs->upgradable_channel_type),
y->tlvs->upgradable_channel_type,
tal_bytelen(y->tlvs->upgradable_channel_type));
return tal_arr_eq(x->tlvs->upgradable_channel_type,
y->tlvs->upgradable_channel_type);
}

void run(const u8 *data, size_t size)
Expand Down
3 changes: 1 addition & 2 deletions tests/fuzz/fuzz-wire-closing_signed.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ static bool equal(const struct closing_signed *x,
return false;

assert(x->tlvs && y->tlvs);
return memeq(x->tlvs->fee_range, tal_bytelen(x->tlvs->fee_range),
y->tlvs->fee_range, tal_bytelen(y->tlvs->fee_range));
return tal_arr_eq(x->tlvs->fee_range, y->tlvs->fee_range);
}

void run(const u8 *data, size_t size)
Expand Down
6 changes: 2 additions & 4 deletions tests/fuzz/fuzz-wire-commitment_signed.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ static bool equal(struct commitment_signed *x, struct commitment_signed *y)
if (memcmp(x, y, upto_htlc_signature) != 0)
return false;

if (!memeq(x->htlc_signature, tal_bytelen(x->htlc_signature),
y->htlc_signature, tal_bytelen(y->htlc_signature)))
if (!tal_arr_eq(x->htlc_signature, y->htlc_signature))
return false;

assert(x->tlvs && y->tlvs);
return memeq(x->tlvs->splice_info, tal_bytelen(x->tlvs->splice_info),
y->tlvs->splice_info, tal_bytelen(y->tlvs->splice_info));
return tal_arr_eq(x->tlvs->splice_info, y->tlvs->splice_info);
}

void run(const u8 *data, size_t size)
Expand Down
3 changes: 1 addition & 2 deletions tests/fuzz/fuzz-wire-error.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ static bool equal(const struct error *x, const struct error *y)
{
if (!channel_id_eq(&x->channel_id, &y->channel_id))
return false;
return memeq(x->data, tal_bytelen(x->data), y->data,
tal_bytelen(y->data));
return tal_arr_eq(x->data, y->data);
}

void run(const u8 *data, size_t size)
Expand Down
12 changes: 4 additions & 8 deletions tests/fuzz/fuzz-wire-init.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,17 @@ static struct init *decode(const tal_t *ctx, const void *p)

static bool equal(const struct init *x, const struct init *y)
{
if (!memeq(x->globalfeatures, tal_bytelen(x->globalfeatures),
y->globalfeatures, tal_bytelen(y->globalfeatures)))
if (!tal_arr_eq(x->globalfeatures, y->globalfeatures))
return false;
if (!memeq(x->features, tal_bytelen(x->features), y->features,
tal_bytelen(y->features)))
if (!tal_arr_eq(x->features, y->features))
return false;

assert(x->tlvs && y->tlvs);

if (!memeq(x->tlvs->networks, tal_bytelen(x->tlvs->networks),
y->tlvs->networks, tal_bytelen(y->tlvs->networks)))
if (!tal_arr_eq(x->tlvs->networks, y->tlvs->networks))
return false;

return memeq(x->tlvs->remote_addr, tal_bytelen(x->tlvs->remote_addr),
y->tlvs->remote_addr, tal_bytelen(y->tlvs->remote_addr));
return tal_arr_eq(x->tlvs->remote_addr, y->tlvs->remote_addr);
}

void run(const u8 *data, size_t size)
Expand Down
Loading

0 comments on commit df44431

Please sign in to comment.