@@ -19,12 +19,11 @@ use std::any::Any;
1919use std:: str:: FromStr ;
2020use std:: sync:: { Arc , OnceLock } ;
2121
22- use arrow:: array:: { Array , ArrayRef , Float64Array } ;
22+ use arrow:: array:: { Array , ArrayRef , Float64Array , Int32Array } ;
2323use arrow:: compute:: kernels:: cast_utils:: IntervalUnit ;
24- use arrow:: compute:: { binary, cast , date_part, DatePart } ;
24+ use arrow:: compute:: { binary, date_part, DatePart } ;
2525use arrow:: datatypes:: DataType :: {
26- Date32 , Date64 , Duration , Float64 , Interval , Time32 , Time64 , Timestamp , Utf8 ,
27- Utf8View ,
26+ Date32 , Date64 , Duration , Interval , Time32 , Time64 , Timestamp , Utf8 , Utf8View ,
2827} ;
2928use arrow:: datatypes:: IntervalUnit :: { DayTime , MonthDayNano , YearMonth } ;
3029use arrow:: datatypes:: TimeUnit :: { Microsecond , Millisecond , Nanosecond , Second } ;
@@ -36,11 +35,12 @@ use datafusion_common::cast::{
3635 as_timestamp_microsecond_array, as_timestamp_millisecond_array,
3736 as_timestamp_nanosecond_array, as_timestamp_second_array,
3837} ;
39- use datafusion_common:: { exec_err, Result , ScalarValue } ;
38+ use datafusion_common:: { exec_err, internal_err , ExprSchema , Result , ScalarValue } ;
4039use datafusion_expr:: scalar_doc_sections:: DOC_SECTION_DATETIME ;
4140use datafusion_expr:: TypeSignature :: Exact ;
4241use datafusion_expr:: {
43- ColumnarValue , Documentation , ScalarUDFImpl , Signature , Volatility , TIMEZONE_WILDCARD ,
42+ ColumnarValue , Documentation , Expr , ScalarUDFImpl , Signature , Volatility ,
43+ TIMEZONE_WILDCARD ,
4444} ;
4545
4646#[ derive( Debug ) ]
@@ -148,7 +148,21 @@ impl ScalarUDFImpl for DatePartFunc {
148148 }
149149
150150 fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
151- Ok ( Float64 )
151+ internal_err ! ( "return_type_from_exprs shoud be called instead" )
152+ }
153+
154+ fn return_type_from_exprs (
155+ & self ,
156+ args : & [ Expr ] ,
157+ _schema : & dyn ExprSchema ,
158+ _arg_types : & [ DataType ] ,
159+ ) -> Result < DataType > {
160+ match & args[ 0 ] {
161+ Expr :: Literal ( ScalarValue :: Utf8 ( Some ( part) ) ) if is_epoch ( part) => {
162+ Ok ( DataType :: Float64 )
163+ }
164+ _ => Ok ( DataType :: Int32 ) ,
165+ }
152166 }
153167
154168 fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
@@ -174,35 +188,31 @@ impl ScalarUDFImpl for DatePartFunc {
174188 ColumnarValue :: Scalar ( scalar) => scalar. to_array ( ) ?,
175189 } ;
176190
177- // to remove quotes at most 2 characters
178- let part_trim = part. trim_matches ( |c| c == '\'' || c == '\"' ) ;
179- if ![ 2 , 0 ] . contains ( & ( part. len ( ) - part_trim. len ( ) ) ) {
180- return exec_err ! ( "Date part '{part}' not supported" ) ;
181- }
191+ let part_trim = part_normalization ( part) ;
182192
183193 // using IntervalUnit here means we hand off all the work of supporting plurals (like "seconds")
184194 // and synonyms ( like "ms,msec,msecond,millisecond") to Arrow
185195 let arr = if let Ok ( interval_unit) = IntervalUnit :: from_str ( part_trim) {
186196 match interval_unit {
187- IntervalUnit :: Year => date_part_f64 ( array. as_ref ( ) , DatePart :: Year ) ?,
188- IntervalUnit :: Month => date_part_f64 ( array. as_ref ( ) , DatePart :: Month ) ?,
189- IntervalUnit :: Week => date_part_f64 ( array. as_ref ( ) , DatePart :: Week ) ?,
190- IntervalUnit :: Day => date_part_f64 ( array. as_ref ( ) , DatePart :: Day ) ?,
191- IntervalUnit :: Hour => date_part_f64 ( array. as_ref ( ) , DatePart :: Hour ) ?,
192- IntervalUnit :: Minute => date_part_f64 ( array. as_ref ( ) , DatePart :: Minute ) ?,
193- IntervalUnit :: Second => seconds ( array. as_ref ( ) , Second ) ?,
194- IntervalUnit :: Millisecond => seconds ( array. as_ref ( ) , Millisecond ) ?,
195- IntervalUnit :: Microsecond => seconds ( array. as_ref ( ) , Microsecond ) ?,
196- IntervalUnit :: Nanosecond => seconds ( array. as_ref ( ) , Nanosecond ) ?,
197+ IntervalUnit :: Year => date_part ( array. as_ref ( ) , DatePart :: Year ) ?,
198+ IntervalUnit :: Month => date_part ( array. as_ref ( ) , DatePart :: Month ) ?,
199+ IntervalUnit :: Week => date_part ( array. as_ref ( ) , DatePart :: Week ) ?,
200+ IntervalUnit :: Day => date_part ( array. as_ref ( ) , DatePart :: Day ) ?,
201+ IntervalUnit :: Hour => date_part ( array. as_ref ( ) , DatePart :: Hour ) ?,
202+ IntervalUnit :: Minute => date_part ( array. as_ref ( ) , DatePart :: Minute ) ?,
203+ IntervalUnit :: Second => seconds_as_i32 ( array. as_ref ( ) , Second ) ?,
204+ IntervalUnit :: Millisecond => seconds_as_i32 ( array. as_ref ( ) , Millisecond ) ?,
205+ IntervalUnit :: Microsecond => seconds_as_i32 ( array. as_ref ( ) , Microsecond ) ?,
206+ IntervalUnit :: Nanosecond => seconds_as_i32 ( array. as_ref ( ) , Nanosecond ) ?,
197207 // century and decade are not supported by `DatePart`, although they are supported in postgres
198208 _ => return exec_err ! ( "Date part '{part}' not supported" ) ,
199209 }
200210 } else {
201211 // special cases that can be extracted (in postgres) but are not interval units
202212 match part_trim. to_lowercase ( ) . as_str ( ) {
203- "qtr" | "quarter" => date_part_f64 ( array. as_ref ( ) , DatePart :: Quarter ) ?,
204- "doy" => date_part_f64 ( array. as_ref ( ) , DatePart :: DayOfYear ) ?,
205- "dow" => date_part_f64 ( array. as_ref ( ) , DatePart :: DayOfWeekSunday0 ) ?,
213+ "qtr" | "quarter" => date_part ( array. as_ref ( ) , DatePart :: Quarter ) ?,
214+ "doy" => date_part ( array. as_ref ( ) , DatePart :: DayOfYear ) ?,
215+ "dow" => date_part ( array. as_ref ( ) , DatePart :: DayOfWeekSunday0 ) ?,
206216 "epoch" => epoch ( array. as_ref ( ) ) ?,
207217 _ => return exec_err ! ( "Date part '{part}' not supported" ) ,
208218 }
@@ -223,6 +233,18 @@ impl ScalarUDFImpl for DatePartFunc {
223233 }
224234}
225235
236+ fn is_epoch ( part : & str ) -> bool {
237+ let part = part_normalization ( part) ;
238+ matches ! ( part. to_lowercase( ) . as_str( ) , "epoch" )
239+ }
240+
241+ // Try to remove quote if exist, if the quote is invalid, return original string and let the downstream function handle the error
242+ fn part_normalization ( part : & str ) -> & str {
243+ part. strip_prefix ( |c| c == '\'' || c == '\"' )
244+ . and_then ( |s| s. strip_suffix ( |c| c == '\'' || c == '\"' ) )
245+ . unwrap_or ( part)
246+ }
247+
226248static DOCUMENTATION : OnceLock < Documentation > = OnceLock :: new ( ) ;
227249
228250fn get_date_part_doc ( ) -> & ' static Documentation {
@@ -261,14 +283,63 @@ fn get_date_part_doc() -> &'static Documentation {
261283 } )
262284}
263285
264- /// Invoke [`date_part`] and cast the result to Float64
265- fn date_part_f64 ( array : & dyn Array , part : DatePart ) -> Result < ArrayRef > {
266- Ok ( cast ( date_part ( array, part) ?. as_ref ( ) , & Float64 ) ?)
286+ /// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the
287+ /// result to a total number of seconds, milliseconds, microseconds or
288+ /// nanoseconds
289+ fn seconds_as_i32 ( array : & dyn Array , unit : TimeUnit ) -> Result < ArrayRef > {
290+ // Nanosecond is neither supported in Postgres nor DuckDB, to avoid to deal with overflow and precision issue we don't support nanosecond
291+ if unit == Nanosecond {
292+ return internal_err ! ( "unit {unit:?} not supported" ) ;
293+ }
294+
295+ let conversion_factor = match unit {
296+ Second => 1_000_000_000 ,
297+ Millisecond => 1_000_000 ,
298+ Microsecond => 1_000 ,
299+ Nanosecond => 1 ,
300+ } ;
301+
302+ let second_factor = match unit {
303+ Second => 1 ,
304+ Millisecond => 1_000 ,
305+ Microsecond => 1_000_000 ,
306+ Nanosecond => 1_000_000_000 ,
307+ } ;
308+
309+ let secs = date_part ( array, DatePart :: Second ) ?;
310+ // This assumes array is primitive and not a dictionary
311+ let secs = as_int32_array ( secs. as_ref ( ) ) ?;
312+ let subsecs = date_part ( array, DatePart :: Nanosecond ) ?;
313+ let subsecs = as_int32_array ( subsecs. as_ref ( ) ) ?;
314+
315+ // Special case where there are no nulls.
316+ if subsecs. null_count ( ) == 0 {
317+ let r: Int32Array = binary ( secs, subsecs, |secs, subsecs| {
318+ secs * second_factor + ( subsecs % 1_000_000_000 ) / conversion_factor
319+ } ) ?;
320+ Ok ( Arc :: new ( r) )
321+ } else {
322+ // Nulls in secs are preserved, nulls in subsecs are treated as zero to account for the case
323+ // where the number of nanoseconds overflows.
324+ let r: Int32Array = secs
325+ . iter ( )
326+ . zip ( subsecs)
327+ . map ( |( secs, subsecs) | {
328+ secs. map ( |secs| {
329+ let subsecs = subsecs. unwrap_or ( 0 ) ;
330+ secs * second_factor + ( subsecs % 1_000_000_000 ) / conversion_factor
331+ } )
332+ } )
333+ . collect ( ) ;
334+ Ok ( Arc :: new ( r) )
335+ }
267336}
268337
269338/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the
270339/// result to a total number of seconds, milliseconds, microseconds or
271340/// nanoseconds
341+ ///
342+ /// Given epoch return f64, this is a duplicated function to optimize for f64 type
272343fn seconds ( array : & dyn Array , unit : TimeUnit ) -> Result < ArrayRef > {
273344 let sf = match unit {
274345 Second => 1_f64 ,
0 commit comments