1717
1818//! Support the coercion rule for aggregate function. 
1919
20- use  crate :: arrow:: datatypes:: Schema ; 
2120use  crate :: error:: { DataFusionError ,  Result } ; 
2221use  crate :: physical_plan:: aggregates:: AggregateFunction ; 
2322use  crate :: physical_plan:: expressions:: { 
@@ -26,6 +25,10 @@ use crate::physical_plan::expressions::{
2625} ; 
2726use  crate :: physical_plan:: functions:: { Signature ,  TypeSignature } ; 
2827use  crate :: physical_plan:: PhysicalExpr ; 
28+ use  crate :: { 
29+     arrow:: datatypes:: Schema , 
30+     physical_plan:: expressions:: is_approx_quantile_supported_arg_type, 
31+ } ; 
2932use  arrow:: datatypes:: DataType ; 
3033use  std:: ops:: Deref ; 
3134use  std:: sync:: Arc ; 
@@ -37,24 +40,9 @@ pub(crate) fn coerce_types(
3740    input_types :  & [ DataType ] , 
3841    signature :  & Signature , 
3942)  -> Result < Vec < DataType > >  { 
40-     match  signature. type_signature  { 
41-         TypeSignature :: Uniform ( agg_count,  _)  | TypeSignature :: Any ( agg_count)  => { 
42-             if  input_types. len ( )  != agg_count { 
43-                 return  Err ( DataFusionError :: Plan ( format ! ( 
44-                     "The function {:?} expects {:?} arguments, but {:?} were provided" , 
45-                     agg_fun, 
46-                     agg_count, 
47-                     input_types. len( ) 
48-                 ) ) ) ; 
49-             } 
50-         } 
51-         _ => { 
52-             return  Err ( DataFusionError :: Internal ( format ! ( 
53-                 "Aggregate functions do not support this {:?}" , 
54-                 signature
55-             ) ) ) ; 
56-         } 
57-     } ; 
43+     // Validate input_types matches (at least one of) the func signature. 
44+     check_arg_count ( agg_fun,  input_types,  & signature. type_signature ) ?; 
45+ 
5846    match  agg_fun { 
5947        AggregateFunction :: Count  | AggregateFunction :: ApproxDistinct  => { 
6048            Ok ( input_types. to_vec ( ) ) 
@@ -123,7 +111,75 @@ pub(crate) fn coerce_types(
123111            } 
124112            Ok ( input_types. to_vec ( ) ) 
125113        } 
114+         AggregateFunction :: ApproxQuantile  => { 
115+             if  !is_approx_quantile_supported_arg_type ( & input_types[ 0 ] )  { 
116+                 return  Err ( DataFusionError :: Plan ( format ! ( 
117+                     "The function {:?} does not support inputs of type {:?}." , 
118+                     agg_fun,  input_types[ 0 ] 
119+                 ) ) ) ; 
120+             } 
121+             if  !matches ! ( input_types[ 1 ] ,  DataType :: Float64 )  { 
122+                 return  Err ( DataFusionError :: Plan ( format ! ( 
123+                     "The quantile argument for {:?} must be Float64, not {:?}." , 
124+                     agg_fun,  input_types[ 1 ] 
125+                 ) ) ) ; 
126+             } 
127+             Ok ( input_types. to_vec ( ) ) 
128+         } 
129+     } 
130+ } 
131+ 
132+ /// Validate the length of `input_types` matches the `signature` for `agg_fun`. 
133+ /// 
134+ /// This method DOES NOT validate the argument types - only that (at least one, 
135+ /// in the case of [`TypeSignature::OneOf`]) signature matches the desired 
136+ /// number of input types. 
137+ fn  check_arg_count ( 
138+     agg_fun :  & AggregateFunction , 
139+     input_types :  & [ DataType ] , 
140+     signature :  & TypeSignature , 
141+ )  -> Result < ( ) >  { 
142+     match  signature { 
143+         TypeSignature :: Uniform ( agg_count,  _)  | TypeSignature :: Any ( agg_count)  => { 
144+             if  input_types. len ( )  != * agg_count { 
145+                 return  Err ( DataFusionError :: Plan ( format ! ( 
146+                     "The function {:?} expects {:?} arguments, but {:?} were provided" , 
147+                     agg_fun, 
148+                     agg_count, 
149+                     input_types. len( ) 
150+                 ) ) ) ; 
151+             } 
152+         } 
153+         TypeSignature :: Exact ( types)  => { 
154+             if  types. len ( )  != input_types. len ( )  { 
155+                 return  Err ( DataFusionError :: Plan ( format ! ( 
156+                     "The function {:?} expects {:?} arguments, but {:?} were provided" , 
157+                     agg_fun, 
158+                     types. len( ) , 
159+                     input_types. len( ) 
160+                 ) ) ) ; 
161+             } 
162+         } 
163+         TypeSignature :: OneOf ( variants)  => { 
164+             let  ok = variants
165+                 . iter ( ) 
166+                 . any ( |v| check_arg_count ( agg_fun,  input_types,  v) . is_ok ( ) ) ; 
167+             if  !ok { 
168+                 return  Err ( DataFusionError :: Plan ( format ! ( 
169+                     "The function {:?} does not accept {:?} function arguments." , 
170+                     agg_fun, 
171+                     input_types. len( ) 
172+                 ) ) ) ; 
173+             } 
174+         } 
175+         _ => { 
176+             return  Err ( DataFusionError :: Internal ( format ! ( 
177+                 "Aggregate functions do not support this {:?}" , 
178+                 signature
179+             ) ) ) ; 
180+         } 
126181    } 
182+     Ok ( ( ) ) 
127183} 
128184
129185fn  get_min_max_result_type ( input_types :  & [ DataType ] )  -> Result < Vec < DataType > >  { 
@@ -239,5 +295,25 @@ mod tests {
239295                assert_eq ! ( * input_type,  result. unwrap( ) ) ; 
240296            } 
241297        } 
298+ 
299+         // ApproxQuantile input types 
300+         let  input_types = vec ! [ 
301+             vec![ DataType :: Int8 ,  DataType :: Float64 ] , 
302+             vec![ DataType :: Int16 ,  DataType :: Float64 ] , 
303+             vec![ DataType :: Int32 ,  DataType :: Float64 ] , 
304+             vec![ DataType :: Int64 ,  DataType :: Float64 ] , 
305+             vec![ DataType :: UInt8 ,  DataType :: Float64 ] , 
306+             vec![ DataType :: UInt16 ,  DataType :: Float64 ] , 
307+             vec![ DataType :: UInt32 ,  DataType :: Float64 ] , 
308+             vec![ DataType :: UInt64 ,  DataType :: Float64 ] , 
309+             vec![ DataType :: Float32 ,  DataType :: Float64 ] , 
310+             vec![ DataType :: Float64 ,  DataType :: Float64 ] , 
311+         ] ; 
312+         for  input_type in  & input_types { 
313+             let  signature = aggregates:: signature ( & AggregateFunction :: ApproxQuantile ) ; 
314+             let  result =
315+                 coerce_types ( & AggregateFunction :: ApproxQuantile ,  input_type,  & signature) ; 
316+             assert_eq ! ( * input_type,  result. unwrap( ) ) ; 
317+         } 
242318    } 
243319} 
0 commit comments