2121use crate :: error:: { DataFusionError , Result } ;
2222use crate :: physical_plan:: window_functions:: PartitionEvaluator ;
2323use crate :: physical_plan:: { window_functions:: BuiltInWindowFunctionExpr , PhysicalExpr } ;
24+ use crate :: scalar:: ScalarValue ;
2425use arrow:: array:: ArrayRef ;
25- use arrow:: compute:: kernels :: window :: shift ;
26+ use arrow:: compute:: cast ;
2627use arrow:: datatypes:: { DataType , Field } ;
2728use arrow:: record_batch:: RecordBatch ;
2829use std:: any:: Any ;
30+ use std:: ops:: Neg ;
2931use std:: ops:: Range ;
3032use std:: sync:: Arc ;
3133
@@ -36,19 +38,23 @@ pub struct WindowShift {
3638 data_type : DataType ,
3739 shift_offset : i64 ,
3840 expr : Arc < dyn PhysicalExpr > ,
41+ default_value : Option < ScalarValue > ,
3942}
4043
4144/// lead() window function
4245pub fn lead (
4346 name : String ,
4447 data_type : DataType ,
4548 expr : Arc < dyn PhysicalExpr > ,
49+ shift_offset : Option < i64 > ,
50+ default_value : Option < ScalarValue > ,
4651) -> WindowShift {
4752 WindowShift {
4853 name,
4954 data_type,
50- shift_offset : - 1 ,
55+ shift_offset : shift_offset . map ( |v| v . neg ( ) ) . unwrap_or ( - 1 ) ,
5156 expr,
57+ default_value,
5258 }
5359}
5460
@@ -57,12 +63,15 @@ pub fn lag(
5763 name : String ,
5864 data_type : DataType ,
5965 expr : Arc < dyn PhysicalExpr > ,
66+ shift_offset : Option < i64 > ,
67+ default_value : Option < ScalarValue > ,
6068) -> WindowShift {
6169 WindowShift {
6270 name,
6371 data_type,
64- shift_offset : 1 ,
72+ shift_offset : shift_offset . unwrap_or ( 1 ) ,
6573 expr,
74+ default_value,
6675 }
6776}
6877
@@ -98,20 +107,71 @@ impl BuiltInWindowFunctionExpr for WindowShift {
98107 Ok ( Box :: new ( WindowShiftEvaluator {
99108 shift_offset : self . shift_offset ,
100109 values,
110+ default_value : self . default_value . clone ( ) ,
101111 } ) )
102112 }
103113}
104114
105115pub ( crate ) struct WindowShiftEvaluator {
106116 shift_offset : i64 ,
107117 values : Vec < ArrayRef > ,
118+ default_value : Option < ScalarValue > ,
119+ }
120+
121+ fn create_empty_array (
122+ value : & Option < ScalarValue > ,
123+ data_type : & DataType ,
124+ size : usize ,
125+ ) -> Result < ArrayRef > {
126+ use arrow:: array:: new_null_array;
127+ let array = value
128+ . as_ref ( )
129+ . map ( |scalar| scalar. to_array_of_size ( size) )
130+ . unwrap_or_else ( || new_null_array ( data_type, size) ) ;
131+ if array. data_type ( ) != data_type {
132+ cast ( & array, data_type) . map_err ( DataFusionError :: ArrowError )
133+ } else {
134+ Ok ( array)
135+ }
136+ }
137+
138+ // TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value
139+ fn shift_with_default_value (
140+ array : & ArrayRef ,
141+ offset : i64 ,
142+ value : & Option < ScalarValue > ,
143+ ) -> Result < ArrayRef > {
144+ use arrow:: compute:: concat;
145+
146+ let value_len = array. len ( ) as i64 ;
147+ if offset == 0 {
148+ Ok ( arrow:: array:: make_array ( array. data_ref ( ) . clone ( ) ) )
149+ } else if offset == i64:: MIN || offset. abs ( ) >= value_len {
150+ create_empty_array ( value, array. data_type ( ) , array. len ( ) )
151+ } else {
152+ let slice_offset = ( -offset) . clamp ( 0 , value_len) as usize ;
153+ let length = array. len ( ) - offset. abs ( ) as usize ;
154+ let slice = array. slice ( slice_offset, length) ;
155+
156+ // Generate array with remaining `null` items
157+ let nulls = offset. abs ( ) as usize ;
158+ let default_values = create_empty_array ( value, slice. data_type ( ) , nulls) ?;
159+ // Concatenate both arrays, add nulls after if shift > 0 else before
160+ if offset > 0 {
161+ concat ( & [ default_values. as_ref ( ) , slice. as_ref ( ) ] )
162+ . map_err ( DataFusionError :: ArrowError )
163+ } else {
164+ concat ( & [ slice. as_ref ( ) , default_values. as_ref ( ) ] )
165+ . map_err ( DataFusionError :: ArrowError )
166+ }
167+ }
108168}
109169
110170impl PartitionEvaluator for WindowShiftEvaluator {
111171 fn evaluate_partition ( & self , partition : Range < usize > ) -> Result < ArrayRef > {
112172 let value = & self . values [ 0 ] ;
113173 let value = value. slice ( partition. start , partition. end - partition. start ) ;
114- shift ( value. as_ref ( ) , self . shift_offset ) . map_err ( DataFusionError :: ArrowError )
174+ shift_with_default_value ( & value, self . shift_offset , & self . default_value )
115175 }
116176}
117177
@@ -142,6 +202,8 @@ mod tests {
142202 "lead" . to_owned ( ) ,
143203 DataType :: Float32 ,
144204 Arc :: new ( Column :: new ( "c3" , 0 ) ) ,
205+ None ,
206+ None ,
145207 ) ,
146208 vec ! [
147209 Some ( -2 ) ,
@@ -162,6 +224,8 @@ mod tests {
162224 "lead" . to_owned ( ) ,
163225 DataType :: Float32 ,
164226 Arc :: new ( Column :: new ( "c3" , 0 ) ) ,
227+ None ,
228+ None ,
165229 ) ,
166230 vec ! [
167231 None ,
@@ -176,6 +240,28 @@ mod tests {
176240 . iter ( )
177241 . collect :: < Int32Array > ( ) ,
178242 ) ?;
243+
244+ test_i32_result (
245+ lag (
246+ "lead" . to_owned ( ) ,
247+ DataType :: Int32 ,
248+ Arc :: new ( Column :: new ( "c3" , 0 ) ) ,
249+ None ,
250+ Some ( ScalarValue :: Int32 ( Some ( 100 ) ) ) ,
251+ ) ,
252+ vec ! [
253+ Some ( 100 ) ,
254+ Some ( 1 ) ,
255+ Some ( -2 ) ,
256+ Some ( 3 ) ,
257+ Some ( -4 ) ,
258+ Some ( 5 ) ,
259+ Some ( -6 ) ,
260+ Some ( 7 ) ,
261+ ]
262+ . iter ( )
263+ . collect :: < Int32Array > ( ) ,
264+ ) ?;
179265 Ok ( ( ) )
180266 }
181267}
0 commit comments