Skip to content

Commit 3a4be19

Browse files
committed
refactor: improve run boundary computation in RunArray casting
1 parent f83fd31 commit 3a4be19

File tree

1 file changed

+234
-15
lines changed

1 file changed

+234
-15
lines changed

arrow-cast/src/cast/run_array.rs

Lines changed: 234 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,21 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::vec;
19-
2018
use crate::cast::*;
21-
use arrow_array::Array;
19+
use arrow_array::cast::AsArray;
20+
use arrow_array::types::{
21+
ArrowDictionaryKeyType, ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type,
22+
Decimal256Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
23+
DurationSecondType, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
24+
Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
25+
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
26+
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
27+
TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
28+
};
29+
use arrow_array::{
30+
Array, ArrayRef, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray,
31+
GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, StringViewArray,
32+
};
2233

2334
/// Attempts to cast a `RunArray` with index type K into
2435
/// `to_type` for supported types.
@@ -137,18 +148,7 @@ pub(crate) fn cast_to_run_end_encoded<K: RunEndIndexType>(
137148
}
138149

139150
// Identify run boundaries by comparing consecutive values
140-
let mut run_ends = Vec::new();
141-
let mut values_indexes = vec![0usize]; // Always include the first index
142-
let mut current_data = cast_array.slice(0, 1).to_data();
143-
for idx in 1..cast_array.len() {
144-
let next_data = cast_array.slice(idx, 1).to_data();
145-
if current_data != next_data {
146-
run_ends.push(idx);
147-
values_indexes.push(idx);
148-
current_data = next_data;
149-
}
150-
}
151-
run_ends.push(cast_array.len());
151+
let (run_ends, values_indexes) = compute_run_boundaries(cast_array);
152152

