Skip to content

Commit 1c4c5ed

Browse files
committed
Add a Threshold<T> type
We have various enums in the codebase that include a `Thresh` variant, we have to explicitly check that invariants are maintained all over the place because these enums are public (eg, `policy::Concrete`). Add a `Threshold<T>` type that abstracts over a threshold and maintains the following invariants: - v.len() > 0 - k > 0 - k <= v.len()
1 parent b60a702 commit 1c4c5ed

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ pub mod miniscript;
126126
pub mod plan;
127127
pub mod policy;
128128
pub mod psbt;
129+
pub mod threshold;
129130

130131
#[cfg(test)]
131132
mod test_utils;

src/threshold.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// SPDX-License-Identifier: CC0-1.0
2+
3+
//! A generic (k,n)-threshold type.
4+
5+
use core::fmt;
6+
7+
/// A (k, n)-threshold.
8+
///
9+
/// This type maintains the following invariants:
10+
/// - n > 0
11+
/// - k > 0
12+
/// - k <= n
13+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
14+
pub struct Threshold<T> {
15+
k: usize,
16+
v: Vec<T>,
17+
}
18+
19+
impl<T> Threshold<T> {
20+
/// Creates a `Theshold<T>` after checking that invariants hold.
21+
pub fn new(k: usize, v: Vec<T>) -> Result<Threshold<T>, Error> {
22+
if v.len() == 0 {
23+
Err(Error::ZeroN)
24+
} else if k == 0 {
25+
Err(Error::ZeroK)
26+
} else if k > v.len() {
27+
Err(Error::BigK)
28+
} else {
29+
Ok(Threshold { k, v })
30+
}
31+
}
32+
33+
/// Creates a `Theshold<T>` without checking that invariants hold.
34+
pub fn new_unchecked(k: usize, v: Vec<T>) -> Threshold<T> { Threshold { k, v } }
35+
36+
/// Returns `k`, the threshold value.
37+
pub fn k(&self) -> usize { self.k }
38+
39+
/// Returns `n`, the total number of elements in the threshold.
40+
pub fn n(&self) -> usize { self.v.len() }
41+
42+
/// Returns a read-only iterator over the threshold elements.
43+
pub fn iter(&self) -> core::slice::Iter<'_, T> { self.v.iter() }
44+
45+
/// Returns the threshold elements, consuming self.
46+
// TODO: Find a better name for this functiion.
47+
pub fn into_elements(self) -> Vec<T> { self.v }
48+
}
49+
50+
/// An error attempting to construct a `Threshold<T>`.
51+
#[derive(Debug, Clone, PartialEq, Eq)]
52+
#[non_exhaustive]
53+
pub enum Error {
54+
/// Threshold `n` value must be non-zero.
55+
ZeroN,
56+
/// Threshold `k` value must be non-zero.
57+
ZeroK,
58+
/// Threshold `k` value must be <= `n`.
59+
BigK,
60+
}
61+
62+
impl fmt::Display for Error {
63+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64+
use Error::*;
65+
66+
match *self {
67+
ZeroN => f.write_str("threshold `n` value must be non-zero"),
68+
ZeroK => f.write_str("threshold `k` value must be non-zero"),
69+
BigK => f.write_str("threshold `k` value must be <= `n`"),
70+
}
71+
}
72+
}
73+
74+
#[cfg(feature = "std")]
75+
impl std::error::Error for Error {
76+
fn cause(&self) -> Option<&dyn std::error::Error> {
77+
use Error::*;
78+
79+
match *self {
80+
ZeroN | ZeroK | BigK => None,
81+
}
82+
}
83+
}

0 commit comments

Comments
 (0)