|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +use std::ops::Range; |
| 5 | + |
| 6 | +use datafusion_datasource::FileRange; |
| 7 | +use vortex::ArrayRef; |
| 8 | +use vortex::scan::ScanBuilder; |
| 9 | + |
| 10 | +/// If the file has a [`FileRange`](datafusion::datasource::listing::FileRange), we translate it into a row range in the file for the scan. |
| 11 | +pub(crate) fn apply_byte_range( |
| 12 | + file_range: FileRange, |
| 13 | + total_size: u64, |
| 14 | + row_count: u64, |
| 15 | + scan_builder: ScanBuilder<ArrayRef>, |
| 16 | +) -> ScanBuilder<ArrayRef> { |
| 17 | + let row_range = byte_range_to_row_range( |
| 18 | + file_range.start as u64..file_range.end as u64, |
| 19 | + row_count, |
| 20 | + total_size, |
| 21 | + ); |
| 22 | + |
| 23 | + scan_builder.with_row_range(row_range) |
| 24 | +} |
| 25 | + |
| 26 | +fn byte_range_to_row_range(byte_range: Range<u64>, row_count: u64, total_size: u64) -> Range<u64> { |
| 27 | + let average_row = total_size / row_count; |
| 28 | + assert!(average_row > 0, "A row must always have at least one byte"); |
| 29 | + |
| 30 | + let start_row = byte_range.start / average_row; |
| 31 | + let end_row = byte_range.end / average_row; |
| 32 | + |
| 33 | + // We take the min here as `end_row` might overshoot |
| 34 | + start_row..u64::min(row_count, end_row) |
| 35 | +} |
| 36 | + |
| 37 | +#[cfg(test)] |
| 38 | +mod tests { |
| 39 | + use std::ops::Range; |
| 40 | + |
| 41 | + use itertools::Itertools; |
| 42 | + use rstest::rstest; |
| 43 | + |
| 44 | + use crate::convert::ranges::byte_range_to_row_range; |
| 45 | + |
| 46 | + #[rstest] |
| 47 | + #[case(0..100, 100, 100, 0..100)] |
| 48 | + #[case(0..105, 100, 105, 0..100)] |
| 49 | + #[case(0..50, 100, 105, 0..50)] |
| 50 | + #[case(50..105, 100, 105, 50..100)] |
| 51 | + #[case(0..1, 4, 8, 0..0)] |
| 52 | + #[case(1..8, 4, 8, 0..4)] |
| 53 | + fn test_range_translation( |
| 54 | + #[case] byte_range: Range<u64>, |
| 55 | + #[case] row_count: u64, |
| 56 | + #[case] total_size: u64, |
| 57 | + #[case] expected: Range<u64>, |
| 58 | + ) { |
| 59 | + assert_eq!( |
| 60 | + byte_range_to_row_range(byte_range, row_count, total_size), |
| 61 | + expected |
| 62 | + ); |
| 63 | + } |
| 64 | + |
| 65 | + #[test] |
| 66 | + fn test_consecutive_ranges() { |
| 67 | + let row_count = 100; |
| 68 | + let total_size = 429; |
| 69 | + let bytes_a = 0..143; |
| 70 | + let bytes_b = 143..286; |
| 71 | + let bytes_c = 286..429; |
| 72 | + |
| 73 | + let rows_a = byte_range_to_row_range(bytes_a, row_count, total_size); |
| 74 | + let rows_b = byte_range_to_row_range(bytes_b, row_count, total_size); |
| 75 | + let rows_c = byte_range_to_row_range(bytes_c, row_count, total_size); |
| 76 | + |
| 77 | + assert_eq!(rows_a.end - rows_a.start, 35); |
| 78 | + assert_eq!(rows_b.end - rows_b.start, 36); |
| 79 | + assert_eq!(rows_c.end - rows_c.start, 29); |
| 80 | + |
| 81 | + assert_eq!(rows_a.start, 0); |
| 82 | + assert_eq!(rows_c.end, 100); |
| 83 | + for (left, right) in [rows_a, rows_b, rows_c].iter().tuple_windows() { |
| 84 | + assert_eq!(left.end, right.start); |
| 85 | + } |
| 86 | + } |
| 87 | +} |
0 commit comments