Skip to content

Commit 6468604

Browse files
committed
Use Threshold type in concrete::Policy::Thresh
Use the `Threshold` type in `policy::concrete::Policy::Thresh` to help maintain invariants on n and k.
1 parent 153b16a commit 6468604

File tree

4 files changed

+84
-71
lines changed

4 files changed

+84
-71
lines changed

src/iter/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ impl<'a, Pk: MiniscriptKey> TreeLike for &'a policy::Concrete<Pk> {
7777
| Ripemd160(_) | Hash160(_) => Tree::Nullary,
7878
And(ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()),
7979
Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| p.as_ref()).collect()),
80-
Thresh(_, ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()),
80+
Thresh(thresh) => Tree::Nary(thresh.iter().map(Arc::as_ref).collect()),
8181
}
8282
}
8383
}
@@ -90,7 +90,7 @@ impl<'a, Pk: MiniscriptKey> TreeLike for Arc<policy::Concrete<Pk>> {
9090
| Ripemd160(_) | Hash160(_) => Tree::Nullary,
9191
And(ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()),
9292
Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| Arc::clone(p)).collect()),
93-
Thresh(_, ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()),
93+
Thresh(thresh) => Tree::Nary(thresh.iter().map(Arc::clone).collect()),
9494
}
9595
}
9696
}

