|
20 | 20 | //! |
21 | 21 | //! see also https://www.postgresql.org/docs/current/functions-window.html |
22 | 22 |
|
23 | | -use super::expressions::{avg_return_type, sum_return_type}; |
24 | 23 | use super::{functions::Signature, type_coercion::data_types}; |
25 | 24 | use crate::error::{DataFusionError, Result}; |
26 | | -use crate::physical_plan::aggregates::AggregateFunction; |
27 | | -use arrow::datatypes::{DataType, TimeUnit}; |
| 25 | +use crate::physical_plan::{aggregates, aggregates::AggregateFunction}; |
| 26 | +use arrow::datatypes::DataType; |
28 | 27 | use std::{fmt, str::FromStr}; |
29 | 28 |
|
30 | 29 | /// WindowFunction |
@@ -90,59 +89,32 @@ pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result<DataT |
90 | 89 | data_types(arg_types, &signature(fun))?; |
91 | 90 |
|
92 | 91 | match fun { |
93 | | - WindowFunction::AggregateFunction(fun) => match fun { |
94 | | - AggregateFunction::Count => Ok(DataType::UInt64), |
95 | | - AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()), |
96 | | - AggregateFunction::Sum => sum_return_type(&arg_types[0]), |
97 | | - AggregateFunction::Avg => avg_return_type(&arg_types[0]), |
98 | | - }, |
99 | | - WindowFunction::BuiltInWindowFunction(_) => Ok(arg_types[0].clone()), |
| 92 | + WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types), |
| 93 | + WindowFunction::BuiltInWindowFunction(fun) => Ok(match fun { |
| 94 | + BuiltInWindowFunction::RowNumber |
| 95 | + | BuiltInWindowFunction::Rank |
| 96 | + | BuiltInWindowFunction::DenseRank => DataType::UInt64, |
| 97 | + BuiltInWindowFunction::Lag |
| 98 | + | BuiltInWindowFunction::Lead |
| 99 | + | BuiltInWindowFunction::FirstValue |
| 100 | + | BuiltInWindowFunction::LastValue => arg_types[0].clone(), |
| 101 | + }), |
100 | 102 | } |
101 | 103 | } |
102 | 104 |
|
103 | | -static NUMERICS: &[DataType] = &[ |
104 | | - DataType::Int8, |
105 | | - DataType::Int16, |
106 | | - DataType::Int32, |
107 | | - DataType::Int64, |
108 | | - DataType::UInt8, |
109 | | - DataType::UInt16, |
110 | | - DataType::UInt32, |
111 | | - DataType::UInt64, |
112 | | - DataType::Float32, |
113 | | - DataType::Float64, |
114 | | -]; |
115 | | - |
116 | | -static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; |
117 | | - |
118 | | -static TIMESTAMPS: &[DataType] = &[ |
119 | | - DataType::Timestamp(TimeUnit::Second, None), |
120 | | - DataType::Timestamp(TimeUnit::Millisecond, None), |
121 | | - DataType::Timestamp(TimeUnit::Microsecond, None), |
122 | | - DataType::Timestamp(TimeUnit::Nanosecond, None), |
123 | | -]; |
124 | | - |
125 | 105 | /// the signatures supported by the function `fun`. |
126 | 106 | fn signature(fun: &WindowFunction) -> Signature { |
127 | 107 | // note: the physical expression must accept the type returned by this function or the execution panics. |
128 | 108 | match fun { |
129 | | - WindowFunction::AggregateFunction(fun) => match fun { |
130 | | - AggregateFunction::Count => Signature::Any(1), |
131 | | - AggregateFunction::Min | AggregateFunction::Max => { |
132 | | - let valid = STRINGS |
133 | | - .iter() |
134 | | - .chain(NUMERICS.iter()) |
135 | | - .chain(TIMESTAMPS.iter()) |
136 | | - .cloned() |
137 | | - .collect::<Vec<_>>(); |
138 | | - Signature::Uniform(1, valid) |
139 | | - } |
140 | | - AggregateFunction::Avg | AggregateFunction::Sum => { |
141 | | - Signature::Uniform(1, NUMERICS.to_vec()) |
142 | | - } |
| 109 | + WindowFunction::AggregateFunction(fun) => aggregates::signature(fun), |
| 110 | + WindowFunction::BuiltInWindowFunction(fun) => match fun { |
| 111 | + BuiltInWindowFunction::RowNumber |
| 112 | + | BuiltInWindowFunction::Rank |
| 113 | + | BuiltInWindowFunction::DenseRank => Signature::Any(0), |
| 114 | + BuiltInWindowFunction::Lag |
| 115 | + | BuiltInWindowFunction::Lead |
| 116 | + | BuiltInWindowFunction::FirstValue |
| 117 | + | BuiltInWindowFunction::LastValue => Signature::Any(1), |
143 | 118 | }, |
144 | | - WindowFunction::BuiltInWindowFunction(_) => { |
145 | | - Signature::Uniform(1, NUMERICS.to_vec()) |
146 | | - } |
147 | 119 | } |
148 | 120 | } |
0 commit comments