1919
2020use  crate :: error:: { DataFusionError ,  Result } ; 
2121use  async_trait:: async_trait; 
22- use  hashbrown:: HashMap ; 
22+ use  hashbrown:: HashSet ; 
2323use  log:: debug; 
2424use  std:: fmt; 
2525use  std:: fmt:: { Debug ,  Display ,  Formatter } ; 
2626use  std:: sync:: atomic:: { AtomicUsize ,  Ordering } ; 
27- use  std:: sync:: { Arc ,  Condvar ,  Mutex ,   Weak } ; 
27+ use  std:: sync:: { Arc ,  Condvar ,  Mutex } ; 
2828
2929static  CONSUMER_ID :  AtomicUsize  = AtomicUsize :: new ( 0 ) ; 
3030
@@ -245,10 +245,10 @@ The memory management architecture is the following:
245245/// Manage memory usage during physical plan execution 
246246#[ derive( Debug ) ]  
247247pub  struct  MemoryManager  { 
248-     requesters :  Arc < Mutex < HashMap < MemoryConsumerId ,  Weak < dyn  MemoryConsumer > > > > , 
249-     trackers :  Arc < Mutex < HashMap < MemoryConsumerId ,  Weak < dyn  MemoryConsumer > > > > , 
248+     requesters :  Arc < Mutex < HashSet < MemoryConsumerId > > > , 
250249    pool_size :  usize , 
251250    requesters_total :  Arc < Mutex < usize > > , 
251+     trackers_total :  AtomicUsize , 
252252    cv :  Condvar , 
253253} 
254254
@@ -267,41 +267,47 @@ impl MemoryManager {
267267                ) ; 
268268
269269                Arc :: new ( Self  { 
270-                     requesters :  Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) , 
271-                     trackers :  Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) , 
270+                     requesters :  Arc :: new ( Mutex :: new ( HashSet :: new ( ) ) ) , 
272271                    pool_size, 
273272                    requesters_total :  Arc :: new ( Mutex :: new ( 0 ) ) , 
273+                     trackers_total :  AtomicUsize :: new ( 0 ) , 
274274                    cv :  Condvar :: new ( ) , 
275275                } ) 
276276            } 
277277        } 
278278    } 
279279
280280    fn  get_tracker_total ( & self )  -> usize  { 
281-         let  trackers = self . trackers . lock ( ) . unwrap ( ) ; 
282-         if  trackers. len ( )  > 0  { 
283-             trackers. values ( ) . fold ( 0usize ,  |acc,  y| match  y. upgrade ( )  { 
284-                 None  => acc, 
285-                 Some ( t)  => acc + t. mem_used ( ) , 
286-             } ) 
287-         }  else  { 
288-             0 
289-         } 
281+         self . trackers_total . load ( Ordering :: SeqCst ) 
290282    } 
291283
292-     /// Register a new memory consumer for memory usage tracking 
293- pub ( crate )  fn  register_consumer ( & self ,  consumer :  & Arc < dyn  MemoryConsumer > )  { 
294-         let  id = consumer. id ( ) . clone ( ) ; 
295-         match  consumer. type_ ( )  { 
296-             ConsumerType :: Requesting  => { 
297-                 let  mut  requesters = self . requesters . lock ( ) . unwrap ( ) ; 
298-                 requesters. insert ( id,  Arc :: downgrade ( consumer) ) ; 
299-             } 
300-             ConsumerType :: Tracking  => { 
301-                 let  mut  trackers = self . trackers . lock ( ) . unwrap ( ) ; 
302-                 trackers. insert ( id,  Arc :: downgrade ( consumer) ) ; 
303-             } 
304-         } 
284+     pub ( crate )  fn  grow_tracker_usage ( & self ,  delta :  usize )  { 
285+         self . trackers_total . fetch_add ( delta,  Ordering :: SeqCst ) ; 
286+     } 
287+ 
288+     pub ( crate )  fn  shrink_tracker_usage ( & self ,  delta :  usize )  { 
289+         let  update =
290+             self . trackers_total 
291+                 . fetch_update ( Ordering :: SeqCst ,  Ordering :: SeqCst ,  |x| { 
292+                     if  x >= delta { 
293+                         Some ( x - delta) 
294+                     }  else  { 
295+                         None 
296+                     } 
297+                 } ) ; 
298+         update. expect ( & * format ! ( 
299+             "Tracker total memory shrink by {} underflow, current value is " , 
300+             delta
301+         ) ) ; 
302+     } 
303+ 
304+     fn  get_requester_total ( & self )  -> usize  { 
305+         * self . requesters_total . lock ( ) . unwrap ( ) 
306+     } 
307+ 
308+     /// Register a new memory requester 
309+ pub ( crate )  fn  register_requester ( & self ,  requester_id :  & MemoryConsumerId )  { 
310+         self . requesters . lock ( ) . unwrap ( ) . insert ( requester_id. clone ( ) ) ; 
305311    } 
306312
307313    fn  max_mem_for_requesters ( & self )  -> usize  { 
@@ -317,7 +323,6 @@ impl MemoryManager {
317323
318324        let  granted; 
319325        loop  { 
320-             let  remaining = rqt_max - * rqt_current_used; 
321326            let  max_per_rqt = rqt_max / num_rqt; 
322327            let  min_per_rqt = max_per_rqt / 2 ; 
323328
@@ -326,6 +331,7 @@ impl MemoryManager {
326331                break ; 
327332            } 
328333
334+             let  remaining = rqt_max. checked_sub ( * rqt_current_used) . unwrap_or_default ( ) ; 
329335            if  remaining >= required { 
330336                granted = true ; 
331337                * rqt_current_used += required; 
@@ -347,46 +353,37 @@ impl MemoryManager {
347353
348354    fn  record_free_then_acquire ( & self ,  freed :  usize ,  acquired :  usize )  { 
349355        let  mut  requesters_total = self . requesters_total . lock ( ) . unwrap ( ) ; 
356+         assert ! ( * requesters_total >= freed) ; 
350357        * requesters_total -= freed; 
351358        * requesters_total += acquired; 
352359        self . cv . notify_all ( ) 
353360    } 
354361
355-     /// Drop a memory consumer from memory usage tracking  
356- pub ( crate )  fn  drop_consumer ( & self ,  id :  & MemoryConsumerId )  { 
362+     /// Drop a memory consumer and reclaim the memory  
363+ pub ( crate )  fn  drop_consumer ( & self ,  id :  & MemoryConsumerId ,   mem_used :   usize )  { 
357364        // find in requesters first 
358365        { 
359366            let  mut  requesters = self . requesters . lock ( ) . unwrap ( ) ; 
360-             if  requesters. remove ( id) . is_some ( )  { 
361-                 return ; 
367+             if  requesters. remove ( id)  { 
368+                 let  mut  total = self . requesters_total . lock ( ) . unwrap ( ) ; 
369+                 assert ! ( * total >= mem_used) ; 
370+                 * total -= mem_used; 
362371            } 
363372        } 
364-         let   mut  trackers =  self . trackers . lock ( ) . unwrap ( ) ; 
365-         trackers . remove ( id ) ; 
373+         self . shrink_tracker_usage ( mem_used ) ; 
374+         self . cv . notify_all ( ) ; 
366375    } 
367376} 
368377
369378impl  Display  for  MemoryManager  { 
370379    fn  fmt ( & self ,  f :  & mut  Formatter )  -> fmt:: Result  { 
371-         let  requesters =
372-             self . requesters 
373-                 . lock ( ) 
374-                 . unwrap ( ) 
375-                 . values ( ) 
376-                 . fold ( vec ! [ ] ,  |mut  acc,  consumer| match  consumer. upgrade ( )  { 
377-                     None  => acc, 
378-                     Some ( c)  => { 
379-                         acc. push ( format ! ( "{}" ,  c) ) ; 
380-                         acc
381-                     } 
382-                 } ) ; 
383-         let  tracker_mem = self . get_tracker_total ( ) ; 
384380        write ! ( f, 
385-                "MemoryManager usage statistics: total {}, tracker used {}, total {} requesters detail: \n  {}," , 
386-                 human_readable_size( self . pool_size) , 
387-                 human_readable_size( tracker_mem) , 
388-                 & requesters. len( ) , 
389-                requesters. join( "\n " ) ) 
381+                "MemoryManager usage statistics: total {}, trackers used {}, total {} requesters used: {}" , 
382+                human_readable_size( self . pool_size) , 
383+                human_readable_size( self . get_tracker_total( ) ) , 
384+                self . requesters. lock( ) . unwrap( ) . len( ) , 
385+                human_readable_size( self . get_requester_total( ) ) , 
386+         ) 
390387    } 
391388} 
392389
@@ -418,6 +415,8 @@ mod tests {
418415    use  super :: * ; 
419416    use  crate :: error:: Result ; 
420417    use  crate :: execution:: runtime_env:: { RuntimeConfig ,  RuntimeEnv } ; 
418+     use  crate :: execution:: MemoryConsumer ; 
419+     use  crate :: physical_plan:: metrics:: { ExecutionPlanMetricsSet ,  MemTrackingMetrics } ; 
421420    use  async_trait:: async_trait; 
422421    use  std:: sync:: atomic:: { AtomicUsize ,  Ordering } ; 
423422    use  std:: sync:: Arc ; 
@@ -487,6 +486,7 @@ mod tests {
487486
488487    impl  DummyTracker  { 
489488        fn  new ( partition :  usize ,  runtime :  Arc < RuntimeEnv > ,  mem_used :  usize )  -> Self  { 
489+             runtime. grow_tracker_usage ( mem_used) ; 
490490            Self  { 
491491                id :  MemoryConsumerId :: new ( partition) , 
492492                runtime, 
@@ -528,23 +528,29 @@ mod tests {
528528            . with_memory_manager ( MemoryManagerConfig :: try_new_limit ( 100 ,  1.0 ) . unwrap ( ) ) ; 
529529        let  runtime = Arc :: new ( RuntimeEnv :: new ( config) . unwrap ( ) ) ; 
530530
531-         let  tracker1 = Arc :: new ( DummyTracker :: new ( 0 ,  runtime. clone ( ) ,  5 ) ) ; 
532-         runtime. register_consumer ( & ( tracker1. clone ( )  as  Arc < dyn  MemoryConsumer > ) ) ; 
531+         DummyTracker :: new ( 0 ,  runtime. clone ( ) ,  5 ) ; 
533532        assert_eq ! ( runtime. memory_manager. get_tracker_total( ) ,  5 ) ; 
534533
535-         let  tracker2 = Arc :: new ( DummyTracker :: new ( 0 ,  runtime. clone ( ) ,  10 ) ) ; 
536-         runtime. register_consumer ( & ( tracker2. clone ( )  as  Arc < dyn  MemoryConsumer > ) ) ; 
534+         let  tracker1 = DummyTracker :: new ( 0 ,  runtime. clone ( ) ,  10 ) ; 
537535        assert_eq ! ( runtime. memory_manager. get_tracker_total( ) ,  15 ) ; 
538536
539-         let  tracker3 = Arc :: new ( DummyTracker :: new ( 0 ,  runtime. clone ( ) ,  15 ) ) ; 
540-         runtime. register_consumer ( & ( tracker3. clone ( )  as  Arc < dyn  MemoryConsumer > ) ) ; 
537+         DummyTracker :: new ( 0 ,  runtime. clone ( ) ,  15 ) ; 
541538        assert_eq ! ( runtime. memory_manager. get_tracker_total( ) ,  30 ) ; 
542539
543-         runtime. drop_consumer ( tracker2. id ( ) ) ; 
540+         runtime. drop_consumer ( tracker1. id ( ) ,  tracker1. mem_used ) ; 
541+         assert_eq ! ( runtime. memory_manager. get_tracker_total( ) ,  20 ) ; 
542+ 
543+         // MemTrackingMetrics as an easy way to track memory 
544+         let  ms = ExecutionPlanMetricsSet :: new ( ) ; 
545+         let  tracking_metric = MemTrackingMetrics :: new_with_rt ( & ms,  0 ,  runtime. clone ( ) ) ; 
546+         tracking_metric. init_mem_used ( 15 ) ; 
547+         assert_eq ! ( runtime. memory_manager. get_tracker_total( ) ,  35 ) ; 
548+ 
549+         drop ( tracking_metric) ; 
544550        assert_eq ! ( runtime. memory_manager. get_tracker_total( ) ,  20 ) ; 
545551
546-         let  requester1 = Arc :: new ( DummyRequester :: new ( 0 ,  runtime. clone ( ) ) ) ; 
547-         runtime. register_consumer ( & ( requester1. clone ( )   as   Arc < dyn   MemoryConsumer > ) ) ; 
552+         let  requester1 = DummyRequester :: new ( 0 ,  runtime. clone ( ) ) ; 
553+         runtime. register_requester ( requester1. id ( ) ) ; 
548554
549555        // first requester entered, should be able to use any of the remaining 80 
550556        requester1. do_with_mem ( 40 ) . await . unwrap ( ) ; 
@@ -553,8 +559,8 @@ mod tests {
553559        assert_eq ! ( requester1. mem_used( ) ,  50 ) ; 
554560        assert_eq ! ( * runtime. memory_manager. requesters_total. lock( ) . unwrap( ) ,  50 ) ; 
555561
556-         let  requester2 = Arc :: new ( DummyRequester :: new ( 0 ,  runtime. clone ( ) ) ) ; 
557-         runtime. register_consumer ( & ( requester2. clone ( )   as   Arc < dyn   MemoryConsumer > ) ) ; 
562+         let  requester2 = DummyRequester :: new ( 0 ,  runtime. clone ( ) ) ; 
563+         runtime. register_requester ( requester2. id ( ) ) ; 
558564
559565        requester2. do_with_mem ( 20 ) . await . unwrap ( ) ; 
560566        requester2. do_with_mem ( 30 ) . await . unwrap ( ) ; 
0 commit comments