src/policy/compiler.rs

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -920,8 +920,9 @@ where
920920
compile_binary!(&mut l_comp[3], &mut r_comp[2], [lw, rw], Terminal::OrI);
921921
compile_binary!(&mut r_comp[3], &mut l_comp[2], [rw, lw], Terminal::OrI);
922922
}
923-
Concrete::Thresh(k, ref subs) => {
924-
let n = subs.len();
923+
Concrete::Thresh(ref thresh) => {
924+
let k = thresh.k();
925+
let n = thresh.n();
925926
let k_over_n = k as f64 / n as f64;
926927

927928
let mut sub_ast = Vec::with_capacity(n);
@@ -931,7 +932,7 @@ where
931932
let mut best_ws = Vec::with_capacity(n);
932933

933934
let mut min_value = (0, f64::INFINITY);
934-
for (i, ast) in subs.iter().enumerate() {
935+
for (i, ast) in thresh.iter().enumerate() {
935936
let sp = sat_prob * k_over_n;
936937
//Expressions must be dissatisfiable
937938
let dp = Some(dissat_prob.unwrap_or(0 as f64) + (1.0 - k_over_n) * sat_prob);
@@ -949,7 +950,7 @@ where
949950
}
950951
sub_ext_data.push(best_es[min_value.0].0);
951952
sub_ast.push(Arc::clone(&best_es[min_value.0].1.ms));
952-
for (i, _ast) in subs.iter().enumerate() {
953+
for (i, _ast) in thresh.iter().enumerate() {
953954
if i != min_value.0 {
954955
sub_ext_data.push(best_ws[i].0);
955956
sub_ast.push(Arc::clone(&best_ws[i].1.ms));
@@ -966,7 +967,7 @@ where
966967
insert_wrap!(ast_ext);
967968
}
968969

969-
let key_vec: Vec<Pk> = subs
970+
let key_vec: Vec<Pk> = thresh
970971
.iter()
971972
.filter_map(|s| {
972973
if let Concrete::Key(ref pk) = s.as_ref() {
@@ -978,16 +979,16 @@ where
978979
.collect();
979980

980981
match Ctx::sig_type() {
981-
SigType::Schnorr if key_vec.len() == subs.len() => {
982+
SigType::Schnorr if key_vec.len() == thresh.n() => {
982983
insert_wrap!(AstElemExt::terminal(Terminal::MultiA(k, key_vec)))
983984
}
984985
SigType::Ecdsa
985-
if key_vec.len() == subs.len() && subs.len() <= MAX_PUBKEYS_PER_MULTISIG =>
986+
if key_vec.len() == thresh.n() && thresh.n() <= MAX_PUBKEYS_PER_MULTISIG =>
986987
{
987988
insert_wrap!(AstElemExt::terminal(Terminal::Multi(k, key_vec)))
988989
}
989-
_ if k == subs.len() => {
990-
let mut it = subs.iter();
990+
_ if k == thresh.n() => {
991+
let mut it = thresh.iter();
991992
let mut policy = it.next().expect("No sub policy in thresh() ?").clone();
992993
policy =
993994
it.fold(policy, |acc, pol| Concrete::And(vec![acc, pol.clone()]).into());
@@ -1157,6 +1158,7 @@ mod tests {
11571158
use super::*;
11581159
use crate::miniscript::{Legacy, Segwitv0, Tap};
11591160
use crate::policy::Liftable;
1161+
use crate::threshold::Threshold;
11601162
use crate::{script_num_size, ToPublicKey};
11611163

11621164
type SPolicy = Concrete<String>;
@@ -1301,19 +1303,19 @@ mod tests {
13011303
let policy: BPolicy = Concrete::Or(vec![
13021304
(
13031305
127,
1304-
Arc::new(Concrete::Thresh(
1306+
Arc::new(Concrete::Thresh(Threshold::new_unchecked(
13051307
3,
13061308
key_pol[0..5].iter().map(|p| (p.clone()).into()).collect(),
1307-
)),
1309+
))),
13081310
),
13091311
(
13101312
1,
13111313
Arc::new(Concrete::And(vec![
13121314
Arc::new(Concrete::Older(Sequence::from_height(10000))),
1313-
Arc::new(Concrete::Thresh(
1315+
Arc::new(Concrete::Thresh(Threshold::new_unchecked(
13141316
2,
13151317
key_pol[5..8].iter().map(|p| (p.clone()).into()).collect(),
1316-
)),
1318+
))),
13171319
])),
13181320
),
13191321
]);
@@ -1430,7 +1432,7 @@ mod tests {
14301432
.iter()
14311433
.map(|pubkey| Arc::new(Concrete::Key(*pubkey)))
14321434
.collect();
1433-
let big_thresh = Concrete::Thresh(*k, pubkeys);
1435+
let big_thresh = Concrete::Thresh(Threshold::new_unchecked(*k, pubkeys));
14341436
let big_thresh_ms: SegwitMiniScript = big_thresh.compile().unwrap();
14351437
if *k == 21 {
14361438
// N * (PUSH + pubkey + CHECKSIGVERIFY)
@@ -1466,8 +1468,8 @@ mod tests {
14661468
.collect();
14671469

14681470
let thresh_res: Result<SegwitMiniScript, _> = Concrete::Or(vec![
1469-
(1, Arc::new(Concrete::Thresh(keys_a.len(), keys_a))),
1470-
(1, Arc::new(Concrete::Thresh(keys_b.len(), keys_b))),
1471+
(1, Arc::new(Concrete::Thresh(Threshold::new_unchecked(keys_a.len(), keys_a)))),
1472+
(1, Arc::new(Concrete::Thresh(Threshold::new_unchecked(keys_b.len(), keys_b)))),
14711473
])
14721474
.compile();
14731475
let script_size = thresh_res.clone().and_then(|m| Ok(m.script_size()));
@@ -1484,7 +1486,8 @@ mod tests {
14841486
.iter()
14851487
.map(|pubkey| Arc::new(Concrete::Key(*pubkey)))
14861488
.collect();
1487-
let thresh_res: Result<SegwitMiniScript, _> = Concrete::Thresh(keys.len(), keys).compile();
1489+
let thresh_res: Result<SegwitMiniScript, _> =
1490+
Concrete::Thresh(Threshold::new_unchecked(keys.len(), keys)).compile();
14881491
let n_elements = thresh_res
14891492
.clone()
14901493
.and_then(|m| Ok(m.max_satisfaction_witness_elements()));
@@ -1505,7 +1508,7 @@ mod tests {
15051508
.map(|pubkey| Arc::new(Concrete::Key(*pubkey)))
15061509
.collect();
15071510
let thresh_res: Result<SegwitMiniScript, _> =
1508-
Concrete::Thresh(keys.len() - 1, keys).compile();
1511+
Concrete::Thresh(Threshold::new_unchecked(keys.len() - 1, keys)).compile();
15091512
let ops_count = thresh_res.clone().and_then(|m| Ok(m.ext.ops.op_count()));
15101513
assert_eq!(
15111514
thresh_res,
@@ -1519,7 +1522,8 @@ mod tests {
15191522
.iter()
15201523
.map(|pubkey| Arc::new(Concrete::Key(*pubkey)))
15211524
.collect();
1522-
let thresh_res = Concrete::Thresh(keys.len() - 1, keys).compile::<Legacy>();
1525+
let thresh_res =
1526+
Concrete::Thresh(Threshold::new_unchecked(keys.len() - 1, keys)).compile::<Legacy>();
15231527
let ops_count = thresh_res.clone().and_then(|m| Ok(m.ext.ops.op_count()));
15241528
assert_eq!(
15251529
thresh_res,

src/policy/concrete.rs

Lines changed: 56 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use crate::iter::TreeLike;
2727
use crate::miniscript::types::extra_props::TimelockInfo;
2828
use crate::prelude::*;
2929
use crate::sync::Arc;
30+
use crate::threshold::Threshold;
3031
#[cfg(all(doc, not(feature = "compiler")))]
3132
use crate::Descriptor;
3233
use crate::{errstr, AbsLockTime, Error, ForEachKey, MiniscriptKey, Translator};
@@ -67,7 +68,7 @@ pub enum Policy<Pk: MiniscriptKey> {
6768
/// relative probabilities for each one.
6869
Or(Vec<(usize, Arc<Policy<Pk>>)>),
6970
/// A set of descriptors, satisfactions must be provided for `k` of them.
70-
Thresh(usize, Vec<Arc<Policy<Pk>>>),
71+
Thresh(Threshold<Arc<Policy<Pk>>>),
7172
}
7273

7374
impl<Pk> Policy<Pk>
@@ -210,9 +211,10 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
210211
})
211212
.collect::<Vec<_>>()
212213
}
213-
Policy::Thresh(k, ref subs) if *k == 1 => {
214-
let total_odds = subs.len();
215-
subs.iter()
214+
Policy::Thresh(thresh) if thresh.k() == 1 => {
215+
let total_odds = thresh.n();
216+
thresh
217+
.iter()
216218
.flat_map(|policy| policy.to_tapleaf_prob_vec(prob / total_odds as f64))
217219
.collect::<Vec<_>>()
218220
}
@@ -430,13 +432,16 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
430432
.map(|(odds, pol)| (prob * *odds as f64 / total_odds as f64, pol.clone()))
431433
.collect::<Vec<_>>()
432434
}
433-
Policy::Thresh(k, subs) if *k == 1 => {
434-
let total_odds = subs.len();
435-
subs.iter()
435+
Policy::Thresh(thresh) if thresh.k() == 1 => {
436+
let total_odds = thresh.n();
437+
thresh
438+
.iter()
436439
.map(|pol| (prob / total_odds as f64, pol.clone()))
437440
.collect::<Vec<_>>()
438441
}
439-
Policy::Thresh(k, subs) if *k != subs.len() => generate_combination(subs, prob, *k),
442+
Policy::Thresh(thresh) if thresh.k() != thresh.n() => {
443+
generate_combination(thresh, prob)
444+
}
440445
pol => vec![(prob, Arc::new(pol.clone()))],
441446
}
442447
}
@@ -585,7 +590,7 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
585590
.enumerate()
586591
.map(|(i, (prob, _))| (*prob, child_n(i)))
587592
.collect()),
588-
Thresh(ref k, ref subs) => Thresh(*k, (0..subs.len()).map(child_n).collect()),
593+
Thresh(ref thresh) => Thresh(thresh.mapped((0..thresh.n()).map(child_n).collect())),
589594
};
590595
translated.push(Arc::new(new_policy));
591596
}
@@ -611,7 +616,9 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
611616
.enumerate()
612617
.map(|(i, (prob, _))| (*prob, child_n(i)))
613618
.collect())),
614-
Thresh(k, ref subs) => Some(Thresh(*k, (0..subs.len()).map(child_n).collect())),
619+
Thresh(ref thresh) => {
620+
Some(Thresh(thresh.mapped((0..thresh.n()).map(child_n).collect())))
621+
}
615622
_ => None,
616623
};
617624
match new_policy {
@@ -647,7 +654,7 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
647654

648655
let num = match data.node {
649656
Or(subs) => (0..subs.len()).map(num_for_child_n).sum(),
650-
Thresh(k, subs) if *k == 1 => (0..subs.len()).map(num_for_child_n).sum(),
657+
Thresh(thresh) if thresh.k() == 1 => (0..thresh.n()).map(num_for_child_n).sum(),
651658
_ => 1,
652659
};
653660
nums.push(num);
@@ -730,9 +737,9 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
730737
let iter = (0..subs.len()).map(info_for_child_n);
731738
TimelockInfo::combine_threshold(1, iter)
732739
}
733-
Thresh(ref k, subs) => {
734-
let iter = (0..subs.len()).map(info_for_child_n);
735-
TimelockInfo::combine_threshold(*k, iter)
740+
Thresh(ref thresh) => {
741+
let iter = (0..thresh.n()).map(info_for_child_n);
742+
TimelockInfo::combine_threshold(thresh.k(), iter)
736743
}
737744
_ => TimelockInfo::default(),
738745
};
@@ -773,11 +780,6 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
773780
return Err(PolicyError::NonBinaryArgOr);
774781
}
775782
}
776-
Thresh(k, ref subs) => {
777-
if k == 0 || k > subs.len() {
778-
return Err(PolicyError::IncorrectThresh);
779-
}
780-
}
781783
_ => {}
782784
}
783785
}
@@ -817,16 +819,16 @@ impl<Pk: MiniscriptKey> Policy<Pk> {
817819
});
818820
(all_safe, atleast_one_safe && all_non_mall)
819821
}
820-
Thresh(k, ref subs) => {
821-
let (safe_count, non_mall_count) = (0..subs.len()).map(acc_for_child_n).fold(
822-
(0, 0),
823-
|(safe_count, non_mall_count), (safe, non_mall)| {
822+
Policy::Thresh(ref thresh) => {
823+
let (safe_count, non_mall_count) = thresh
824+
.iter()
825+
.map(|sub| sub.is_safe_nonmalleable())
826+
.fold((0, 0), |(safe_count, non_mall_count), (safe, non_mall)| {
824827
(safe_count + safe as usize, non_mall_count + non_mall as usize)
825-
},
826-
);
828+
});
827829
(
828-
safe_count >= (subs.len() - k + 1),
829-
non_mall_count == subs.len() && safe_count >= (subs.len() - k),
830+
safe_count >= (thresh.n() - thresh.k() + 1),
831+
non_mall_count == thresh.n() && safe_count >= (thresh.n() - thresh.k()),
830832
)
831833
}
832834
};
@@ -869,10 +871,10 @@ impl<Pk: MiniscriptKey> fmt::Debug for Policy<Pk> {
869871
}
870872
f.write_str(")")
871873
}
872-
Policy::Thresh(k, ref subs) => {
873-
write!(f, "thresh({}", k)?;
874-
for sub in subs {
875-
write!(f, ",{:?}", sub)?;
874+
Policy::Thresh(ref thresh) => {
875+
write!(f, "thresh({}", thresh.k())?;
876+
for policy in thresh.iter() {
877+
write!(f, ",{:?}", policy)?;
876878
}
877879
f.write_str(")")
878880
}
@@ -912,10 +914,10 @@ impl<Pk: MiniscriptKey> fmt::Display for Policy<Pk> {
912914
}
913915
f.write_str(")")
914916
}
915-
Policy::Thresh(k, ref subs) => {
916-
write!(f, "thresh({}", k)?;
917-
for sub in subs {
918-
write!(f, ",{}", sub)?;
917+
Policy::Thresh(ref thresh) => {
918+
write!(f, "thresh({}", thresh.k())?;
919+
for policy in thresh.iter() {
920+
write!(f, ",{}", policy)?;
919921
}
920922
f.write_str(")")
921923
}
@@ -1028,16 +1030,19 @@ impl_block_str!(
10281030
return Err(Error::PolicyError(PolicyError::IncorrectThresh));
10291031
}
10301032

1031-
let thresh = expression::parse_num(top.args[0].name)?;
1032-
if thresh >= nsubs || thresh == 0 {
1033+
let k = expression::parse_num(top.args[0].name)?;
1034+
if k >= nsubs || k == 0 {
10331035
return Err(Error::PolicyError(PolicyError::IncorrectThresh));
10341036
}
10351037

10361038
let mut subs = Vec::with_capacity(top.args.len() - 1);
10371039
for arg in &top.args[1..] {
10381040
subs.push(Policy::from_tree(arg)?);
10391041
}
1040-
Ok(Policy::Thresh(thresh as usize, subs.into_iter().map(Arc::new).collect()))
1042+
let v = subs.into_iter().map(Arc::new).collect();
1043+
1044+
let thresh = Threshold::new(k as usize, v).map_err(|_| PolicyError::IncorrectThresh)?;
1045+
Ok(Policy::Thresh(thresh))
10411046
}
10421047
_ => Err(errstr(top.name)),
10431048
}
@@ -1089,20 +1094,20 @@ fn with_huffman_tree<Pk: MiniscriptKey>(
10891094
/// any one of the conditions exclusively.
10901095
#[cfg(feature = "compiler")]
10911096
fn generate_combination<Pk: MiniscriptKey>(
1092-
policy_vec: &Vec<Arc<Policy<Pk>>>,
1097+
policy_thresh: &Threshold<Arc<Policy<Pk>>>,
10931098
prob: f64,
1094-
k: usize,
10951099
) -> Vec<(f64, Arc<Policy<Pk>>)> {
1096-
debug_assert!(k <= policy_vec.len());
1097-
10981100
let mut ret: Vec<(f64, Arc<Policy<Pk>>)> = vec![];
1099-
for i in 0..policy_vec.len() {
1100-
let policies: Vec<Arc<Policy<Pk>>> = policy_vec
1101+
let k = policy_thresh.k();
1102+
for i in 0..policy_thresh.n() {
1103+
let policies: Vec<Arc<Policy<Pk>>> = policy_thresh
11011104
.iter()
11021105
.enumerate()
11031106
.filter_map(|(j, sub)| if j != i { Some(Arc::clone(sub)) } else { None })
11041107
.collect();
1105-
ret.push((prob / policy_vec.len() as f64, Arc::new(Policy::Thresh(k, policies))));
1108+
if let Ok(thresh) = Threshold::new(k, policies) {
1109+
ret.push((prob / policy_thresh.n() as f64, Arc::new(Policy::Thresh(thresh))));
1110+
}
11061111
}
11071112
ret
11081113
}
@@ -1123,7 +1128,8 @@ mod compiler_tests {
11231128
.map(|p| Arc::new(p))
11241129
.collect();
11251130

1126-
let combinations = generate_combination(&policies, 1.0, 2);
1131+
let thresh = Threshold::new_unchecked(2, policies);
1132+
let combinations = generate_combination(&thresh, 1.0);
11271133

11281134
let comb_a: Vec<Policy<String>> = vec![
11291135
policy_str!("pk(B)"),
@@ -1150,7 +1156,10 @@ mod compiler_tests {
11501156
.map(|sub_pol| {
11511157
(
11521158
0.25,
1153-
Arc::new(Policy::Thresh(2, sub_pol.into_iter().map(|p| Arc::new(p)).collect())),
1159+
Arc::new(Policy::Thresh(Threshold::new_unchecked(
1160+
2,
1161+
sub_pol.into_iter().map(|p| Arc::new(p)).collect(),
1162+
))),
11541163
)
11551164
})
11561165
.collect::<Vec<_>>();

0 commit comments

Comments
 (0)