Skip to content

Commit 5ccbb7d

Browse files
committed
chore: Add safety check to CometBuffer
1 parent ac4223c commit 5ccbb7d

File tree

7 files changed

+56
-26
lines changed

7 files changed

+56
-26
lines changed

native/core/benches/parquet_read.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,6 @@ impl Iterator for TestColumnReader {
213213
}
214214
self.total_num_values_read += total;
215215

216-
Some(self.inner.current_batch())
216+
Some(self.inner.current_batch().unwrap())
217217
}
218218
}

native/core/src/common/buffer.rs

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use crate::common::bit;
19+
use crate::execution::operators::ExecutionError;
1920
use arrow::buffer::Buffer as ArrowBuffer;
2021
use std::{
2122
alloc::{handle_alloc_error, Layout},
@@ -43,6 +44,8 @@ pub struct CometBuffer {
4344
capacity: usize,
4445
/// Whether this buffer owns the data it points to.
4546
owned: bool,
47+
/// The allocation instance for this buffer.
48+
allocation: Arc<CometBufferAllocation>,
4649
}
4750

4851
unsafe impl Sync for CometBuffer {}
@@ -63,6 +66,7 @@ impl CometBuffer {
6366
len: aligned_capacity,
6467
capacity: aligned_capacity,
6568
owned: true,
69+
allocation: Arc::new(CometBufferAllocation::new()),
6670
}
6771
}
6872
}
@@ -84,6 +88,7 @@ impl CometBuffer {
8488
len,
8589
capacity,
8690
owned: false,
91+
allocation: Arc::new(CometBufferAllocation::new()),
8792
}
8893
}
8994

