@@ -26,6 +26,9 @@ use arrow::datatypes::{Schema, SchemaRef};
2626use arrow:: record_batch:: RecordBatch ;
2727
2828use crate :: execution:: context:: TaskContext ;
29+ use crate :: execution:: memory_pool:: MemoryConsumer ;
30+ use crate :: physical_plan:: common:: SharedMemoryReservation ;
31+ use crate :: physical_plan:: metrics:: { ExecutionPlanMetricsSet , MetricsSet } ;
2932use crate :: physical_plan:: {
3033 coalesce_batches:: concat_batches, coalesce_partitions:: CoalescePartitionsExec ,
3134 ColumnStatistics , DisplayFormatType , Distribution , EquivalenceProperties ,
@@ -35,12 +38,11 @@ use crate::physical_plan::{
3538use crate :: { error:: Result , scalar:: ScalarValue } ;
3639use async_trait:: async_trait;
3740use datafusion_common:: DataFusionError ;
38- use log:: debug;
39- use std:: time:: Instant ;
41+ use parking_lot:: Mutex ;
4042
4143use super :: utils:: {
42- adjust_right_output_partitioning, cross_join_equivalence_properties, OnceAsync ,
43- OnceFut ,
44+ adjust_right_output_partitioning, cross_join_equivalence_properties,
45+ BuildProbeJoinMetrics , OnceAsync , OnceFut ,
4446} ;
4547
4648/// Data of the left side
@@ -58,6 +60,8 @@ pub struct CrossJoinExec {
5860 schema : SchemaRef ,
5961 /// Build-side data
6062 left_fut : OnceAsync < JoinLeftData > ,
63+ /// Execution plan metrics
64+ metrics : ExecutionPlanMetricsSet ,
6165}
6266
6367impl CrossJoinExec {
@@ -79,6 +83,7 @@ impl CrossJoinExec {
7983 right,
8084 schema,
8185 left_fut : Default :: default ( ) ,
86+ metrics : ExecutionPlanMetricsSet :: default ( ) ,
8287 }
8388 }
8489
@@ -97,9 +102,9 @@ impl CrossJoinExec {
97102async fn load_left_input (
98103 left : Arc < dyn ExecutionPlan > ,
99104 context : Arc < TaskContext > ,
105+ metrics : BuildProbeJoinMetrics ,
106+ reservation : SharedMemoryReservation ,
100107) -> Result < JoinLeftData > {
101- let start = Instant :: now ( ) ;
102-
103108 // merge all left parts into a single stream
104109 let merge = {
105110 if left. output_partitioning ( ) . partition_count ( ) != 1 {
@@ -111,22 +116,28 @@ async fn load_left_input(
111116 let stream = merge. execute ( 0 , context) ?;
112117
113118 // Load all batches and count the rows
114- let ( batches, num_rows) = stream
115- . try_fold ( ( Vec :: new ( ) , 0usize ) , |mut acc, batch| async {
116- acc. 1 += batch. num_rows ( ) ;
117- acc. 0 . push ( batch) ;
118- Ok ( acc)
119- } )
119+ let ( batches, num_rows, _, _) = stream
120+ . try_fold (
121+ ( Vec :: new ( ) , 0usize , metrics, reservation) ,
122+ |mut acc, batch| async {
123+ let batch_size = batch. get_array_memory_size ( ) ;
124+ // Reserve memory for incoming batch
125+ acc. 3 . lock ( ) . try_grow ( batch_size) ?;
126+ // Update metrics
127+ acc. 2 . build_mem_used . add ( batch_size) ;
128+ acc. 2 . build_input_batches . add ( 1 ) ;
129+ acc. 2 . build_input_rows . add ( batch. num_rows ( ) ) ;
130+ // Update rowcount
131+ acc. 1 += batch. num_rows ( ) ;
132+ // Push batch to output
133+ acc. 0 . push ( batch) ;
134+ Ok ( acc)
135+ } ,
136+ )
120137 . await ?;
121138
122139 let merged_batch = concat_batches ( & left. schema ( ) , & batches, num_rows) ?;
123140
124- debug ! (
125- "Built build-side of cross join containing {} rows in {} ms" ,
126- num_rows,
127- start. elapsed( ) . as_millis( )
128- ) ;
129-
130141 Ok ( merged_batch)
131142}
132143
@@ -143,6 +154,10 @@ impl ExecutionPlan for CrossJoinExec {
143154 vec ! [ self . left. clone( ) , self . right. clone( ) ]
144155 }
145156
157+ fn metrics ( & self ) -> Option < MetricsSet > {
158+ Some ( self . metrics . clone_inner ( ) )
159+ }
160+
146161 /// Specifies whether this plan generates an infinite stream of records.
147162 /// If the plan does not support pipelining, but it its input(s) are
148163 /// infinite, returns an error to indicate this.
@@ -205,21 +220,29 @@ impl ExecutionPlan for CrossJoinExec {
205220 ) -> Result < SendableRecordBatchStream > {
206221 let stream = self . right . execute ( partition, context. clone ( ) ) ?;
207222
208- let left_fut = self
209- . left_fut
210- . once ( || load_left_input ( self . left . clone ( ) , context) ) ;
223+ let join_metrics = BuildProbeJoinMetrics :: new ( partition, & self . metrics ) ;
224+ let reservation = Arc :: new ( Mutex :: new (
225+ MemoryConsumer :: new ( format ! ( "CrossJoinStream[{partition}]" ) )
226+ . register ( context. memory_pool ( ) ) ,
227+ ) ) ;
228+
229+ let left_fut = self . left_fut . once ( || {
230+ load_left_input (
231+ self . left . clone ( ) ,
232+ context,
233+ join_metrics. clone ( ) ,
234+ reservation. clone ( ) ,
235+ )
236+ } ) ;
211237
212238 Ok ( Box :: pin ( CrossJoinStream {
213239 schema : self . schema . clone ( ) ,
214240 left_fut,
215241 right : stream,
216242 right_batch : Arc :: new ( parking_lot:: Mutex :: new ( None ) ) ,
217243 left_index : 0 ,
218- num_input_batches : 0 ,
219- num_input_rows : 0 ,
220- num_output_batches : 0 ,
221- num_output_rows : 0 ,
222- join_time : 0 ,
244+ join_metrics,
245+ reservation,
223246 } ) )
224247 }
225248
@@ -321,16 +344,10 @@ struct CrossJoinStream {
321344 left_index : usize ,
322345 /// Current batch being processed from the right side
323346 right_batch : Arc < parking_lot:: Mutex < Option < RecordBatch > > > ,
324- /// number of input batches
325- num_input_batches : usize ,
326- /// number of input rows
327- num_input_rows : usize ,
328- /// number of batches produced
329- num_output_batches : usize ,
330- /// number of rows produced
331- num_output_rows : usize ,
332- /// total time for joining probe-side batches to the build-side batches
333- join_time : usize ,
347+ /// join execution metrics
348+ join_metrics : BuildProbeJoinMetrics ,
349+ /// memory reservation
350+ reservation : SharedMemoryReservation ,
334351}
335352
336353impl RecordBatchStream for CrossJoinStream {
@@ -385,28 +402,30 @@ impl CrossJoinStream {
385402 & mut self ,
386403 cx : & mut std:: task:: Context < ' _ > ,
387404 ) -> std:: task:: Poll < Option < Result < RecordBatch > > > {
405+ let build_timer = self . join_metrics . build_time . timer ( ) ;
388406 let left_data = match ready ! ( self . left_fut. get( cx) ) {
389407 Ok ( left_data) => left_data,
390408 Err ( e) => return Poll :: Ready ( Some ( Err ( e) ) ) ,
391409 } ;
410+ build_timer. done ( ) ;
392411
393412 if left_data. num_rows ( ) == 0 {
394413 return Poll :: Ready ( None ) ;
395414 }
396415
397416 if self . left_index > 0 && self . left_index < left_data. num_rows ( ) {
398- let start = Instant :: now ( ) ;
417+ let join_timer = self . join_metrics . join_time . timer ( ) ;
399418 let right_batch = {
400419 let right_batch = self . right_batch . lock ( ) ;
401420 right_batch. clone ( ) . unwrap ( )
402421 } ;
403422 let result =
404423 build_batch ( self . left_index , & right_batch, left_data, & self . schema ) ;
405- self . num_input_rows += right_batch. num_rows ( ) ;
424+ self . join_metrics . input_rows . add ( right_batch. num_rows ( ) ) ;
406425 if let Ok ( ref batch) = result {
407- self . join_time += start . elapsed ( ) . as_millis ( ) as usize ;
408- self . num_output_batches += 1 ;
409- self . num_output_rows += batch. num_rows ( ) ;
426+ join_timer . done ( ) ;
427+ self . join_metrics . output_batches . add ( 1 ) ;
428+ self . join_metrics . output_rows . add ( batch. num_rows ( ) ) ;
410429 }
411430 self . left_index += 1 ;
412431 return Poll :: Ready ( Some ( result) ) ;
@@ -416,15 +435,15 @@ impl CrossJoinStream {
416435 . poll_next_unpin ( cx)
417436 . map ( |maybe_batch| match maybe_batch {
418437 Some ( Ok ( batch) ) => {
419- let start = Instant :: now ( ) ;
438+ let join_timer = self . join_metrics . join_time . timer ( ) ;
420439 let result =
421440 build_batch ( self . left_index , & batch, left_data, & self . schema ) ;
422- self . num_input_batches += 1 ;
423- self . num_input_rows += batch. num_rows ( ) ;
441+ self . join_metrics . input_batches . add ( 1 ) ;
442+ self . join_metrics . input_rows . add ( batch. num_rows ( ) ) ;
424443 if let Ok ( ref batch) = result {
425- self . join_time += start . elapsed ( ) . as_millis ( ) as usize ;
426- self . num_output_batches += 1 ;
427- self . num_output_rows += batch. num_rows ( ) ;
444+ join_timer . done ( ) ;
445+ self . join_metrics . output_batches . add ( 1 ) ;
446+ self . join_metrics . output_rows . add ( batch. num_rows ( ) ) ;
428447 }
429448 self . left_index = 1 ;
430449
@@ -434,15 +453,7 @@ impl CrossJoinStream {
434453 Some ( result)
435454 }
436455 other => {
437- debug ! (
438- "Processed {} probe-side input batches containing {} rows and \
439- produced {} output batches containing {} rows in {} ms",
440- self . num_input_batches,
441- self . num_input_rows,
442- self . num_output_batches,
443- self . num_output_rows,
444- self . join_time
445- ) ;
456+ self . reservation . lock ( ) . free ( ) ;
446457 other
447458 }
448459 } )
@@ -452,6 +463,26 @@ impl CrossJoinStream {
452463#[ cfg( test) ]
453464mod tests {
454465 use super :: * ;
466+ use crate :: assert_batches_sorted_eq;
467+ use crate :: common:: assert_contains;
468+ use crate :: execution:: runtime_env:: { RuntimeConfig , RuntimeEnv } ;
469+ use crate :: physical_plan:: common;
470+ use crate :: prelude:: { SessionConfig , SessionContext } ;
471+ use crate :: test:: { build_table_scan_i32, columns} ;
472+
473+ async fn join_collect (
474+ left : Arc < dyn ExecutionPlan > ,
475+ right : Arc < dyn ExecutionPlan > ,
476+ context : Arc < TaskContext > ,
477+ ) -> Result < ( Vec < String > , Vec < RecordBatch > ) > {
478+ let join = CrossJoinExec :: new ( left, right) ;
479+ let columns_header = columns ( & join. schema ( ) ) ;
480+
481+ let stream = join. execute ( 0 , context) ?;
482+ let batches = common:: collect ( stream) . await ?;
483+
484+ Ok ( ( columns_header, batches) )
485+ }
455486
456487 #[ tokio:: test]
457488 async fn test_stats_cartesian_product ( ) {
@@ -589,4 +620,70 @@ mod tests {
589620
590621 assert_eq ! ( result, expected) ;
591622 }
623+
624+ #[ tokio:: test]
625+ async fn test_join ( ) -> Result < ( ) > {
626+ let session_ctx = SessionContext :: new ( ) ;
627+ let task_ctx = session_ctx. task_ctx ( ) ;
628+
629+ let left = build_table_scan_i32 (
630+ ( "a1" , & vec ! [ 1 , 2 , 3 ] ) ,
631+ ( "b1" , & vec ! [ 4 , 5 , 6 ] ) ,
632+ ( "c1" , & vec ! [ 7 , 8 , 9 ] ) ,
633+ ) ;
634+ let right = build_table_scan_i32 (
635+ ( "a2" , & vec ! [ 10 , 11 ] ) ,
636+ ( "b2" , & vec ! [ 12 , 13 ] ) ,
637+ ( "c2" , & vec ! [ 14 , 15 ] ) ,
638+ ) ;
639+
640+ let ( columns, batches) = join_collect ( left, right, task_ctx) . await ?;
641+
642+ assert_eq ! ( columns, vec![ "a1" , "b1" , "c1" , "a2" , "b2" , "c2" ] ) ;
643+ let expected = vec ! [
644+ "+----+----+----+----+----+----+" ,
645+ "| a1 | b1 | c1 | a2 | b2 | c2 |" ,
646+ "+----+----+----+----+----+----+" ,
647+ "| 1 | 4 | 7 | 10 | 12 | 14 |" ,
648+ "| 1 | 4 | 7 | 11 | 13 | 15 |" ,
649+ "| 2 | 5 | 8 | 10 | 12 | 14 |" ,
650+ "| 2 | 5 | 8 | 11 | 13 | 15 |" ,
651+ "| 3 | 6 | 9 | 10 | 12 | 14 |" ,
652+ "| 3 | 6 | 9 | 11 | 13 | 15 |" ,
653+ "+----+----+----+----+----+----+" ,
654+ ] ;
655+
656+ assert_batches_sorted_eq ! ( expected, & batches) ;
657+
658+ Ok ( ( ) )
659+ }
660+
661+ #[ tokio:: test]
662+ async fn test_overallocation ( ) -> Result < ( ) > {
663+ let runtime_config = RuntimeConfig :: new ( ) . with_memory_limit ( 100 , 1.0 ) ;
664+ let runtime = Arc :: new ( RuntimeEnv :: new ( runtime_config) ?) ;
665+ let session_ctx =
666+ SessionContext :: with_config_rt ( SessionConfig :: default ( ) , runtime) ;
667+ let task_ctx = session_ctx. task_ctx ( ) ;
668+
669+ let left = build_table_scan_i32 (
670+ ( "a1" , & vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ] ) ,
671+ ( "b1" , & vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ] ) ,
672+ ( "c1" , & vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ] ) ,
673+ ) ;
674+ let right = build_table_scan_i32 (
675+ ( "a2" , & vec ! [ 10 , 11 ] ) ,
676+ ( "b2" , & vec ! [ 12 , 13 ] ) ,
677+ ( "c2" , & vec ! [ 14 , 15 ] ) ,
678+ ) ;
679+
680+ let err = join_collect ( left, right, task_ctx) . await . unwrap_err ( ) ;
681+
682+ assert_contains ! (
683+ err. to_string( ) ,
684+ "External error: Resources exhausted: Failed to allocate additional"
685+ ) ;
686+
687+ Ok ( ( ) )
688+ }
592689}
0 commit comments