Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public CometListVector(

@Override
public ColumnarArray getArray(int i) {
if (isNullAt(i)) return null;
int start = listVector.getOffsetBuffer().getInt(i * ListVector.OFFSET_WIDTH);
int end = listVector.getOffsetBuffer().getInt((i + 1) * ListVector.OFFSET_WIDTH);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public CometMapVector(

@Override
public ColumnarMap getMap(int i) {
if (isNullAt(i)) return null;
int start = mapVector.getOffsetBuffer().getInt(i * MapVector.OFFSET_WIDTH);
int end = mapVector.getOffsetBuffer().getInt((i + 1) * MapVector.OFFSET_WIDTH);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ public double getDouble(int rowId) {

@Override
public UTF8String getUTF8String(int rowId) {
if (isNullAt(rowId)) return null;
if (!isBaseFixedWidthVector) {
BaseVariableWidthVector varWidthVector = (BaseVariableWidthVector) valueVector;
long offsetBufferAddress = varWidthVector.getOffsetBuffer().memoryAddress();
Expand All @@ -147,6 +148,7 @@ public UTF8String getUTF8String(int rowId) {

@Override
public byte[] getBinary(int rowId) {
if (isNullAt(rowId)) return null;
int offset;
int length;
if (valueVector instanceof BaseVariableWidthVector) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public boolean isFixedLength() {

@Override
public Decimal getDecimal(int i, int precision, int scale) {
if (isNullAt(i)) return null;
if (!useDecimal128 && precision <= Decimal.MAX_INT_DIGITS() && type instanceof IntegerType) {
return createDecimal(getInt(i), precision, scale);
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {
Expand Down
184 changes: 113 additions & 71 deletions native/spark-expr/src/array_funcs/array_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
// under the License.

use arrow::array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::datatypes::{DataType, Schema};
use arrow::{
array::{as_primitive_array, Capacities, MutableArrayData},
buffer::{NullBuffer, OffsetBuffer},
datatypes::ArrowNativeType,
record_batch::RecordBatch,
};
use datafusion::common::{
Expand Down Expand Up @@ -198,114 +197,124 @@ fn array_insert<O: OffsetSizeTrait>(
pos_array: &ArrayRef,
legacy_mode: bool,
) -> DataFusionResult<ColumnarValue> {
// The code is based on the implementation of the array_append from the Apache DataFusion
// https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513
//
// This code is also based on the implementation of the array_insert from the Apache Spark
// https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713
// Implementation aligned with Arrow's half-open offset ranges and Spark semantics.

let values = list_array.values();
let offsets = list_array.offsets();
let values_data = values.to_data();
let item_data = items_array.to_data();

// Estimate capacity (original values + inserted items upper bound)
let new_capacity = Capacities::Array(values_data.len() + item_data.len());

let mut mutable_values =
MutableArrayData::with_capacities(vec![&values_data, &item_data], true, new_capacity);

let mut new_offsets = vec![O::usize_as(0)];
let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());
// New offsets and top-level list validity bitmap
let mut new_offsets = Vec::with_capacity(list_array.len() + 1);
new_offsets.push(O::usize_as(0));
let mut list_valid = Vec::<bool>::with_capacity(list_array.len());

let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions
// Spark supports only Int32 position indices
let pos_data: &Int32Array = as_primitive_array(&pos_array);

for (row_index, offset_window) in offsets.windows(2).enumerate() {
let pos = pos_data.values()[row_index];
let start = offset_window[0].as_usize();
let end = offset_window[1].as_usize();
let is_item_null = items_array.is_null(row_index);
for (row_index, window) in offsets.windows(2).enumerate() {
let start = window[0].as_usize();
let end = window[1].as_usize();
let len = end - start;
let pos = pos_data.value(row_index);

if list_array.is_null(row_index) {
// In Spark if value of the array is NULL than nothing happens
mutable_values.extend_nulls(1);
new_offsets.push(new_offsets[row_index] + O::one());
new_nulls.push(false);
// Top-level list row is NULL: do not write any child values and do not advance offset
new_offsets.push(new_offsets[row_index]);
list_valid.push(false);
continue;
}

if pos == 0 {
return Err(DataFusionError::Internal(
"Position for array_insert should be greter or less than zero".to_string(),
"Position for array_insert should be greater or less than zero".to_string(),
));
}
Comment on lines +221 to 238
Copy link

@coderabbitai coderabbitai bot Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Handle NULL positions before calling Int32Array::value

Int32Array::value(row_index) panics when the slot is NULL. Spark allows array_insert to receive NULL positions (the entire result should be NULL in that row), but this implementation dereferences the value before checking for validity, so any NULL index triggers an abort instead of returning a NULL row. Please guard pos_data in the same way we already guard the list itself.

-        let pos = pos_data.value(row_index);
-
-        if list_array.is_null(row_index) {
+        if list_array.is_null(row_index) {
             // Top-level list row is NULL: do not write any child values and do not advance offset
             new_offsets.push(new_offsets[row_index]);
             list_valid.push(false);
             continue;
         }
 
+        if pos_data.is_null(row_index) {
+            new_offsets.push(new_offsets[row_index]);
+            list_valid.push(false);
+            continue;
+        }
+
+        let pos = pos_data.value(row_index);
+
         if pos == 0 {
             return Err(DataFusionError::Internal(
                 "Position for array_insert should be greater or less than zero".to_string(),
             ));
         }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for (row_index, window) in offsets.windows(2).enumerate() {
let start = window[0].as_usize();
let end = window[1].as_usize();
let len = end - start;
let pos = pos_data.value(row_index);
if list_array.is_null(row_index) {
// In Spark if value of the array is NULL than nothing happens
mutable_values.extend_nulls(1);
new_offsets.push(new_offsets[row_index] + O::one());
new_nulls.push(false);
// Top-level list row is NULL: do not write any child values and do not advance offset
new_offsets.push(new_offsets[row_index]);
list_valid.push(false);
continue;
}
if pos == 0 {
return Err(DataFusionError::Internal(
"Position for array_insert should be greter or less than zero".to_string(),
"Position for array_insert should be greater or less than zero".to_string(),
));
}
for (row_index, window) in offsets.windows(2).enumerate() {
let start = window[0].as_usize();
let end = window[1].as_usize();
let len = end - start;
if list_array.is_null(row_index) {
// Top-level list row is NULL: do not write any child values and do not advance offset
new_offsets.push(new_offsets[row_index]);
list_valid.push(false);
continue;
}
if pos_data.is_null(row_index) {
new_offsets.push(new_offsets[row_index]);
list_valid.push(false);
continue;
}
let pos = pos_data.value(row_index);
if pos == 0 {
return Err(DataFusionError::Internal(
"Position for array_insert should be greater or less than zero".to_string(),
));
}
🤖 Prompt for AI Agents
In native/spark-expr/src/array_funcs/array_insert.rs around lines 221 to 238,
the code calls pos_data.value(row_index) before checking whether the position
slot is NULL which panics on NULL positions; instead, first test
pos_data.is_null(row_index) and if true push the current new_offsets entry, push
false to list_valid and continue (so the result row is NULL), only then call
pos_data.value(row_index) and proceed with the existing pos==0 check and
subsequent logic.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback:The CodeRabbit AI reviewer is correct that a check for null is needed before using PrimitiveArray::value(): Note: This method does not check for nulls and the value is arbitrary if [is_null](Self::is_null) returns true for the index. (https://github.com/apache/arrow-rs/blob/a0db1985c3a0f3190cfc5166b428933a28c740f9/arrow-array/src/array/primitive_array.rs#L766-L767)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martin-augment, thank you for confirming the issue and providing the Arrow documentation reference! You're absolutely right—the docs clearly state that value() doesn't perform null checks and returns arbitrary data for null slots.

The fix I suggested will ensure that when pos_data contains a NULL at any row, the code will properly produce a NULL result row instead of panicking or producing undefined behavior. This aligns with Spark's semantics where a NULL position should yield a NULL result.


✏️ Learnings added
Learnt from: martin-augment
Repo: martin-augment/datafusion-comet PR: 19
File: native/spark-expr/src/array_funcs/array_insert.rs:221-238
Timestamp: 2025-11-11T20:44:04.950Z
Learning: In Rust code using Apache Arrow arrays, always check `is_null(index)` before calling `value(index)` on `PrimitiveArray` types (such as `Int32Array`, `Int64Array`, etc.), because `value()` does not check for nulls and returns arbitrary values for null slots. This applies to functions in `native/spark-expr/src` that process Arrow arrays.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: martin-augment
Repo: martin-augment/datafusion-comet PR: 7
File: native/spark-expr/src/math_funcs/abs.rs:201-302
Timestamp: 2025-11-04T14:26:48.750Z
Learning: In the abs function in native/spark-expr/src/math_funcs/abs.rs (Rust), NULL values for signed integers (Int8, Int16, Int32, Int64) and decimals (Decimal128, Decimal256) should return the argument as-is (e.g., ColumnarValue::Scalar(ScalarValue::Int8(None))) rather than panicking on unwrap().


if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
let corrected_pos = if pos > 0 {
(pos - 1).as_usize()
} else {
end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 }
};
let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
return Err(DataFusionError::Internal(format!(
"Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}"
)));
}
let final_len: usize;

if (start + corrected_pos) <= end {
mutable_values.extend(0, start, start + corrected_pos);
if pos > 0 {
// Positive index (1-based)
let pos1 = pos as usize;
if pos1 <= len + 1 {
// In-range insertion (including appending to end)
let corrected = pos1 - 1; // 0-based insertion point
mutable_values.extend(0, start, start + corrected);
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend(0, start + corrected_pos, end);
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
mutable_values.extend(0, start + corrected, end);
final_len = len + 1;
} else {
// Beyond end: pad with nulls then insert
let corrected = pos1 - 1;
let padding = corrected - len;
mutable_values.extend(0, start, end);
mutable_values.extend_nulls(new_array_len - (end - start));
mutable_values.extend_nulls(padding);
mutable_values.extend(1, row_index, row_index + 1);
// In that case spark actualy makes array longer than expected;
// For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one());
final_len = corrected + 1; // equals pos1
}
} else {
// This comment is takes from the Apache Spark source code as is:
// special case- if the new position is negative but larger than the current array size
// place the new item at start of array, place the current array contents at the end
// and fill the newly created array elements inbetween with a null
let base_offset = if legacy_mode { 1 } else { 0 };
let new_array_len = (-pos + base_offset).as_usize();
if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
return Err(DataFusionError::Internal(format!(
"Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}"
)));
}
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend_nulls(new_array_len - (end - start + 1));
mutable_values.extend(0, start, end);
new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len));
}
if is_item_null {
if (start == end) || (values.is_null(row_index)) {
new_nulls.push(false)
// Negative index (1-based from the end)
let k = (-pos) as usize;

if k <= len {
// In-range negative insertion
// Non-legacy: -1 behaves like append to end (corrected = len - k + 1)
// Legacy: -1 behaves like insert before the last element (corrected = len - k)
let base_offset = if legacy_mode { 0 } else { 1 };
let corrected = len - k + base_offset;
mutable_values.extend(0, start, start + corrected);
mutable_values.extend(1, row_index, row_index + 1);
mutable_values.extend(0, start + corrected, end);
final_len = len + 1;
} else {
new_nulls.push(true)
// Negative index beyond the start (Spark-specific behavior):
// Place item first, then pad with nulls, then append the original array.
// Final length = k + base_offset, where base_offset = 1 in legacy mode, otherwise 0.
let base_offset = if legacy_mode { 1 } else { 0 };
let target_len = k + base_offset;
let padding = target_len.saturating_sub(len + 1);
mutable_values.extend(1, row_index, row_index + 1); // insert item first
mutable_values.extend_nulls(padding); // pad nulls
mutable_values.extend(0, start, end); // append original values
final_len = target_len;
}
} else {
new_nulls.push(true)
}

if final_len > MAX_ROUNDED_ARRAY_LENGTH {
return Err(DataFusionError::Internal(format!(
"Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH}, but got {final_len}"
)));
}

let prev = new_offsets[row_index].as_usize();
new_offsets.push(O::usize_as(prev + final_len));
list_valid.push(true);
}

let data = make_array(mutable_values.freeze());
let data_type = match list_array.data_type() {
DataType::List(field) => field.data_type(),
DataType::LargeList(field) => field.data_type(),
let child = make_array(mutable_values.freeze());

// Reuse the original list element field (name/type/nullability)
let elem_field = match list_array.data_type() {
DataType::List(field) => Arc::clone(field),
DataType::LargeList(field) => Arc::clone(field),
_ => unreachable!(),
};
let new_array = GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type.clone(), true)),

// Build the resulting list array
let new_list = GenericListArray::<O>::try_new(
elem_field,
OffsetBuffer::new(new_offsets.into()),
data,
Some(NullBuffer::new(new_nulls.into())),
child,
Some(NullBuffer::new(list_valid.into())),
)?;

Ok(ColumnarValue::Array(Arc::new(new_array)))
Ok(ColumnarValue::Array(Arc::new(new_list)))
}

impl Display for ArrayInsert {
Expand Down Expand Up @@ -442,4 +451,37 @@ mod test {

Ok(())
}

#[test]
fn test_array_insert_bug_repro_null_item_pos1_fixed() -> Result<()> {
use arrow::array::{Array, ArrayRef, Int32Array, ListArray};
use arrow::datatypes::Int32Type;

// row0 = [0, null, 0]
// row1 = [1, null, 1]
let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![Some(0), None, Some(0)]),
Some(vec![Some(1), None, Some(1)]),
]);

let positions = Int32Array::from(vec![1, 1]);
let items = Int32Array::from(vec![None, None]);

let ColumnarValue::Array(result) = array_insert(
&list,
&(Arc::new(items) as ArrayRef),
&(Arc::new(positions) as ArrayRef),
false, // legacy_mode = false
)?
else {
unreachable!()
};

let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Some(vec![None, Some(0), None, Some(0)]),
Some(vec![None, Some(1), None, Some(1)]),
]);
assert_eq!(&result.to_data(), &expected.to_data());
Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.ArrayType

import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus}
import org.apache.comet.DataTypeSupport.isComplexType
Expand Down Expand Up @@ -774,6 +775,30 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}

test("array_reverse 2") {
// This test validates data correctness for array<binary> columns with nullable elements.
// See https://github.com/apache/datafusion-comet/issues/2612
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val schemaOptions =
SchemaGenOptions(generateArray = true, generateStruct = false, generateMap = false)
val dataOptions = DataGenOptions(allowNull = true, generateNegativeZero = false)
ParquetGenerator.makeParquetFile(random, spark, filename, 100, schemaOptions, dataOptions)
}
withTempView("t1") {
val table = spark.read.parquet(filename)
table.createOrReplaceTempView("t1")
for (field <- table.schema.fields.filter(_.dataType.isInstanceOf[ArrayType])) {
val sql = s"SELECT ${field.name}, reverse(${field.name}) FROM t1 ORDER BY ${field.name}"
checkSparkAnswer(sql)
}
}
}
}

// https://github.com/apache/datafusion-comet/issues/2612
test("array_reverse - fallback for binary array") {
val fallbackReason =
Expand Down