Skip to content

Commit 590f506

Browse files
authored
Fix BitReader::get_batch zero extension (#1708) (#1722)
* Fix BitReader::get_batch zero extension (#1708) * Fix tests
1 parent 4de6895 commit 590f506

File tree

1 file changed

+43
-33
lines changed

1 file changed

+43
-33
lines changed

parquet/src/util/bit_util.rs

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -568,40 +568,35 @@ impl BitReader {
568568
}
569569
}
570570

571-
unsafe {
572-
let in_buf = &self.buffer.data()[self.byte_offset..];
573-
let mut in_ptr = in_buf as *const [u8] as *const u8 as *const u32;
574-
if size_of::<T>() == 4 {
575-
while values_to_read - i >= 32 {
576-
let out_ptr = &mut batch[i..] as *mut [T] as *mut T as *mut u32;
577-
in_ptr = unpack32(in_ptr, out_ptr, num_bits);
578-
self.byte_offset += 4 * num_bits;
579-
i += 32;
580-
}
581-
} else {
582-
let mut out_buf = [0u32; 32];
583-
let out_ptr = &mut out_buf as &mut [u32] as *mut [u32] as *mut u32;
584-
while values_to_read - i >= 32 {
585-
in_ptr = unpack32(in_ptr, out_ptr, num_bits);
586-
self.byte_offset += 4 * num_bits;
587-
for n in 0..32 {
588-
// We need to copy from smaller size to bigger size to avoid
589-
// overwriting other memory regions.
590-
if size_of::<T>() > size_of::<u32>() {
591-
std::ptr::copy_nonoverlapping(
592-
out_buf[n..].as_ptr() as *const u32,
593-
&mut batch[i] as *mut T as *mut u32,
594-
1,
595-
);
596-
} else {
597-
std::ptr::copy_nonoverlapping(
598-
out_buf[n..].as_ptr() as *const T,
599-
&mut batch[i] as *mut T,
600-
1,
601-
);
602-
}
603-
i += 1;
571+
let in_buf = &self.buffer.data()[self.byte_offset..];
572+
let mut in_ptr = in_buf as *const [u8] as *const u8 as *const u32;
573+
if size_of::<T>() == 4 {
574+
while values_to_read - i >= 32 {
575+
let out_ptr = &mut batch[i..] as *mut [T] as *mut T as *mut u32;
576+
in_ptr = unsafe { unpack32(in_ptr, out_ptr, num_bits) };
577+
self.byte_offset += 4 * num_bits;
578+
i += 32;
579+
}
580+
} else {
581+
let mut out_buf = [0u32; 32];
582+
let out_ptr = &mut out_buf as &mut [u32] as *mut [u32] as *mut u32;
583+
while values_to_read - i >= 32 {
584+
in_ptr = unsafe { unpack32(in_ptr, out_ptr, num_bits) };
585+
self.byte_offset += 4 * num_bits;
586+
587+
for out in out_buf {
588+
// Zero-allocate buffer
589+
let mut out_bytes = T::Buffer::default();
590+
let in_bytes = out.to_le_bytes();
591+
592+
{
593+
let out_bytes = out_bytes.as_mut();
594+
let len = out_bytes.len().min(in_bytes.len());
595+
(&mut out_bytes[..len]).copy_from_slice(&in_bytes[..len]);
604596
}
597+
598+
batch[i] = T::from_le_bytes(out_bytes);
599+
i += 1;
605600
}
606601
}
607602
}
@@ -1193,4 +1188,19 @@ mod tests {
11931188
);
11941189
});
11951190
}
1191+
1192+
#[test]
1193+
fn test_get_batch_zero_extend() {
1194+
let to_read = vec![0xFF; 4];
1195+
let mut reader = BitReader::new(ByteBufferPtr::new(to_read));
1196+
1197+
// Create a non-zeroed output buffer
1198+
let mut output = [u64::MAX; 32];
1199+
reader.get_batch(&mut output, 1);
1200+
1201+
for v in output {
1202+
// Values should be read correctly
1203+
assert_eq!(v, 1);
1204+
}
1205+
}
11961206
}

0 commit comments

Comments
 (0)