@@ -30,8 +30,13 @@ use datafusion::{
3030 execution:: memory_pool:: { MemoryPool , MemoryReservation } ,
3131} ;
3232
33- use crate :: jvm_bridge:: { jni_call, JVMClasses } ;
33+ use crate :: {
34+ errors:: CometResult ,
35+ jvm_bridge:: { jni_call, JVMClasses } ,
36+ } ;
3437
38+ /// A DataFusion `MemoryPool` implementation for Comet. Internally this is
39+ /// implemented via delegating calls to [`crate::jvm_bridge::CometTaskMemoryManager`].
3540pub struct CometMemoryPool {
3641 task_memory_manager_handle : Arc < GlobalRef > ,
3742 used : AtomicUsize ,
@@ -52,44 +57,56 @@ impl CometMemoryPool {
5257 used : AtomicUsize :: new ( 0 ) ,
5358 }
5459 }
60+
61+ fn acquire ( & self , additional : usize ) -> CometResult < i64 > {
62+ let mut env = JVMClasses :: get_env ( ) ;
63+ let handle = self . task_memory_manager_handle . as_obj ( ) ;
64+ unsafe {
65+ jni_call ! ( & mut env,
66+ comet_task_memory_manager( handle) . acquire_memory( additional as i64 ) -> i64 )
67+ }
68+ }
69+
70+ fn release ( & self , size : usize ) -> CometResult < ( ) > {
71+ let mut env = JVMClasses :: get_env ( ) ;
72+ let handle = self . task_memory_manager_handle . as_obj ( ) ;
73+ unsafe {
74+ jni_call ! ( & mut env, comet_task_memory_manager( handle) . release_memory( size as i64 ) -> ( ) )
75+ }
76+ }
5577}
5678
5779unsafe impl Send for CometMemoryPool { }
5880unsafe impl Sync for CometMemoryPool { }
5981
6082impl MemoryPool for CometMemoryPool {
6183 fn grow ( & self , _: & MemoryReservation , additional : usize ) {
84+ self . acquire ( additional)
85+ . unwrap_or_else ( |_| panic ! ( "Failed to acquire {} bytes" , additional) ) ;
6286 self . used . fetch_add ( additional, Relaxed ) ;
6387 }
6488
6589 fn shrink ( & self , _: & MemoryReservation , size : usize ) {
66- let mut env = JVMClasses :: get_env ( ) ;
67- let handle = self . task_memory_manager_handle . as_obj ( ) ;
68- unsafe {
69- jni_call ! ( & mut env, comet_task_memory_manager( handle) . release_memory( size as i64 ) -> ( ) )
70- . unwrap ( ) ;
71- }
90+ self . release ( size)
91+ . unwrap_or_else ( |_| panic ! ( "Failed to release {} bytes" , size) ) ;
7292 self . used . fetch_sub ( size, Relaxed ) ;
7393 }
7494
7595 fn try_grow ( & self , _: & MemoryReservation , additional : usize ) -> Result < ( ) , DataFusionError > {
7696 if additional > 0 {
77- let mut env = JVMClasses :: get_env ( ) ;
78- let handle = self . task_memory_manager_handle . as_obj ( ) ;
79- unsafe {
80- let acquired = jni_call ! ( & mut env,
81- comet_task_memory_manager( handle) . acquire_memory( additional as i64 ) -> i64 ) ?;
97+ let acquired = self . acquire ( additional) ?;
98+ // If the number of bytes we acquired is less than the requested, return an error,
99+ // and hopefully will trigger spilling from the caller side.
100+ if acquired < additional as i64 {
101+ // Release the acquired bytes before throwing error
102+ self . release ( acquired as usize ) ?;
82103
83- // If the number of bytes we acquired is less than the requested, return an error,
84- // and hopefully will trigger spilling from the caller side.
85- if acquired < additional as i64 {
86- return Err ( DataFusionError :: Execution ( format ! (
87- "Failed to acquire {} bytes, only got {}. Reserved: {}" ,
88- additional,
89- acquired,
90- self . reserved( ) ,
91- ) ) ) ;
92- }
104+ return Err ( DataFusionError :: Execution ( format ! (
105+ "Failed to acquire {} bytes, only got {}. Reserved: {}" ,
106+ additional,
107+ acquired,
108+ self . reserved( ) ,
109+ ) ) ) ;
93110 }
94111 self . used . fetch_add ( additional, Relaxed ) ;
95112 }
0 commit comments