Skip to content

Commit 603773d

Browse files
committed
safety check on write_fixed_bitset
1 parent fb58431 commit 603773d

File tree

1 file changed

+55
-40
lines changed

1 file changed

+55
-40
lines changed

dash/src/consensus/encode.rs

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,12 @@ pub fn read_fixed_bitset<R: Read + ?Sized>(r: &mut R, size: usize) -> std::io::R
10051005
}
10061006

10071007
pub fn write_fixed_bitset<W: Write + ?Sized>(w: &mut W, bits: &[bool], size: usize) -> io::Result<usize> {
1008+
if bits.len() < size {
1009+
return Err(io::Error::new(
1010+
io::ErrorKind::InvalidInput,
1011+
"Bits length is less than the specified size",
1012+
));
1013+
}
10081014
// Define a reasonable maximum size to prevent excessive memory allocation
10091015
const MAX_BITSET_SIZE: usize = 1_000_000;
10101016
if size > MAX_BITSET_SIZE {
@@ -1491,56 +1497,65 @@ mod tests {
14911497
#[test]
14921498
fn test_fixed_bitset_round_trip() {
14931499
let test_cases = vec![
1494-
(vec![], 0),
1495-
(vec![true, false, true, false, true, false, true, false], 8),
1496-
(vec![true; 10], 10),
1497-
(vec![false; 15], 15),
1498-
(vec![true, false, true], 16), // size greater than bits.len()
1500+
(vec![], 0, true), // (bits, size, expect_success)
1501+
(vec![true, false, true, false, true, false, true, false], 8, true),
1502+
(vec![true; 10], 10, true),
1503+
(vec![false; 15], 15, true),
1504+
(vec![true, false, true], 16, false), // size greater than bits.len()
14991505
(
15001506
vec![
15011507
true, false, true, false, true, false, true, false, true, false, true, false,
15021508
true, false, true, false, true, false, true, false, true, false, true, false,
15031509
],
15041510
24,
1511+
true,
15051512
),
15061513
];
15071514

1508-
for (bits, size) in test_cases {
1515+
for (bits, size, expect_success) in test_cases {
15091516
let mut buffer = Vec::new();
1510-
// Write the bitset to the buffer
1511-
let bytes_written = write_fixed_bitset(&mut buffer, &bits, size).expect("Failed to write");
1512-
// Calculate expected bytes written
1513-
let expected_bytes = (size + 7) / 8;
1514-
assert_eq!(
1515-
bytes_written, expected_bytes,
1516-
"Incorrect number of bytes written for bitset with size {}",
1517-
size
1518-
);
1519-
1520-
// Read the bitset back from the buffer
1521-
let mut cursor = Cursor::new(&buffer);
1522-
let read_bits = read_fixed_bitset(&mut cursor, size).expect("Failed to read");
1523-
1524-
// Assert that the original bits match the deserialized bits
1525-
// For bits beyond bits.len(), they should be false
1526-
let expected_bits: Vec<bool> = (0..size)
1527-
.map(|i| bits.get(i).copied().unwrap_or(false))
1528-
.collect();
1529-
1530-
assert_eq!(
1531-
read_bits, expected_bits,
1532-
"Deserialized bits do not match original for size {}",
1533-
size
1534-
);
1535-
1536-
// Ensure that we've consumed all bytes (no extra bytes left)
1537-
let position = cursor.position();
1538-
assert_eq!(
1539-
position as usize,
1540-
buffer.len(),
1541-
"Not all bytes were consumed for size {}",
1542-
size
1543-
);
1517+
// Attempt to write the bitset to the buffer
1518+
let result = write_fixed_bitset(&mut buffer, &bits, size);
1519+
1520+
if expect_success {
1521+
// Expect the write to succeed
1522+
let bytes_written = result.expect("Failed to write");
1523+
// Calculate expected bytes written
1524+
let expected_bytes = (size + 7) / 8;
1525+
assert_eq!(
1526+
bytes_written, expected_bytes,
1527+
"Incorrect number of bytes written for bitset with size {}",
1528+
size
1529+
);
1530+
1531+
// Read the bitset back from the buffer
1532+
let mut cursor = Cursor::new(&buffer);
1533+
let read_bits = read_fixed_bitset(&mut cursor, size).expect("Failed to read");
1534+
1535+
// Assert that the original bits match the deserialized bits
1536+
assert_eq!(
1537+
read_bits, bits,
1538+
"Deserialized bits do not match original for size {}",
1539+
size
1540+
);
1541+
1542+
// Ensure that we've consumed all bytes (no extra bytes left)
1543+
let position = cursor.position();
1544+
assert_eq!(
1545+
position as usize,
1546+
buffer.len(),
1547+
"Not all bytes were consumed for size {}",
1548+
size
1549+
);
1550+
} else {
1551+
// Expect the write to fail
1552+
assert!(
1553+
result.is_err(),
1554+
"Expected write to fail for bits.len() < size (size: {}, bits.len(): {})",
1555+
size,
1556+
bits.len()
1557+
);
1558+
}
15441559
}
15451560
}
15461561
}

0 commit comments

Comments
 (0)