1616// under the License.
1717
1818use crate :: common:: bit;
19+ use crate :: execution:: operators:: ExecutionError ;
1920use arrow:: buffer:: Buffer as ArrowBuffer ;
2021use std:: {
2122 alloc:: { handle_alloc_error, Layout } ,
@@ -43,6 +44,8 @@ pub struct CometBuffer {
4344 capacity : usize ,
4445 /// Whether this buffer owns the data it points to.
4546 owned : bool ,
47+ /// The allocation instance for this buffer.
48+ allocation : Arc < CometBufferAllocation > ,
4649}
4750
4851unsafe impl Sync for CometBuffer { }
@@ -63,6 +66,7 @@ impl CometBuffer {
6366 len : aligned_capacity,
6467 capacity : aligned_capacity,
6568 owned : true ,
69+ allocation : Arc :: new ( CometBufferAllocation :: new ( ) ) ,
6670 }
6771 }
6872 }
@@ -84,6 +88,7 @@ impl CometBuffer {
8488 len,
8589 capacity,
8690 owned : false ,
91+ allocation : Arc :: new ( CometBufferAllocation :: new ( ) ) ,
8792 }
8893 }
8994
@@ -163,11 +168,28 @@ impl CometBuffer {
163168 /// because of the iterator-style pattern, the content of the original mutable buffer will only
164169 /// be updated once upstream operators fully consumed the previous output batch. For breaking
165170 /// operators, they are responsible for copying content out of the buffers.
166- pub unsafe fn to_arrow ( & self ) -> ArrowBuffer {
171+ pub unsafe fn to_arrow ( & self ) -> Result < ArrowBuffer , ExecutionError > {
167172 let ptr = NonNull :: new_unchecked ( self . data . as_ptr ( ) ) ;
168- // Uses a dummy `Arc::new(0)` as `Allocation` to ensure the memory region pointed by
169- // `ptr` won't be freed when the returned `ArrowBuffer` goes out of scope.
170- ArrowBuffer :: from_custom_allocation ( ptr, self . len , Arc :: new ( 0 ) )
173+ self . check_reference ( ) ?;
174+ Ok ( ArrowBuffer :: from_custom_allocation (
175+ ptr,
176+ self . len ,
177+ self . allocation . clone ( ) ,
178+ ) )
179+ }
180+
181+ /// Checks if this buffer is exclusively owned by Comet. If not, an error is returned.
182+ /// We run this check when we want to update the buffer. If the buffer is also shared by
183+ /// other components, e.g. one DataFusion operator stores the buffer, Comet cannot safely
184+ /// modify the buffer.
185+ pub fn check_reference ( & self ) -> Result < ( ) , ExecutionError > {
186+ if Arc :: strong_count ( & self . allocation ) > 1 {
187+ Err ( ExecutionError :: GeneralError (
188+ "Error on modifying a buffer which is not exclusively owned by Comet" . to_string ( ) ,
189+ ) )
190+ } else {
191+ Ok ( ( ) )
192+ }
171193 }
172194
173195 /// Resets this buffer by filling all bytes with zeros.
@@ -242,13 +264,6 @@ impl PartialEq for CometBuffer {
242264 }
243265}
244266
245- impl From < & ArrowBuffer > for CometBuffer {
246- fn from ( value : & ArrowBuffer ) -> Self {
247- assert_eq ! ( value. len( ) , value. capacity( ) ) ;
248- CometBuffer :: from_ptr ( value. as_ptr ( ) , value. len ( ) , value. capacity ( ) )
249- }
250- }
251-
252267impl std:: ops:: Deref for CometBuffer {
253268 type Target = [ u8 ] ;
254269
@@ -264,6 +279,15 @@ impl std::ops::DerefMut for CometBuffer {
264279 }
265280}
266281
282+ #[ derive( Debug ) ]
283+ struct CometBufferAllocation { }
284+
285+ impl CometBufferAllocation {
286+ fn new ( ) -> Self {
287+ Self { }
288+ }
289+ }
290+
267291#[ cfg( test) ]
268292mod tests {
269293 use super :: * ;
@@ -319,7 +343,7 @@ mod tests {
319343 assert_eq ! ( b"aaaa bbbb cccc dddd" , & buf. as_slice( ) [ 0 ..str . len( ) ] ) ;
320344
321345 unsafe {
322- let immutable_buf: ArrowBuffer = buf. to_arrow ( ) ;
346+ let immutable_buf: ArrowBuffer = buf. to_arrow ( ) . unwrap ( ) ;
323347 assert_eq ! ( 64 , immutable_buf. len( ) ) ;
324348 assert_eq ! ( str , & immutable_buf. as_slice( ) [ 0 ..str . len( ) ] ) ;
325349 }
@@ -335,7 +359,7 @@ mod tests {
335359 assert_eq ! ( b"hello comet" , & buf. as_slice( ) [ 0 ..11 ] ) ;
336360
337361 unsafe {
338- let arrow_buf2 = buf. to_arrow ( ) ;
362+ let arrow_buf2 = buf. to_arrow ( ) . unwrap ( ) ;
339363 assert_eq ! ( arrow_buf, arrow_buf2) ;
340364 }
341365 }
0 commit comments