153153
// Build the run_ends array
154154
for run_end in run_ends {
@@ -167,3 +167,222 @@ pub(crate) fn cast_to_run_end_encoded<K: RunEndIndexType>(
167167
let run_array = RunArray::<K>::try_new(&run_ends_array, values_array.as_ref())?;
168168
Ok(Arc::new(run_array))
169169
}
170+
171+
// Compute run boundaries for the given array.
172+
fn compute_run_boundaries(array: &ArrayRef) -> (Vec<usize>, Vec<usize>) {
173+
use arrow_schema::{DataType::*, IntervalUnit, TimeUnit};
174+
match array.data_type() {
175+
Null => runs_for_null(array.len()),
176+
Boolean => runs_for_boolean(array.as_boolean()),
177+
Int8 => runs_for_primitive(array.as_primitive::<Int8Type>()),
178+
Int16 => runs_for_primitive(array.as_primitive::<Int16Type>()),
179+
Int32 => runs_for_primitive(array.as_primitive::<Int32Type>()),
180+
Int64 => runs_for_primitive(array.as_primitive::<Int64Type>()),
181+
UInt8 => runs_for_primitive(array.as_primitive::<UInt8Type>()),
182+
UInt16 => runs_for_primitive(array.as_primitive::<UInt16Type>()),
183+
UInt32 => runs_for_primitive(array.as_primitive::<UInt32Type>()),
184+
UInt64 => runs_for_primitive(array.as_primitive::<UInt64Type>()),
185+
Float16 => runs_for_primitive(array.as_primitive::<Float16Type>()),
186+
Float32 => runs_for_primitive(array.as_primitive::<Float32Type>()),
187+
Float64 => runs_for_primitive(array.as_primitive::<Float64Type>()),
188+
Date32 => runs_for_primitive(array.as_primitive::<Date32Type>()),
189+
Date64 => runs_for_primitive(array.as_primitive::<Date64Type>()),
190+
Time32(TimeUnit::Second) => runs_for_primitive(array.as_primitive::<Time32SecondType>()),
191+
Time32(TimeUnit::Millisecond) => {
192+
runs_for_primitive(array.as_primitive::<Time32MillisecondType>())
193+
}
194+
Time64(TimeUnit::Microsecond) => {
195+
runs_for_primitive(array.as_primitive::<Time64MicrosecondType>())
196+
}
197+
Time64(TimeUnit::Nanosecond) => {
198+
runs_for_primitive(array.as_primitive::<Time64NanosecondType>())
199+
}
200+
Duration(TimeUnit::Second) => {
201+
runs_for_primitive(array.as_primitive::<DurationSecondType>())
202+
}
203+
Duration(TimeUnit::Millisecond) => {
204+
runs_for_primitive(array.as_primitive::<DurationMillisecondType>())
205+
}
206+
Duration(TimeUnit::Microsecond) => {
207+
runs_for_primitive(array.as_primitive::<DurationMicrosecondType>())
208+
}
209+
Duration(TimeUnit::Nanosecond) => {
210+
runs_for_primitive(array.as_primitive::<DurationNanosecondType>())
211+
}
212+
Timestamp(TimeUnit::Second, _) => {
213+
runs_for_primitive(array.as_primitive::<TimestampSecondType>())
214+
}
215+
Timestamp(TimeUnit::Millisecond, _) => {
216+
runs_for_primitive(array.as_primitive::<TimestampMillisecondType>())
217+
}
218+
Timestamp(TimeUnit::Microsecond, _) => {
219+
runs_for_primitive(array.as_primitive::<TimestampMicrosecondType>())
220+
}
221+
Timestamp(TimeUnit::Nanosecond, _) => {
222+
runs_for_primitive(array.as_primitive::<TimestampNanosecondType>())
223+
}
224+
Interval(IntervalUnit::YearMonth) => {
225+
runs_for_primitive(array.as_primitive::<IntervalYearMonthType>())
226+
}
227+
Interval(IntervalUnit::DayTime) => {
228+
runs_for_primitive(array.as_primitive::<IntervalDayTimeType>())
229+
}
230+
Interval(IntervalUnit::MonthDayNano) => {
231+
runs_for_primitive(array.as_primitive::<IntervalMonthDayNanoType>())
232+
}
233+
Decimal128(_, _) => runs_for_primitive(array.as_primitive::<Decimal128Type>()),
234+
Decimal256(_, _) => runs_for_primitive(array.as_primitive::<Decimal256Type>()),
235+
Utf8 => runs_for_string(array.as_string::<i32>()),
236+
LargeUtf8 => runs_for_string(array.as_string::<i64>()),
237+
Utf8View => runs_for_string_view(array.as_string_view()),
238+
Binary => runs_for_binary(array.as_binary::<i32>()),
239+
LargeBinary => runs_for_binary(array.as_binary::<i64>()),
240+
BinaryView => runs_for_binary_view(array.as_binary_view()),
241+
FixedSizeBinary(_) => runs_for_fixed_size_binary(array.as_fixed_size_binary()),
242+
Dictionary(key_type, _) => match key_type.as_ref() {
243+
Int8 => runs_for_dictionary::<Int8Type>(array.as_dictionary()),
244+
Int16 => runs_for_dictionary::<Int16Type>(array.as_dictionary()),
245+
Int32 => runs_for_dictionary::<Int32Type>(array.as_dictionary()),
246+
Int64 => runs_for_dictionary::<Int64Type>(array.as_dictionary()),
247+
UInt8 => runs_for_dictionary::<UInt8Type>(array.as_dictionary()),
248+
UInt16 => runs_for_dictionary::<UInt16Type>(array.as_dictionary()),
249+
UInt32 => runs_for_dictionary::<UInt32Type>(array.as_dictionary()),
250+
UInt64 => runs_for_dictionary::<UInt64Type>(array.as_dictionary()),
251+
_ => runs_generic(array),
252+
},
253+
_ => runs_generic(array),
254+
}
255+
}
256+
257+
fn runs_for_null(len: usize) -> (Vec<usize>, Vec<usize>) {
258+
(vec![len], vec![0])
259+
}
260+
261+
fn runs_for_boolean(array: &BooleanArray) -> (Vec<usize>, Vec<usize>) {
262+
build_runs_with_nulls(
263+
array.len(),
264+
array.null_count() > 0,
265+
|idx| array.is_valid(idx),
266+
|idx| array.value(idx),
267+
)
268+
}
269+
270+
fn runs_for_primitive<T: ArrowPrimitiveType>(
271+
array: &PrimitiveArray<T>,
272+
) -> (Vec<usize>, Vec<usize>) {
273+
build_runs_with_nulls(
274+
array.len(),
275+
array.null_count() > 0,
276+
|idx| array.is_valid(idx),
277+
|idx| array.value(idx),
278+
)
279+
}
280+
281+
fn runs_for_binary<O: OffsetSizeTrait>(array: &GenericBinaryArray<O>) -> (Vec<usize>, Vec<usize>) {
282+
build_runs_with_nulls(
283+
array.len(),
284+
array.null_count() > 0,
285+
|idx| array.is_valid(idx),
286+
|idx| array.value(idx),
287+
)
288+
}
289+
290+
fn runs_for_string<O: OffsetSizeTrait>(array: &GenericStringArray<O>) -> (Vec<usize>, Vec<usize>) {
291+
build_runs_with_nulls(
292+
array.len(),
293+
array.null_count() > 0,
294+
|idx| array.is_valid(idx),
295+
|idx| array.value(idx),
296+
)
297+
}
298+
299+
fn runs_for_binary_view(array: &BinaryViewArray) -> (Vec<usize>, Vec<usize>) {
300+
build_runs_with_nulls(
301+
array.len(),
302+
array.null_count() > 0,
303+
|idx| array.is_valid(idx),
304+
|idx| array.value(idx),
305+
)
306+
}
307+
308+
fn runs_for_string_view(array: &StringViewArray) -> (Vec<usize>, Vec<usize>) {
309+
build_runs_with_nulls(
310+
array.len(),
311+
array.null_count() > 0,
312+
|idx| array.is_valid(idx),
313+
|idx| array.value(idx),
314+
)
315+
}
316+
317+
fn runs_for_fixed_size_binary(array: &FixedSizeBinaryArray) -> (Vec<usize>, Vec<usize>) {
318+
build_runs_with_nulls(
319+
array.len(),
320+
array.null_count() > 0,
321+
|idx| array.is_valid(idx),
322+
|idx| array.value(idx),
323+
)
324+
}
325+
326+
fn runs_for_dictionary<K: ArrowDictionaryKeyType>(
327+
array: &DictionaryArray<K>,
328+
) -> (Vec<usize>, Vec<usize>) {
329+
runs_for_primitive(array.keys())
330+
}
331+
332+
fn build_runs(len: usize, mut equal: impl FnMut(usize, usize) -> bool) -> (Vec<usize>, Vec<usize>) {
333+
let mut run_ends = Vec::new();
334+
let mut values_indexes = vec![0usize];
335+
if len == 0 {
336+
return (run_ends, values_indexes);
337+
}
338+
let mut current = 0usize;
339+
for idx in 1..len {
340+
if !equal(current, idx) {
341+
run_ends.push(idx);
342+
values_indexes.push(idx);
343+
current = idx;
344+
}
345+
}
346+
run_ends.push(len);
347+
(run_ends, values_indexes)
348+
}
349+
350+
fn build_runs_with_nulls<T: PartialEq>(
351+
len: usize,
352+
has_nulls: bool,
353+
mut is_valid: impl FnMut(usize) -> bool,
354+
mut value_at: impl FnMut(usize) -> T,
355+
) -> (Vec<usize>, Vec<usize>) {
356+
build_runs(len, |lhs, rhs| {
357+
if has_nulls {
358+
let lhs_valid = is_valid(lhs);
359+
let rhs_valid = is_valid(rhs);
360+
if lhs_valid && rhs_valid {
361+
value_at(lhs) == value_at(rhs)
362+
} else {
363+
lhs_valid == rhs_valid
364+
}
365+
} else {
366+
value_at(lhs) == value_at(rhs)
367+
}
368+
})
369+
}
370+
371+
fn runs_generic(array: &ArrayRef) -> (Vec<usize>, Vec<usize>) {
372+
let mut run_ends = Vec::new();
373+
let mut values_indexes = vec![0usize];
374+
if array.len() == 0 {
375+
return (run_ends, values_indexes);
376+
}
377+
let mut current_data = array.slice(0, 1).to_data();
378+
for idx in 1..array.len() {
379+
let next_data = array.slice(idx, 1).to_data();
380+
if current_data != next_data {
381+
run_ends.push(idx);
382+
values_indexes.push(idx);
383+
current_data = next_data;
384+
}
385+
}
386+
run_ends.push(array.len());
387+
(run_ends, values_indexes)
388+
}

0 commit comments

Comments
 (0)