1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use datafusion:: prelude:: SessionConfig ;
19- use datafusion_common:: utils:: get_available_parallelism;
18+ use std:: { num:: NonZeroUsize , sync:: Arc } ;
19+
20+ use datafusion:: {
21+ execution:: {
22+ disk_manager:: DiskManagerConfig ,
23+ memory_pool:: { FairSpillPool , GreedyMemoryPool , MemoryPool , TrackConsumersPool } ,
24+ runtime_env:: RuntimeEnvBuilder ,
25+ } ,
26+ prelude:: SessionConfig ,
27+ } ;
28+ use datafusion_common:: { utils:: get_available_parallelism, DataFusionError , Result } ;
2029use structopt:: StructOpt ;
2130
2231// Common benchmark options (don't use doc comments otherwise this doc
@@ -35,6 +44,20 @@ pub struct CommonOpt {
3544 #[ structopt( short = "s" , long = "batch-size" , default_value = "8192" ) ]
3645 pub batch_size : usize ,
3746
47+ /// The memory pool type to use, should be one of "fair" or "greedy"
48+ #[ structopt( long = "mem-pool-type" , default_value = "fair" ) ]
49+ pub mem_pool_type : String ,
50+
51+ /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query
52+ /// if there's any, otherwise run with no memory limit.
53+ #[ structopt( long = "memory-limit" , parse( try_from_str = parse_memory_limit) ) ]
54+ pub memory_limit : Option < usize > ,
55+
56+ /// The amount of memory to reserve for sort spill operations. DataFusion's default value will be used
57+ /// if not specified.
58+ #[ structopt( long = "sort-spill-reservation-bytes" , parse( try_from_str = parse_memory_limit) ) ]
59+ pub sort_spill_reservation_bytes : Option < usize > ,
60+
3861 /// Activate debug mode to see more details
3962 #[ structopt( short, long) ]
4063 pub debug : bool ,
@@ -48,10 +71,81 @@ impl CommonOpt {
4871
4972 /// Modify the existing config appropriately
5073 pub fn update_config ( & self , config : SessionConfig ) -> SessionConfig {
51- config
74+ let mut config = config
5275 . with_target_partitions (
5376 self . partitions . unwrap_or ( get_available_parallelism ( ) ) ,
5477 )
55- . with_batch_size ( self . batch_size )
78+ . with_batch_size ( self . batch_size ) ;
79+ if let Some ( sort_spill_reservation_bytes) = self . sort_spill_reservation_bytes {
80+ config =
81+ config. with_sort_spill_reservation_bytes ( sort_spill_reservation_bytes) ;
82+ }
83+ config
84+ }
85+
86+ /// Return an appropriately configured `RuntimeEnvBuilder`
87+ pub fn runtime_env_builder ( & self ) -> Result < RuntimeEnvBuilder > {
88+ let mut rt_builder = RuntimeEnvBuilder :: new ( ) ;
89+ const NUM_TRACKED_CONSUMERS : usize = 5 ;
90+ if let Some ( memory_limit) = self . memory_limit {
91+ let pool: Arc < dyn MemoryPool > = match self . mem_pool_type . as_str ( ) {
92+ "fair" => Arc :: new ( TrackConsumersPool :: new (
93+ FairSpillPool :: new ( memory_limit) ,
94+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
95+ ) ) ,
96+ "greedy" => Arc :: new ( TrackConsumersPool :: new (
97+ GreedyMemoryPool :: new ( memory_limit) ,
98+ NonZeroUsize :: new ( NUM_TRACKED_CONSUMERS ) . unwrap ( ) ,
99+ ) ) ,
100+ _ => {
101+ return Err ( DataFusionError :: Configuration ( format ! (
102+ "Invalid memory pool type: {}" ,
103+ self . mem_pool_type
104+ ) ) )
105+ }
106+ } ;
107+ rt_builder = rt_builder
108+ . with_memory_pool ( pool)
109+ . with_disk_manager ( DiskManagerConfig :: NewOs ) ;
110+ }
111+ Ok ( rt_builder)
112+ }
113+ }
114+
115+ /// Parse memory limit from string to number of bytes
116+ /// e.g. '1.5G', '100M' -> 1572864
117+ fn parse_memory_limit ( limit : & str ) -> Result < usize , String > {
118+ let ( number, unit) = limit. split_at ( limit. len ( ) - 1 ) ;
119+ let number: f64 = number
120+ . parse ( )
121+ . map_err ( |_| format ! ( "Failed to parse number from memory limit '{}'" , limit) ) ?;
122+
123+ match unit {
124+ "K" => Ok ( ( number * 1024.0 ) as usize ) ,
125+ "M" => Ok ( ( number * 1024.0 * 1024.0 ) as usize ) ,
126+ "G" => Ok ( ( number * 1024.0 * 1024.0 * 1024.0 ) as usize ) ,
127+ _ => Err ( format ! (
128+ "Unsupported unit '{}' in memory limit '{}'" ,
129+ unit, limit
130+ ) ) ,
131+ }
132+ }
133+
134+ #[ cfg( test) ]
135+ mod tests {
136+ use super :: * ;
137+
138+ #[ test]
139+ fn test_parse_memory_limit_all ( ) {
140+ // Test valid inputs
141+ assert_eq ! ( parse_memory_limit( "100K" ) . unwrap( ) , 102400 ) ;
142+ assert_eq ! ( parse_memory_limit( "1.5M" ) . unwrap( ) , 1572864 ) ;
143+ assert_eq ! ( parse_memory_limit( "2G" ) . unwrap( ) , 2147483648 ) ;
144+
145+ // Test invalid unit
146+ assert ! ( parse_memory_limit( "500X" ) . is_err( ) ) ;
147+
148+ // Test invalid number
149+ assert ! ( parse_memory_limit( "abcM" ) . is_err( ) ) ;
56150 }
57151}
0 commit comments