@@ -1005,6 +1005,12 @@ pub fn read_fixed_bitset<R: Read + ?Sized>(r: &mut R, size: usize) -> std::io::R
10051005}
10061006
10071007pub fn write_fixed_bitset < W : Write + ?Sized > ( w : & mut W , bits : & [ bool ] , size : usize ) -> io:: Result < usize > {
1008+ if bits. len ( ) < size {
1009+ return Err ( io:: Error :: new (
1010+ io:: ErrorKind :: InvalidInput ,
1011+ "Bits length is less than the specified size" ,
1012+ ) ) ;
1013+ }
10081014 // Define a reasonable maximum size to prevent excessive memory allocation
10091015 const MAX_BITSET_SIZE : usize = 1_000_000 ;
10101016 if size > MAX_BITSET_SIZE {
@@ -1491,56 +1497,65 @@ mod tests {
14911497 #[ test]
14921498 fn test_fixed_bitset_round_trip ( ) {
14931499 let test_cases = vec ! [
1494- ( vec![ ] , 0 ) ,
1495- ( vec![ true , false , true , false , true , false , true , false ] , 8 ) ,
1496- ( vec![ true ; 10 ] , 10 ) ,
1497- ( vec![ false ; 15 ] , 15 ) ,
1498- ( vec![ true , false , true ] , 16 ) , // size greater than bits.len()
1500+ ( vec![ ] , 0 , true ) , // (bits, size, expect_success)
1501+ ( vec![ true , false , true , false , true , false , true , false ] , 8 , true ) ,
1502+ ( vec![ true ; 10 ] , 10 , true ) ,
1503+ ( vec![ false ; 15 ] , 15 , true ) ,
1504+ ( vec![ true , false , true ] , 16 , false ) , // size greater than bits.len()
14991505 (
15001506 vec![
15011507 true , false , true , false , true , false , true , false , true , false , true , false ,
15021508 true , false , true , false , true , false , true , false , true , false , true , false ,
15031509 ] ,
15041510 24 ,
1511+ true ,
15051512 ) ,
15061513 ] ;
15071514
1508- for ( bits, size) in test_cases {
1515+ for ( bits, size, expect_success ) in test_cases {
15091516 let mut buffer = Vec :: new ( ) ;
1510- // Write the bitset to the buffer
1511- let bytes_written = write_fixed_bitset ( & mut buffer, & bits, size) . expect ( "Failed to write" ) ;
1512- // Calculate expected bytes written
1513- let expected_bytes = ( size + 7 ) / 8 ;
1514- assert_eq ! (
1515- bytes_written, expected_bytes,
1516- "Incorrect number of bytes written for bitset with size {}" ,
1517- size
1518- ) ;
1519-
1520- // Read the bitset back from the buffer
1521- let mut cursor = Cursor :: new ( & buffer) ;
1522- let read_bits = read_fixed_bitset ( & mut cursor, size) . expect ( "Failed to read" ) ;
1523-
1524- // Assert that the original bits match the deserialized bits
1525- // For bits beyond bits.len(), they should be false
1526- let expected_bits: Vec < bool > = ( 0 ..size)
1527- . map ( |i| bits. get ( i) . copied ( ) . unwrap_or ( false ) )
1528- . collect ( ) ;
1529-
1530- assert_eq ! (
1531- read_bits, expected_bits,
1532- "Deserialized bits do not match original for size {}" ,
1533- size
1534- ) ;
1535-
1536- // Ensure that we've consumed all bytes (no extra bytes left)
1537- let position = cursor. position ( ) ;
1538- assert_eq ! (
1539- position as usize ,
1540- buffer. len( ) ,
1541- "Not all bytes were consumed for size {}" ,
1542- size
1543- ) ;
1517+ // Attempt to write the bitset to the buffer
1518+ let result = write_fixed_bitset ( & mut buffer, & bits, size) ;
1519+
1520+ if expect_success {
1521+ // Expect the write to succeed
1522+ let bytes_written = result. expect ( "Failed to write" ) ;
1523+ // Calculate expected bytes written
1524+ let expected_bytes = ( size + 7 ) / 8 ;
1525+ assert_eq ! (
1526+ bytes_written, expected_bytes,
1527+ "Incorrect number of bytes written for bitset with size {}" ,
1528+ size
1529+ ) ;
1530+
1531+ // Read the bitset back from the buffer
1532+ let mut cursor = Cursor :: new ( & buffer) ;
1533+ let read_bits = read_fixed_bitset ( & mut cursor, size) . expect ( "Failed to read" ) ;
1534+
1535+ // Assert that the original bits match the deserialized bits
1536+ assert_eq ! (
1537+ read_bits, bits,
1538+ "Deserialized bits do not match original for size {}" ,
1539+ size
1540+ ) ;
1541+
1542+ // Ensure that we've consumed all bytes (no extra bytes left)
1543+ let position = cursor. position ( ) ;
1544+ assert_eq ! (
1545+ position as usize ,
1546+ buffer. len( ) ,
1547+ "Not all bytes were consumed for size {}" ,
1548+ size
1549+ ) ;
1550+ } else {
1551+ // Expect the write to fail
1552+ assert ! (
1553+ result. is_err( ) ,
1554+ "Expected write to fail for bits.len() < size (size: {}, bits.len(): {})" ,
1555+ size,
1556+ bits. len( )
1557+ ) ;
1558+ }
15441559 }
15451560 }
15461561}
0 commit comments