Skip to content

Commit 2efc4ee

Browse files
committed
decide that bsatn byte > 1 is not a valid bool
1 parent 1992b8e commit 2efc4ee

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

crates/sats/src/bsatn.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,5 +179,11 @@ mod tests {
179179
let val_decoded = AlgebraicValue::decode(&ty, &mut &bytes[..]).unwrap();
180180
prop_assert_eq!(val, val_decoded);
181181
}
182+
183+
#[test]
184+
fn bsatn_non_zero_one_u8_aint_bool(val in 2u8..) {
185+
let bytes = [val];
186+
prop_assert!(AlgebraicValue::decode(&AlgebraicType::Bool, &mut &bytes[..]).is_err());
187+
}
182188
}
183189
}

crates/sats/src/bsatn/de.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ impl<'de, 'a, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'a, R> {
6060
}
6161

6262
fn deserialize_bool(self) -> Result<bool, Self::Error> {
63-
self.reader.get_u8().map(|x| x != 0)
63+
let byte = self.reader.get_u8()?;
64+
match byte {
65+
0 => Ok(false),
66+
1 => Ok(true),
67+
b => Err(DecodeError::InvalidBool(b)),
68+
}
6469
}
6570
fn deserialize_u8(self) -> Result<u8, DecodeError> {
6671
self.reader.get_u8()

crates/sats/src/buffer.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ pub enum DecodeError {
2020
InvalidTag { tag: u8, sum_name: Option<String> },
2121
/// Expected data to be UTF-8 but it wasn't.
2222
InvalidUtf8,
23+
/// Expected the byte to be 0 or 1 to be a valid bool.
24+
InvalidBool(u8),
2325
/// Custom error not in the other variants of `DecodeError`.
2426
Other(String),
2527
}
@@ -40,6 +42,7 @@ impl fmt::Display for DecodeError {
4042
)
4143
}
4244
DecodeError::InvalidUtf8 => f.write_str("invalid utf8"),
45+
DecodeError::InvalidBool(byte) => write!(f, "byte {byte} not valid as `bool` (must be 0 or 1)"),
4346
DecodeError::Other(err) => f.write_str(err),
4447
}
4548
}
@@ -158,7 +161,7 @@ pub trait BufReader<'de> {
158161
/// Reads and returns a byte slice of `.len() = size` advancing the cursor.
159162
#[inline]
160163
fn get_slice(&mut self, size: usize) -> Result<&'de [u8], DecodeError> {
161-
self.get_chunk(size).ok_or(DecodeError::BufferLength {
164+
self.get_chunk(size).ok_or_else(|| DecodeError::BufferLength {
162165
for_type: "[u8]",
163166
expected: size,
164167
given: self.remaining(),
@@ -168,7 +171,7 @@ pub trait BufReader<'de> {
168171
/// Reads an array of type `[u8; N]` from the input.
169172
#[inline]
170173
fn get_array<const N: usize>(&mut self) -> Result<&'de [u8; N], DecodeError> {
171-
self.get_array_chunk().ok_or(DecodeError::BufferLength {
174+
self.get_array_chunk().ok_or_else(|| DecodeError::BufferLength {
172175
for_type: "[u8; _]",
173176
expected: N,
174177
given: self.remaining(),

0 commit comments

Comments
 (0)