Skip to content

Commit 153b16a

Browse files
committed
Add a Threshold<T> type
We have various enums in the codebase that include a `Thresh` variant, we have to explicitly check that invariants are maintained all over the place because these enums are public (eg, `policy::Concrete`). Add a `Threshold<T>` type that abstracts over a threshold and maintains the following invariants: - v.len() > 0 - k > 0 - k <= v.len()
1 parent 68d9994 commit 153b16a

File tree

2 files changed

+136
-2
lines changed

2 files changed

+136
-2
lines changed

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ pub mod miniscript;
126126
pub mod plan;
127127
pub mod policy;
128128
pub mod psbt;
129+
pub mod threshold;
129130

130131
#[cfg(test)]
131132
mod test_utils;
@@ -861,7 +862,7 @@ mod prelude {
861862
rc, slice,
862863
string::{String, ToString},
863864
sync,
864-
vec::Vec,
865+
vec::{self, Vec},
865866
};
866867
#[cfg(any(feature = "std", test))]
867868
pub use std::{
@@ -872,7 +873,7 @@ mod prelude {
872873
string::{String, ToString},
873874
sync,
874875
sync::Mutex,
875-
vec::Vec,
876+
vec::{self, Vec},
876877
};
877878

878879
#[cfg(all(not(feature = "std"), not(test)))]

src/threshold.rs

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// SPDX-License-Identifier: CC0-1.0
2+
3+
//! A generic (k,n)-threshold type.
4+
5+
use core::fmt;
6+
7+
use crate::prelude::{vec, Vec};
8+
9+
/// A (k, n)-threshold.
10+
///
11+
/// This type maintains the following invariants:
12+
/// - n > 0
13+
/// - k > 0
14+
/// - k <= n
15+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
16+
pub struct Threshold<T> {
17+
k: usize,
18+
v: Vec<T>,
19+
}
20+
21+
impl<T> Threshold<T> {
22+
/// Creates a `Theshold<T>` after checking that invariants hold.
23+
pub fn new(k: usize, v: Vec<T>) -> Result<Threshold<T>, Error> {
24+
if v.len() == 0 {
25+
Err(Error::ZeroN)
26+
} else if k == 0 {
27+
Err(Error::ZeroK)
28+
} else if k > v.len() {
29+
Err(Error::BigK)
30+
} else {
31+
Ok(Threshold { k, v })
32+
}
33+
}
34+
35+
/// Creates a `Theshold<T>` without checking that invariants hold.
36+
#[cfg(test)]
37+
pub fn new_unchecked(k: usize, v: Vec<T>) -> Threshold<T> { Threshold { k, v } }
38+
39+
/// Returns `k`, the threshold value.
40+
pub fn k(&self) -> usize { self.k }
41+
42+
/// Returns `n`, the total number of elements in the threshold.
43+
pub fn n(&self) -> usize { self.v.len() }
44+
45+
/// Returns a read-only iterator over the threshold elements.
46+
pub fn iter(&self) -> core::slice::Iter<'_, T> { self.v.iter() }
47+
48+
/// Creates an iterator over the threshold elements.
49+
pub fn into_iter(self) -> vec::IntoIter<T> { self.v.into_iter() }
50+
51+
/// Creates an iterator over the threshold elements.
52+
pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, T> { self.v.iter_mut() }
53+
54+
/// Returns the threshold elements, consuming self.
55+
pub fn into_elements(self) -> Vec<T> { self.v }
56+
57+
/// Creates a new (k, n)-threshold using a newly mapped vector.
58+
///
59+
/// Typically this function is called after collecting a vector that was
60+
/// created by iterating this threshold. E.g.,
61+
///
62+
/// `thresh.mapped((0..thresh.n()).map(|element| some_function(element)).collect())`
63+
///
64+
/// # Panics
65+
///
66+
/// Panics if the new vector is not the same length as the
67+
/// original i.e., `new.len() != self.n()`.
68+
pub(crate) fn mapped<U>(&self, new: Vec<U>) -> Threshold<U> {
69+
if self.n() != new.len() {
70+
panic!("cannot map to a different length vector")
71+
}
72+
Threshold { k: self.k(), v: new }
73+
}
74+
}
75+
76+
/// An error attempting to construct a `Threshold<T>`.
77+
#[derive(Debug, Clone, PartialEq, Eq)]
78+
#[non_exhaustive]
79+
pub enum Error {
80+
/// Threshold `n` value must be non-zero.
81+
ZeroN,
82+
/// Threshold `k` value must be non-zero.
83+
ZeroK,
84+
/// Threshold `k` value must be <= `n`.
85+
BigK,
86+
}
87+
88+
impl fmt::Display for Error {
89+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
90+
use Error::*;
91+
92+
match *self {
93+
ZeroN => f.write_str("threshold `n` value must be non-zero"),
94+
ZeroK => f.write_str("threshold `k` value must be non-zero"),
95+
BigK => f.write_str("threshold `k` value must be <= `n`"),
96+
}
97+
}
98+
}
99+
100+
#[cfg(feature = "std")]
101+
impl std::error::Error for Error {
102+
fn cause(&self) -> Option<&dyn std::error::Error> {
103+
use Error::*;
104+
105+
match *self {
106+
ZeroN | ZeroK | BigK => None,
107+
}
108+
}
109+
}
110+
111+
#[cfg(test)]
112+
mod tests {
113+
use super::*;
114+
115+
#[test]
116+
fn threshold_constructor_valid() {
117+
let v = vec![1, 2, 3];
118+
let n = 3;
119+
120+
for k in 1..=3 {
121+
let thresh = Threshold::new(k, v.clone()).expect("failed to create threshold");
122+
assert_eq!(thresh.k(), k);
123+
assert_eq!(thresh.n(), n);
124+
}
125+
}
126+
127+
#[test]
128+
fn threshold_constructor_invalid() {
129+
let v = vec![1, 2, 3];
130+
assert!(Threshold::new(0, v.clone()).is_err());
131+
assert!(Threshold::new(4, v.clone()).is_err());
132+
}
133+
}

0 commit comments

Comments
 (0)