@@ -163,11 +168,28 @@ impl CometBuffer {
163168
/// because of the iterator-style pattern, the content of the original mutable buffer will only
164169
/// be updated once upstream operators fully consumed the previous output batch. For breaking
165170
/// operators, they are responsible for copying content out of the buffers.
166-
pub unsafe fn to_arrow(&self) -> ArrowBuffer {
171+
pub unsafe fn to_arrow(&self) -> Result<ArrowBuffer, ExecutionError> {
167172
let ptr = NonNull::new_unchecked(self.data.as_ptr());
168-
// Uses a dummy `Arc::new(0)` as `Allocation` to ensure the memory region pointed by
169-
// `ptr` won't be freed when the returned `ArrowBuffer` goes out of scope.
170-
ArrowBuffer::from_custom_allocation(ptr, self.len, Arc::new(0))
173+
self.check_reference()?;
174+
Ok(ArrowBuffer::from_custom_allocation(
175+
ptr,
176+
self.len,
177+
self.allocation.clone(),
178+
))
179+
}
180+
181+
/// Checks if this buffer is exclusively owned by Comet. If not, an error is returned.
182+
/// We run this check when we want to update the buffer. If the buffer is also shared by
183+
/// other components, e.g. one DataFusion operator stores the buffer, Comet cannot safely
184+
/// modify the buffer.
185+
pub fn check_reference(&self) -> Result<(), ExecutionError> {
186+
if Arc::strong_count(&self.allocation) > 1 {
187+
Err(ExecutionError::GeneralError(
188+
"Error on modifying a buffer which is not exclusively owned by Comet".to_string(),
189+
))
190+
} else {
191+
Ok(())
192+
}
171193
}
172194

173195
/// Resets this buffer by filling all bytes with zeros.
@@ -242,13 +264,6 @@ impl PartialEq for CometBuffer {
242264
}
243265
}
244266

245-
impl From<&ArrowBuffer> for CometBuffer {
246-
fn from(value: &ArrowBuffer) -> Self {
247-
assert_eq!(value.len(), value.capacity());
248-
CometBuffer::from_ptr(value.as_ptr(), value.len(), value.capacity())
249-
}
250-
}
251-
252267
impl std::ops::Deref for CometBuffer {
253268
type Target = [u8];
254269

@@ -264,6 +279,15 @@ impl std::ops::DerefMut for CometBuffer {
264279
}
265280
}
266281

282+
#[derive(Debug)]
283+
struct CometBufferAllocation {}
284+
285+
impl CometBufferAllocation {
286+
fn new() -> Self {
287+
Self {}
288+
}
289+
}
290+
267291
#[cfg(test)]
268292
mod tests {
269293
use super::*;
@@ -319,7 +343,7 @@ mod tests {
319343
assert_eq!(b"aaaa bbbb cccc dddd", &buf.as_slice()[0..str.len()]);
320344

321345
unsafe {
322-
let immutable_buf: ArrowBuffer = buf.to_arrow();
346+
let immutable_buf: ArrowBuffer = buf.to_arrow().unwrap();
323347
assert_eq!(64, immutable_buf.len());
324348
assert_eq!(str, &immutable_buf.as_slice()[0..str.len()]);
325349
}
@@ -335,7 +359,7 @@ mod tests {
335359
assert_eq!(b"hello comet", &buf.as_slice()[0..11]);
336360

337361
unsafe {
338-
let arrow_buf2 = buf.to_arrow();
362+
let arrow_buf2 = buf.to_arrow().unwrap();
339363
assert_eq!(arrow_buf, arrow_buf2);
340364
}
341365
}

native/core/src/execution/operators/copy.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,10 @@ fn copy_array(array: &dyn Array) -> ArrayRef {
258258
/// is a dictionary array, we will cast the dictionary array to primitive type
259259
/// (i.e., unpack the dictionary array) and copy the primitive array. If the input
260260
/// array is a primitive array, we simply copy the array.
261-
fn copy_or_unpack_array(array: &Arc<dyn Array>, mode: &CopyMode) -> Result<ArrayRef, ArrowError> {
261+
pub(crate) fn copy_or_unpack_array(
262+
array: &Arc<dyn Array>,
263+
mode: &CopyMode,
264+
) -> Result<ArrayRef, ArrowError> {
262265
match array.data_type() {
263266
DataType::Dictionary(_, value_type) => {
264267
let options = CastOptions::default();

native/core/src/parquet/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch(
543543
try_unwrap_or_throw(&e, |_env| {
544544
let ctx = get_context(handle)?;
545545
let reader = &mut ctx.column_reader;
546-
let data = reader.current_batch();
546+
let data = reader.current_batch()?;
547547
data.move_to_spark(array_addr, schema_addr)
548548
.map_err(|e| e.into())
549549
})

native/core/src/parquet/mutable_vector.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use arrow::{array::ArrayData, datatypes::DataType as ArrowDataType};
1919

2020
use crate::common::{bit, CometBuffer};
21+
use crate::execution::operators::ExecutionError;
2122

2223
const DEFAULT_ARRAY_LEN: usize = 4;
2324

@@ -192,7 +193,7 @@ impl ParquetMutableVector {
192193
/// This method is highly unsafe since it calls `CometBuffer::to_arrow` which leaks raw
193194
/// pointer to the memory region that are tracked by `CometBuffer`. Please see comments on
194195
/// `to_arrow` buffer to understand the motivation.
195-
pub fn get_array_data(&mut self) -> ArrayData {
196+
pub fn get_array_data(&mut self) -> Result<ArrayData, ExecutionError> {
196197
unsafe {
197198
let data_type = if let Some(d) = &self.dictionary {
198199
ArrowDataType::Dictionary(
@@ -204,20 +205,19 @@ impl ParquetMutableVector {
204205
};
205206
let mut builder = ArrayData::builder(data_type)
206207
.len(self.num_values)
207-
.add_buffer(self.value_buffer.to_arrow())
208-
.null_bit_buffer(Some(self.validity_buffer.to_arrow()))
208+
.add_buffer(self.value_buffer.to_arrow()?)
209+
.null_bit_buffer(Some(self.validity_buffer.to_arrow()?))
209210
.null_count(self.num_nulls);
210211

211212
if Self::is_binary_type(&self.arrow_type) && self.dictionary.is_none() {
212213
let child = &mut self.children[0];
213-
builder = builder.add_buffer(child.value_buffer.to_arrow());
214+
builder = builder.add_buffer(child.value_buffer.to_arrow()?);
214215
}
215216

216217
if let Some(d) = &mut self.dictionary {
217-
builder = builder.add_child_data(d.get_array_data());
218+
builder = builder.add_child_data(d.get_array_data()?);
218219
}
219-
220-
builder.build_unchecked()
220+
Ok(builder.build_unchecked())
221221
}
222222
}
223223

native/core/src/parquet/read/column.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use super::{
3939
};
4040

4141
use crate::common::{bit, bit::log2};
42+
use crate::execution::operators::ExecutionError;
4243

4344
/// Maximum number of decimal digits an i32 can represent
4445
const DECIMAL_MAX_INT_DIGITS: i32 = 9;
@@ -601,7 +602,7 @@ impl ColumnReader {
601602
}
602603

603604
#[inline]
604-
pub fn current_batch(&mut self) -> ArrayData {
605+
pub fn current_batch(&mut self) -> Result<ArrayData, ExecutionError> {
605606
make_func_mut!(self, current_batch)
606607
}
607608

@@ -684,7 +685,7 @@ impl<T: DataType> TypedColumnReader<T> {
684685
/// Note: the caller must make sure the returned Arrow vector is fully consumed before calling
685686
/// `read_batch` again.
686687
#[inline]
687-
pub fn current_batch(&mut self) -> ArrayData {
688+
pub fn current_batch(&mut self) -> Result<ArrayData, ExecutionError> {
688689
self.vector.get_array_data()
689690
}
690691

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@ class CometExecSuite extends CometTestBase {
6767
test("TopK operator should return correct results on dictionary column with nulls") {
6868
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
6969
withTable("test_data") {
70+
val data = (0 to 8000)
71+
.flatMap(_ => Seq((1, null, "A"), (2, "BBB", "B"), (3, "BBB", "B"), (4, "BBB", "B")))
7072
val tableDF = spark.sparkContext
71-
.parallelize(Seq((1, null, "A"), (2, "BBB", "B"), (3, "BBB", "B"), (4, "BBB", "B")), 3)
73+
.parallelize(data, 3)
7274
.toDF("c1", "c2", "c3")
7375
tableDF
7476
.coalesce(1)

0 commit comments

Comments
 (0)