@@ -24,6 +24,7 @@ use std::task::{Context, Poll};
2424use  std:: { any:: Any ,  vec} ; 
2525
2626use  crate :: error:: { DataFusionError ,  Result } ; 
27+ use  crate :: execution:: memory_pool:: { MemoryConsumer ,  MemoryReservation } ; 
2728use  crate :: physical_plan:: hash_utils:: create_hashes; 
2829use  crate :: physical_plan:: { 
2930    DisplayFormatType ,  EquivalenceProperties ,  ExecutionPlan ,  Partitioning ,  Statistics , 
@@ -50,14 +51,21 @@ use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
5051use  tokio:: task:: JoinHandle ; 
5152
5253type  MaybeBatch  = Option < ArrowResult < RecordBatch > > ; 
54+ type  SharedMemoryReservation  = Arc < Mutex < MemoryReservation > > ; 
5355
5456/// Inner state of [`RepartitionExec`]. 
5557#[ derive( Debug ) ]  
5658struct  RepartitionExecState  { 
5759    /// Channels for sending batches from input partitions to output partitions. 
5860/// Key is the partition number. 
59- channels : 
60-         HashMap < usize ,  ( UnboundedSender < MaybeBatch > ,  UnboundedReceiver < MaybeBatch > ) > , 
61+ channels :  HashMap < 
62+         usize , 
63+         ( 
64+             UnboundedSender < MaybeBatch > , 
65+             UnboundedReceiver < MaybeBatch > , 
66+             SharedMemoryReservation , 
67+         ) , 
68+     > , 
6169
6270    /// Helper that ensures that that background job is killed once it is no longer needed. 
6371abort_helper :  Arc < AbortOnDropMany < ( ) > > , 
@@ -338,7 +346,13 @@ impl ExecutionPlan for RepartitionExec {
338346                // for this would be to add spill-to-disk capabilities. 
339347                let  ( sender,  receiver)  =
340348                    mpsc:: unbounded_channel :: < Option < ArrowResult < RecordBatch > > > ( ) ; 
341-                 state. channels . insert ( partition,  ( sender,  receiver) ) ; 
349+                 let  reservation = Arc :: new ( Mutex :: new ( 
350+                     MemoryConsumer :: new ( format ! ( "RepartitionExec[{partition}]" ) ) 
351+                         . register ( context. memory_pool ( ) ) , 
352+                 ) ) ; 
353+                 state
354+                     . channels 
355+                     . insert ( partition,  ( sender,  receiver,  reservation) ) ; 
342356            } 
343357
344358            // launch one async task per *input* partition 
@@ -347,7 +361,9 @@ impl ExecutionPlan for RepartitionExec {
347361                let  txs:  HashMap < _ ,  _ >  = state
348362                    . channels 
349363                    . iter ( ) 
350-                     . map ( |( partition,  ( tx,  _rx) ) | ( * partition,  tx. clone ( ) ) ) 
364+                     . map ( |( partition,  ( tx,  _rx,  reservation) ) | { 
365+                         ( * partition,  ( tx. clone ( ) ,  Arc :: clone ( reservation) ) ) 
366+                     } ) 
351367                    . collect ( ) ; 
352368
353369                let  r_metrics = RepartitionMetrics :: new ( i,  partition,  & self . metrics ) ; 
@@ -366,7 +382,9 @@ impl ExecutionPlan for RepartitionExec {
366382                // (and pass along any errors, including panic!s) 
367383                let  join_handle = tokio:: spawn ( Self :: wait_for_task ( 
368384                    AbortOnDropSingle :: new ( input_task) , 
369-                     txs, 
385+                     txs. into_iter ( ) 
386+                         . map ( |( partition,  ( tx,  _reservation) ) | ( partition,  tx) ) 
387+                         . collect ( ) , 
370388                ) ) ; 
371389                join_handles. push ( join_handle) ; 
372390            } 
@@ -381,14 +399,17 @@ impl ExecutionPlan for RepartitionExec {
381399
382400        // now return stream for the specified *output* partition which will 
383401        // read from the channel 
402+         let  ( _tx,  rx,  reservation)  = state
403+             . channels 
404+             . remove ( & partition) 
405+             . expect ( "partition not used yet" ) ; 
384406        Ok ( Box :: pin ( RepartitionStream  { 
385407            num_input_partitions, 
386408            num_input_partitions_processed :  0 , 
387409            schema :  self . input . schema ( ) , 
388-             input :  UnboundedReceiverStream :: new ( 
389-                 state. channels . remove ( & partition) . unwrap ( ) . 1 , 
390-             ) , 
410+             input :  UnboundedReceiverStream :: new ( rx) , 
391411            drop_helper :  Arc :: clone ( & state. abort_helper ) , 
412+             reservation, 
392413        } ) ) 
393414    } 
394415
@@ -439,7 +460,7 @@ impl RepartitionExec {
439460async  fn  pull_from_input ( 
440461        input :  Arc < dyn  ExecutionPlan > , 
441462        i :  usize , 
442-         mut  txs :  HashMap < usize ,  UnboundedSender < Option < ArrowResult < RecordBatch > > > > , 
463+         mut  txs :  HashMap < usize ,  ( UnboundedSender < MaybeBatch > ,   SharedMemoryReservation ) > , 
443464        partitioning :  Partitioning , 
444465        r_metrics :  RepartitionMetrics , 
445466        context :  Arc < TaskContext > , 
@@ -467,11 +488,16 @@ impl RepartitionExec {
467488            } ; 
468489
469490            partitioner. partition ( batch,  |partition,  partitioned| { 
491+                 let  size = partitioned. get_array_memory_size ( ) ; 
492+ 
470493                let  timer = r_metrics. send_time . timer ( ) ; 
471494                // if there is still a receiver, send to it 
472-                 if  let  Some ( tx)  = txs. get_mut ( & partition)  { 
495+                 if  let  Some ( ( tx,  reservation) )  = txs. get_mut ( & partition)  { 
496+                     reservation. lock ( ) . try_grow ( size) ?; 
497+ 
473498                    if  tx. send ( Some ( Ok ( partitioned) ) ) . is_err ( )  { 
474499                        // If the other end has hung up, it was an early shutdown (e.g. LIMIT) 
500+                         reservation. lock ( ) . shrink ( size) ; 
475501                        txs. remove ( & partition) ; 
476502                    } 
477503                } 
@@ -546,6 +572,9 @@ struct RepartitionStream {
546572    /// Handle to ensure background tasks are killed when no longer needed. 
547573#[ allow( dead_code) ]  
548574    drop_helper :  Arc < AbortOnDropMany < ( ) > > , 
575+ 
576+     /// Memory reservation. 
577+ reservation :  SharedMemoryReservation , 
549578} 
550579
551580impl  Stream  for  RepartitionStream  { 
@@ -555,20 +584,35 @@ impl Stream for RepartitionStream {
555584        mut  self :  Pin < & mut  Self > , 
556585        cx :  & mut  Context < ' _ > , 
557586    )  -> Poll < Option < Self :: Item > >  { 
558-         match  self . input . poll_next_unpin ( cx)  { 
559-             Poll :: Ready ( Some ( Some ( v) ) )  => Poll :: Ready ( Some ( v) ) , 
560-             Poll :: Ready ( Some ( None ) )  => { 
561-                 self . num_input_partitions_processed  += 1 ; 
562-                 if  self . num_input_partitions  == self . num_input_partitions_processed  { 
563-                     // all input partitions have finished sending batches 
564-                     Poll :: Ready ( None ) 
565-                 }  else  { 
566-                     // other partitions still have data to send 
567-                     self . poll_next ( cx) 
587+         loop  { 
588+             match  self . input . poll_next_unpin ( cx)  { 
589+                 Poll :: Ready ( Some ( Some ( v) ) )  => { 
590+                     if  let  Ok ( batch)  = & v { 
591+                         self . reservation 
592+                             . lock ( ) 
593+                             . shrink ( batch. get_array_memory_size ( ) ) ; 
594+                     } 
595+ 
596+                     return  Poll :: Ready ( Some ( v) ) ; 
597+                 } 
598+                 Poll :: Ready ( Some ( None ) )  => { 
599+                     self . num_input_partitions_processed  += 1 ; 
600+ 
601+                     if  self . num_input_partitions  == self . num_input_partitions_processed  { 
602+                         // all input partitions have finished sending batches 
603+                         return  Poll :: Ready ( None ) ; 
604+                     }  else  { 
605+                         // other partitions still have data to send 
606+                         continue ; 
607+                     } 
608+                 } 
609+                 Poll :: Ready ( None )  => { 
610+                     return  Poll :: Ready ( None ) ; 
611+                 } 
612+                 Poll :: Pending  => { 
613+                     return  Poll :: Pending ; 
568614                } 
569615            } 
570-             Poll :: Ready ( None )  => Poll :: Ready ( None ) , 
571-             Poll :: Pending  => Poll :: Pending , 
572616        } 
573617    } 
574618} 
@@ -583,6 +627,8 @@ impl RecordBatchStream for RepartitionStream {
583627#[ cfg( test) ]  
584628mod  tests { 
585629    use  super :: * ; 
630+     use  crate :: execution:: context:: SessionConfig ; 
631+     use  crate :: execution:: runtime_env:: { RuntimeConfig ,  RuntimeEnv } ; 
586632    use  crate :: from_slice:: FromSlice ; 
587633    use  crate :: prelude:: SessionContext ; 
588634    use  crate :: test:: create_vec_batches; 
@@ -1078,4 +1124,41 @@ mod tests {
10781124        assert ! ( batch0. is_empty( )  || batch1. is_empty( ) ) ; 
10791125        Ok ( ( ) ) 
10801126    } 
1127+ 
1128+     #[ tokio:: test]  
1129+     async  fn  oom ( )  -> Result < ( ) >  { 
1130+         // define input partitions 
1131+         let  schema = test_schema ( ) ; 
1132+         let  partition = create_vec_batches ( & schema,  50 ) ; 
1133+         let  input_partitions = vec ! [ partition] ; 
1134+         let  partitioning = Partitioning :: RoundRobinBatch ( 4 ) ; 
1135+ 
1136+         // setup up context 
1137+         let  session_ctx = SessionContext :: with_config_rt ( 
1138+             SessionConfig :: default ( ) , 
1139+             Arc :: new ( 
1140+                 RuntimeEnv :: new ( RuntimeConfig :: default ( ) . with_memory_limit ( 1 ,  1.0 ) ) 
1141+                     . unwrap ( ) , 
1142+             ) , 
1143+         ) ; 
1144+         let  task_ctx = session_ctx. task_ctx ( ) ; 
1145+ 
1146+         // create physical plan 
1147+         let  exec = MemoryExec :: try_new ( & input_partitions,  schema. clone ( ) ,  None ) ?; 
1148+         let  exec = RepartitionExec :: try_new ( Arc :: new ( exec) ,  partitioning) ?; 
1149+ 
1150+         // pull partitions 
1151+         for  i in  0 ..exec. partitioning . partition_count ( )  { 
1152+             let  mut  stream = exec. execute ( i,  task_ctx. clone ( ) ) ?; 
1153+             let  err =
1154+                 DataFusionError :: ArrowError ( stream. next ( ) . await . unwrap ( ) . unwrap_err ( ) ) ; 
1155+             let  err = err. find_root ( ) ; 
1156+             assert ! ( 
1157+                 matches!( err,  DataFusionError :: ResourcesExhausted ( _) ) , 
1158+                 "Wrong error type: {err}" , 
1159+             ) ; 
1160+         } 
1161+ 
1162+         Ok ( ( ) ) 
1163+     } 
10811164} 
0 commit comments