Skip to content

Commit 2c2f225

Browse files
authored
Return an error on overflow in do_append_val_inner (#16201)
* Return an error on overflow in `do_append_val_inner` * Refactor
1 parent 72f4eab commit 2c2f225

File tree

4 files changed

+121
-81
lines changed

4 files changed

+121
-81
lines changed

datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use arrow::array::{
2424
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
2525
use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType};
2626
use datafusion_common::utils::proxy::VecAllocExt;
27+
use datafusion_common::{DataFusionError, Result};
2728
use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY};
2829
use itertools::izip;
2930
use std::mem::size_of;
@@ -80,7 +81,7 @@ where
8081
self.do_equal_to_inner(lhs_row, array, rhs_row)
8182
}
8283

83-
fn append_val_inner<B>(&mut self, array: &ArrayRef, row: usize)
84+
fn append_val_inner<B>(&mut self, array: &ArrayRef, row: usize) -> Result<()>
8485
where
8586
B: ByteArrayType,
8687
{
@@ -92,8 +93,10 @@ where
9293
self.offsets.push(O::usize_as(offset));
9394
} else {
9495
self.nulls.append(false);
95-
self.do_append_val_inner(arr, row);
96+
self.do_append_val_inner(arr, row)?;
9697
}
98+
99+
Ok(())
97100
}
98101

99102
fn vectorized_equal_to_inner<B>(
@@ -123,7 +126,11 @@ where
123126
}
124127
}
125128

126-
fn vectorized_append_inner<B>(&mut self, array: &ArrayRef, rows: &[usize])
129+
fn vectorized_append_inner<B>(
130+
&mut self,
131+
array: &ArrayRef,
132+
rows: &[usize],
133+
) -> Result<()>
127134
where
128135
B: ByteArrayType,
129136
{
@@ -141,22 +148,14 @@ where
141148
match all_null_or_non_null {
142149
None => {
143150
for &row in rows {
144-
if arr.is_null(row) {
145-
self.nulls.append(true);
146-
// nulls need a zero length in the offset buffer
147-
let offset = self.buffer.len();
148-
self.offsets.push(O::usize_as(offset));
149-
} else {
150-
self.nulls.append(false);
151-
self.do_append_val_inner(arr, row);
152-
}
151+
self.append_val_inner::<B>(array, row)?
153152
}
154153
}
155154

156155
Some(true) => {
157156
self.nulls.append_n(rows.len(), false);
158157
for &row in rows {
159-
self.do_append_val_inner(arr, row);
158+
self.do_append_val_inner(arr, row)?;
160159
}
161160
}
162161

@@ -168,6 +167,8 @@ where
168167
self.offsets.resize(new_len, O::usize_as(offset));
169168
}
170169
}
170+
171+
Ok(())
171172
}
172173

173174
fn do_equal_to_inner<B>(
@@ -188,20 +189,26 @@ where
188189
self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8])
189190
}
190191

191-
fn do_append_val_inner<B>(&mut self, array: &GenericByteArray<B>, row: usize)
192+
fn do_append_val_inner<B>(
193+
&mut self,
194+
array: &GenericByteArray<B>,
195+
row: usize,
196+
) -> Result<()>
192197
where
193198
B: ByteArrayType,
194199
{
195200
let value: &[u8] = array.value(row).as_ref();
196201
self.buffer.append_slice(value);
197202

198-
assert!(
199-
self.buffer.len() <= self.max_buffer_size,
200-
"offset overflow, buffer size > {}",
201-
self.max_buffer_size
202-
);
203+
if self.buffer.len() > self.max_buffer_size {
204+
return Err(DataFusionError::Execution(format!(
205+
"offset overflow, buffer size > {}",
206+
self.max_buffer_size
207+
)));
208+
}
203209

204210
self.offsets.push(O::usize_as(self.buffer.len()));
211+
Ok(())
205212
}
206213

207214
/// return the current value of the specified row irrespective of null
@@ -238,25 +245,27 @@ where
238245
}
239246
}
240247

241-
fn append_val(&mut self, column: &ArrayRef, row: usize) {
248+
fn append_val(&mut self, column: &ArrayRef, row: usize) -> Result<()> {
242249
// Sanity array type
243250
match self.output_type {
244251
OutputType::Binary => {
245252
debug_assert!(matches!(
246253
column.data_type(),
247254
DataType::Binary | DataType::LargeBinary
248255
));
249-
self.append_val_inner::<GenericBinaryType<O>>(column, row)
256+
self.append_val_inner::<GenericBinaryType<O>>(column, row)?
250257
}
251258
OutputType::Utf8 => {
252259
debug_assert!(matches!(
253260
column.data_type(),
254261
DataType::Utf8 | DataType::LargeUtf8
255262
));
256-
self.append_val_inner::<GenericStringType<O>>(column, row)
263+
self.append_val_inner::<GenericStringType<O>>(column, row)?
257264
}
258265
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
259266
};
267+
268+
Ok(())
260269
}
261270

262271
fn vectorized_equal_to(
@@ -296,24 +305,26 @@ where
296305
}
297306
}
298307

