@@ -175,10 +175,45 @@ impl WindowExpr for BuiltInWindowExpr {
175175 // case when partition_by is supported, in which case we'll parallelize the calls.
176176 // See https://github.com/apache/arrow-datafusion/issues/299
177177 let values = self . evaluate_args ( batch) ?;
178- self . window . evaluate ( batch. num_rows ( ) , & values)
178+ let partition_points = self . evaluate_partition_points (
179+ batch. num_rows ( ) ,
180+ & self . partition_columns ( batch) ?,
181+ ) ?;
182+ let results = partition_points
183+ . iter ( )
184+ . map ( |partition_range| {
185+ let start = partition_range. start ;
186+ let len = partition_range. end - start;
187+ let values = values
188+ . iter ( )
189+ . map ( |arr| arr. slice ( start, len) )
190+ . collect :: < Vec < _ > > ( ) ;
191+ self . window . evaluate ( len, & values)
192+ } )
193+ . collect :: < Result < Vec < _ > > > ( ) ?
194+ . into_iter ( )
195+ . collect :: < Vec < ArrayRef > > ( ) ;
196+ let results = results. iter ( ) . map ( |i| i. as_ref ( ) ) . collect :: < Vec < _ > > ( ) ;
197+ concat ( & results) . map_err ( DataFusionError :: ArrowError )
179198 }
180199}
181200
201+ /// Given a partition range, and the full list of sort partition points, given that the sort
202+ /// partition points are sorted using [partition columns..., order columns...], the split
203+ /// boundaries would align (what's sorted on [partition columns...] would definitely be sorted
204+ /// on finer columns), so this will use binary search to find ranges that are within the
205+ /// partition range and return the valid slice.
206+ fn find_ranges_in_range < ' a > (
207+ partition_range : & Range < usize > ,
208+ sort_partition_points : & ' a [ Range < usize > ] ,
209+ ) -> & ' a [ Range < usize > ] {
210+ let start_idx = sort_partition_points
211+ . partition_point ( |sort_range| sort_range. start < partition_range. start ) ;
212+ let end_idx = sort_partition_points
213+ . partition_point ( |sort_range| sort_range. end <= partition_range. end ) ;
214+ & sort_partition_points[ start_idx..end_idx]
215+ }
216+
182217/// A window expr that takes the form of an aggregate function
183218#[ derive( Debug ) ]
184219pub struct AggregateWindowExpr {
@@ -205,13 +240,27 @@ impl AggregateWindowExpr {
205240 /// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same
206241 /// results for peers) and concatenate the results.
207242 fn peer_based_evaluate ( & self , batch : & RecordBatch ) -> Result < ArrayRef > {
208- let sort_partition_points = self . evaluate_sort_partition_points ( batch) ?;
209- let mut window_accumulators = self . create_accumulator ( ) ?;
243+ let num_rows = batch. num_rows ( ) ;
244+ let partition_points =
245+ self . evaluate_partition_points ( num_rows, & self . partition_columns ( batch) ?) ?;
246+ let sort_partition_points =
247+ self . evaluate_partition_points ( num_rows, & self . sort_columns ( batch) ?) ?;
210248 let values = self . evaluate_args ( batch) ?;
211- let results = sort_partition_points
249+ let results = partition_points
212250 . iter ( )
213- . map ( |peer_range| window_accumulators. scan_peers ( & values, peer_range) )
214- . collect :: < Result < Vec < _ > > > ( ) ?;
251+ . map ( |partition_range| {
252+ let sort_partition_points =
253+ find_ranges_in_range ( partition_range, & sort_partition_points) ;
254+ let mut window_accumulators = self . create_accumulator ( ) ?;
255+ sort_partition_points
256+ . iter ( )
257+ . map ( |range| window_accumulators. scan_peers ( & values, range) )
258+ . collect :: < Result < Vec < _ > > > ( )
259+ } )
260+ . collect :: < Result < Vec < Vec < ArrayRef > > > > ( ) ?
261+ . into_iter ( )
262+ . flatten ( )
263+ . collect :: < Vec < ArrayRef > > ( ) ;
215264 let results = results. iter ( ) . map ( |i| i. as_ref ( ) ) . collect :: < Vec < _ > > ( ) ;
216265 concat ( & results) . map_err ( DataFusionError :: ArrowError )
217266 }
0 commit comments