299-
fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) {
308+
fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) -> Result<()> {
300309
match self.output_type {
301310
OutputType::Binary => {
302311
debug_assert!(matches!(
303312
column.data_type(),
304313
DataType::Binary | DataType::LargeBinary
305314
));
306-
self.vectorized_append_inner::<GenericBinaryType<O>>(column, rows)
315+
self.vectorized_append_inner::<GenericBinaryType<O>>(column, rows)?
307316
}
308317
OutputType::Utf8 => {
309318
debug_assert!(matches!(
310319
column.data_type(),
311320
DataType::Utf8 | DataType::LargeUtf8
312321
));
313-
self.vectorized_append_inner::<GenericStringType<O>>(column, rows)
322+
self.vectorized_append_inner::<GenericStringType<O>>(column, rows)?
314323
}
315324
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
316325
};
326+
327+
Ok(())
317328
}
318329

319330
fn len(&self) -> usize {
@@ -421,12 +432,12 @@ mod tests {
421432

422433
use crate::aggregates::group_values::multi_group_by::bytes::ByteGroupValueBuilder;
423434
use arrow::array::{ArrayRef, NullBufferBuilder, StringArray};
435+
use datafusion_common::DataFusionError;
424436
use datafusion_physical_expr::binary_map::OutputType;
425437

426438
use super::GroupColumn;
427439

428440
#[test]
429-
#[should_panic]
430441
fn test_byte_group_value_builder_overflow() {
431442
let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
432443

@@ -435,31 +446,36 @@ mod tests {
435446
let array =
436447
Arc::new(StringArray::from(vec![Some(large_string.as_str())])) as ArrayRef;
437448

438-
// Append items until our buffer length is 1 + i32::MAX as usize
439-
for _ in 0..2048 {
440-
builder.append_val(&array, 0);
449+
// Append items until our buffer length is i32::MAX as usize
450+
for _ in 0..2047 {
451+
builder.append_val(&array, 0).unwrap();
441452
}
442453

443-
assert_eq!(builder.value(2047), large_string.as_bytes());
454+
assert!(matches!(
455+
builder.append_val(&array, 0),
456+
Err(DataFusionError::Execution(e)) if e.contains("offset overflow")
457+
));
458+
459+
assert_eq!(builder.value(2046), large_string.as_bytes());
444460
}
445461

446462
#[test]
447463
fn test_byte_take_n() {
448464
let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
449465
let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef;
450466
// a, null, null
451-
builder.append_val(&array, 0);
452-
builder.append_val(&array, 1);
453-
builder.append_val(&array, 1);
467+
builder.append_val(&array, 0).unwrap();
468+
builder.append_val(&array, 1).unwrap();
469+
builder.append_val(&array, 1).unwrap();
454470

455471
// (a, null) remaining: null
456472
let output = builder.take_n(2);
457473
assert_eq!(&output, &array);
458474

459475
// null, a, null, a
460-
builder.append_val(&array, 0);
461-
builder.append_val(&array, 1);
462-
builder.append_val(&array, 0);
476+
builder.append_val(&array, 0).unwrap();
477+
builder.append_val(&array, 1).unwrap();
478+
builder.append_val(&array, 0).unwrap();
463479

464480
// (null, a) remaining: (null, a)
465481
let output = builder.take_n(2);
@@ -473,9 +489,9 @@ mod tests {
473489
])) as ArrayRef;
474490

475491
// null, a, longstringfortest, null, null
476-
builder.append_val(&array, 2);
477-
builder.append_val(&array, 1);
478-
builder.append_val(&array, 1);
492+
builder.append_val(&array, 2).unwrap();
493+
builder.append_val(&array, 1).unwrap();
494+
builder.append_val(&array, 1).unwrap();
479495

480496
// (null, a, longstringfortest, null) remaining: (null)
481497
let output = builder.take_n(4);
@@ -494,7 +510,7 @@ mod tests {
494510
builder_array: &ArrayRef,
495511
append_rows: &[usize]| {
496512
for &index in append_rows {
497-
builder.append_val(builder_array, index);
513+
builder.append_val(builder_array, index).unwrap();
498514
}
499515
};
500516

@@ -517,7 +533,9 @@ mod tests {
517533
let append = |builder: &mut ByteGroupValueBuilder<i32>,
518534
builder_array: &ArrayRef,
519535
append_rows: &[usize]| {
520-
builder.vectorized_append(builder_array, append_rows);
536+
builder
537+
.vectorized_append(builder_array, append_rows)
538+
.unwrap();
521539
};
522540

523541
let equal_to = |builder: &ByteGroupValueBuilder<i32>,
@@ -551,7 +569,9 @@ mod tests {
551569
None,
552570
None,
553571
])) as _;
554-
builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]);
572+
builder
573+
.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4])
574+
.unwrap();
555575

556576
let mut equal_to_results = vec![true; all_nulls_input_array.len()];
557577
builder.vectorized_equal_to(
@@ -575,7 +595,9 @@ mod tests {
575595
Some("string4"),
576596
Some("string5"),
577597
])) as _;
578-
builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]);
598+
builder
599+
.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4])
600+
.unwrap();
579601

580602
let mut equal_to_results = vec![true; all_not_nulls_input_array.len()];
581603
builder.vectorized_equal_to(

0 commit comments

Comments
 (0)