Skip to content

Commit 05d5f01

Browse files
authored
implement window functions with partition by (#558)
1 parent 5900b4c commit 05d5f01

File tree

9 files changed

+275
-30
lines changed

9 files changed

+275
-30
lines changed

datafusion/src/execution/context.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,80 @@ mod tests {
13551355
Ok(())
13561356
}
13571357

1358+
#[tokio::test]
1359+
async fn window_partition_by() -> Result<()> {
1360+
let results = execute(
1361+
"SELECT \
1362+
c1, \
1363+
c2, \
1364+
SUM(c2) OVER (PARTITION BY c2), \
1365+
COUNT(c2) OVER (PARTITION BY c2), \
1366+
MAX(c2) OVER (PARTITION BY c2), \
1367+
MIN(c2) OVER (PARTITION BY c2), \
1368+
AVG(c2) OVER (PARTITION BY c2) \
1369+
FROM test \
1370+
ORDER BY c1, c2 \
1371+
LIMIT 5",
1372+
4,
1373+
)
1374+
.await?;
1375+
1376+
let expected = vec![
1377+
"+----+----+---------+-----------+---------+---------+---------+",
1378+
"| c1 | c2 | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
1379+
"+----+----+---------+-----------+---------+---------+---------+",
1380+
"| 0 | 1 | 4 | 4 | 1 | 1 | 1 |",
1381+
"| 0 | 2 | 8 | 4 | 2 | 2 | 2 |",
1382+
"| 0 | 3 | 12 | 4 | 3 | 3 | 3 |",
1383+
"| 0 | 4 | 16 | 4 | 4 | 4 | 4 |",
1384+
"| 0 | 5 | 20 | 4 | 5 | 5 | 5 |",
1385+
"+----+----+---------+-----------+---------+---------+---------+",
1386+
];
1387+
1388+
// window function shall respect ordering
1389+
assert_batches_eq!(expected, &results);
1390+
Ok(())
1391+
}
1392+
1393+
#[tokio::test]
1394+
async fn window_partition_by_order_by() -> Result<()> {
1395+
let results = execute(
1396+
"SELECT \
1397+
c1, \
1398+
c2, \
1399+
ROW_NUMBER() OVER (PARTITION BY c2 ORDER BY c1), \
1400+
FIRST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \
1401+
LAST_VALUE(c2 + c1) OVER (PARTITION BY c2 ORDER BY c1), \
1402+
NTH_VALUE(c2 + c1, 2) OVER (PARTITION BY c2 ORDER BY c1), \
1403+
SUM(c2) OVER (PARTITION BY c2 ORDER BY c1), \
1404+
COUNT(c2) OVER (PARTITION BY c2 ORDER BY c1), \
1405+
MAX(c2) OVER (PARTITION BY c2 ORDER BY c1), \
1406+
MIN(c2) OVER (PARTITION BY c2 ORDER BY c1), \
1407+
AVG(c2) OVER (PARTITION BY c2 ORDER BY c1) \
1408+
FROM test \
1409+
ORDER BY c1, c2 \
1410+
LIMIT 5",
1411+
4,
1412+
)
1413+
.await?;
1414+
1415+
let expected = vec![
1416+
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
1417+
"| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2 Plus c1) | LAST_VALUE(c2 Plus c1) | NTH_VALUE(c2 Plus c1,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
1418+
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
1419+
"| 0 | 1 | 1 | 1 | 4 | 2 | 1 | 1 | 1 | 1 | 1 |",
1420+
"| 0 | 2 | 1 | 2 | 5 | 3 | 2 | 1 | 2 | 2 | 2 |",
1421+
"| 0 | 3 | 1 | 3 | 6 | 4 | 3 | 1 | 3 | 3 | 3 |",
1422+
"| 0 | 4 | 1 | 4 | 7 | 5 | 4 | 1 | 4 | 4 | 4 |",
1423+
"| 0 | 5 | 1 | 5 | 8 | 6 | 5 | 1 | 5 | 5 | 5 |",
1424+
"+----+----+--------------+-------------------------+------------------------+--------------------------------+---------+-----------+---------+---------+---------+",
1425+
];
1426+
1427+
// window function shall respect ordering
1428+
assert_batches_eq!(expected, &results);
1429+
Ok(())
1430+
}
1431+
13581432
#[tokio::test]
13591433
async fn aggregate() -> Result<()> {
13601434
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;

datafusion/src/physical_plan/expressions/nth_value.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use crate::error::{DataFusionError, Result};
2121
use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
2222
use crate::scalar::ScalarValue;
23-
use arrow::array::{new_empty_array, ArrayRef};
23+
use arrow::array::{new_empty_array, new_null_array, ArrayRef};
2424
use arrow::datatypes::{DataType, Field};
2525
use std::any::Any;
2626
use std::sync::Arc;
@@ -135,8 +135,12 @@ impl BuiltInWindowFunctionExpr for NthValue {
135135
NthValueKind::Last => (num_rows as usize) - 1,
136136
NthValueKind::Nth(n) => (n as usize) - 1,
137137
};
138-
let value = ScalarValue::try_from_array(value, index)?;
139-
Ok(value.to_array_of_size(num_rows))
138+
Ok(if index >= num_rows {
139+
new_null_array(value.data_type(), num_rows)
140+
} else {
141+
let value = ScalarValue::try_from_array(value, index)?;
142+
value.to_array_of_size(num_rows)
143+
})
140144
}
141145
}
142146

datafusion/src/physical_plan/mod.rs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -485,19 +485,20 @@ pub trait WindowExpr: Send + Sync + Debug {
485485
/// evaluate the window function values against the batch
486486
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
487487

488-
/// evaluate the sort partition points
489-
fn evaluate_sort_partition_points(
488+
/// evaluate the partition points given the sort columns; if the sort columns are
489+
/// empty then the result will be a single element vec of the whole column rows.
490+
fn evaluate_partition_points(
490491
&self,
491-
batch: &RecordBatch,
492+
num_rows: usize,
493+
partition_columns: &[SortColumn],
492494
) -> Result<Vec<Range<usize>>> {
493-
let sort_columns = self.sort_columns(batch)?;
494-
if sort_columns.is_empty() {
495+
if partition_columns.is_empty() {
495496
Ok(vec![Range {
496497
start: 0,
497-
end: batch.num_rows(),
498+
end: num_rows,
498499
}])
499500
} else {
500-
lexicographical_partition_ranges(&sort_columns)
501+
lexicographical_partition_ranges(partition_columns)
501502
.map_err(DataFusionError::ArrowError)
502503
}
503504
}
@@ -508,8 +509,8 @@ pub trait WindowExpr: Send + Sync + Debug {
508509
/// expressions that's from the window function's order by clause, empty if absent
509510
fn order_by(&self) -> &[PhysicalSortExpr];
510511

511-
/// get sort columns that can be used for partitioning, empty if absent
512-
fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
512+
/// get partition columns that can be used for partitioning, empty if absent
513+
fn partition_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
513514
self.partition_by()
514515
.iter()
515516
.map(|expr| {
@@ -519,13 +520,20 @@ pub trait WindowExpr: Send + Sync + Debug {
519520
}
520521
.evaluate_to_sort_column(batch)
521522
})
522-
.chain(
523-
self.order_by()
524-
.iter()
525-
.map(|e| e.evaluate_to_sort_column(batch)),
526-
)
527523
.collect()
528524
}
525+
526+
/// get sort columns that can be used for peer evaluation, empty if absent
527+
fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
528+
let mut sort_columns = self.partition_columns(batch)?;
529+
let order_by_columns = self
530+
.order_by()
531+
.iter()
532+
.map(|e| e.evaluate_to_sort_column(batch))
533+
.collect::<Result<Vec<SortColumn>>>()?;
534+
sort_columns.extend(order_by_columns);
535+
Ok(sort_columns)
536+
}
529537
}
530538

531539
/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and

datafusion/src/physical_plan/planner.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -775,12 +775,6 @@ impl DefaultPhysicalPlanner {
775775
)),
776776
})
777777
.collect::<Result<Vec<_>>>()?;
778-
if !partition_by.is_empty() {
779-
return Err(DataFusionError::NotImplemented(
780-
"window expression with non-empty partition by clause is not yet supported"
781-
.to_owned(),
782-
));
783-
}
784778
if window_frame.is_some() {
785779
return Err(DataFusionError::NotImplemented(
786780
"window expression with window frame definition is not yet supported"

datafusion/src/physical_plan/windows.rs

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,45 @@ impl WindowExpr for BuiltInWindowExpr {
175175
// case when partition_by is supported, in which case we'll parallelize the calls.
176176
// See https://github.com/apache/arrow-datafusion/issues/299
177177
let values = self.evaluate_args(batch)?;
178-
self.window.evaluate(batch.num_rows(), &values)
178+
let partition_points = self.evaluate_partition_points(
179+
batch.num_rows(),
180+
&self.partition_columns(batch)?,
181+
)?;
182+
let results = partition_points
183+
.iter()
184+
.map(|partition_range| {
185+
let start = partition_range.start;
186+
let len = partition_range.end - start;
187+
let values = values
188+
.iter()
189+
.map(|arr| arr.slice(start, len))
190+
.collect::<Vec<_>>();
191+
self.window.evaluate(len, &values)
192+
})
193+
.collect::<Result<Vec<_>>>()?
194+
.into_iter()
195+
.collect::<Vec<ArrayRef>>();
196+
let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
197+
concat(&results).map_err(DataFusionError::ArrowError)
179198
}
180199
}
181200

201+
/// Given a partition range, and the full list of sort partition points, given that the sort
202+
/// partition points are sorted using [partition columns..., order columns...], the split
203+
/// boundaries would align (what's sorted on [partition columns...] would definitely be sorted
204+
/// on finer columns), so this will use binary search to find ranges that are within the
205+
/// partition range and return the valid slice.
206+
fn find_ranges_in_range<'a>(
207+
partition_range: &Range<usize>,
208+
sort_partition_points: &'a [Range<usize>],
209+
) -> &'a [Range<usize>] {
210+
let start_idx = sort_partition_points
211+
.partition_point(|sort_range| sort_range.start < partition_range.start);
212+
let end_idx = sort_partition_points
213+
.partition_point(|sort_range| sort_range.end <= partition_range.end);
214+
&sort_partition_points[start_idx..end_idx]
215+
}
216+
182217
/// A window expr that takes the form of an aggregate function
183218
#[derive(Debug)]
184219
pub struct AggregateWindowExpr {
@@ -205,13 +240,27 @@ impl AggregateWindowExpr {
205240
/// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same
206241
/// results for peers) and concatenate the results.
207242
fn peer_based_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
208-
let sort_partition_points = self.evaluate_sort_partition_points(batch)?;
209-
let mut window_accumulators = self.create_accumulator()?;
243+
let num_rows = batch.num_rows();
244+
let partition_points =
245+
self.evaluate_partition_points(num_rows, &self.partition_columns(batch)?)?;
246+
let sort_partition_points =
247+
self.evaluate_partition_points(num_rows, &self.sort_columns(batch)?)?;
210248
let values = self.evaluate_args(batch)?;
211-
let results = sort_partition_points
249+
let results = partition_points
212250
.iter()
213-
.map(|peer_range| window_accumulators.scan_peers(&values, peer_range))
214-
.collect::<Result<Vec<_>>>()?;
251+
.map(|partition_range| {
252+
let sort_partition_points =
253+
find_ranges_in_range(partition_range, &sort_partition_points);
254+
let mut window_accumulators = self.create_accumulator()?;
255+
sort_partition_points
256+
.iter()
257+
.map(|range| window_accumulators.scan_peers(&values, range))
258+
.collect::<Result<Vec<_>>>()
259+
})
260+
.collect::<Result<Vec<Vec<ArrayRef>>>>()?
261+
.into_iter()
262+
.flatten()
263+
.collect::<Vec<ArrayRef>>();
215264
let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
216265
concat(&results).map_err(DataFusionError::ArrowError)
217266
}

datafusion/tests/sql.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,70 @@ async fn csv_query_window_with_empty_over() -> Result<()> {
868868
Ok(())
869869
}
870870

871+
#[tokio::test]
872+
async fn csv_query_window_with_partition_by() -> Result<()> {
873+
let mut ctx = ExecutionContext::new();
874+
register_aggregate_csv(&mut ctx)?;
875+
let sql = "select \
876+
c9, \
877+
sum(cast(c4 as Int)) over (partition by c3), \
878+
avg(cast(c4 as Int)) over (partition by c3), \
879+
count(cast(c4 as Int)) over (partition by c3), \
880+
max(cast(c4 as Int)) over (partition by c3), \
881+
min(cast(c4 as Int)) over (partition by c3), \
882+
first_value(cast(c4 as Int)) over (partition by c3), \
883+
last_value(cast(c4 as Int)) over (partition by c3), \
884+
nth_value(cast(c4 as Int), 2) over (partition by c3) \
885+
from aggregate_test_100 \
886+
order by c9 \
887+
limit 5";
888+
let actual = execute(&mut ctx, sql).await;
889+
let expected = vec![
890+
vec![
891+
"28774375", "-16110", "-16110", "1", "-16110", "-16110", "-16110", "-16110",
892+
"NULL",
893+
],
894+
vec![
895+
"63044568", "3917", "3917", "1", "3917", "3917", "3917", "3917", "NULL",
896+
],
897+
vec![
898+
"141047417",
899+
"-38455",
900+
"-19227.5",
901+
"2",
902+
"-16974",
903+
"-21481",
904+
"-16974",
905+
"-21481",
906+
"-21481",
907+
],
908+
vec![
909+
"141680161",
910+
"-1114",
911+
"-1114",
912+
"1",
913+
"-1114",
914+
"-1114",
915+
"-1114",
916+
"-1114",
917+
"NULL",
918+
],
919+
vec![
920+
"145294611",
921+
"15673",
922+
"15673",
923+
"1",
924+
"15673",
925+
"15673",
926+
"15673",
927+
"15673",
928+
"NULL",
929+
],
930+
];
931+
assert_eq!(expected, actual);
932+
Ok(())
933+
}
934+
871935
#[tokio::test]
872936
async fn csv_query_window_with_order_by() -> Result<()> {
873937
let mut ctx = ExecutionContext::new();
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
-- Licensed to the Apache Software Foundation (ASF) under one
2+
-- or more contributor license agreements. See the NOTICE file
3+
-- distributed with this work for additional information
4+
-- regarding copyright ownership. The ASF licenses this file
5+
-- to you under the Apache License, Version 2.0 (the
6+
-- "License"); you may not use this file except in compliance
7+
-- with the License. You may obtain a copy of the License at
8+
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language gOVERning permissions and
15+
-- limitations under the License.
16+
17+
SELECT
18+
c9,
19+
row_number() OVER (PARTITION BY c2, c9) AS row_number,
20+
count(c3) OVER (PARTITION BY c2) AS count_c3,
21+
avg(c3) OVER (PARTITION BY c2) AS avg_c3_by_c2,
22+
sum(c3) OVER (PARTITION BY c2) AS sum_c3_by_c2,
23+
max(c3) OVER (PARTITION BY c2) AS max_c3_by_c2,
24+
min(c3) OVER (PARTITION BY c2) AS min_c3_by_c2
25+
FROM test
26+
ORDER BY c9;
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
-- Licensed to the Apache Software Foundation (ASF) under one
2+
-- or more contributor license agreements. See the NOTICE file
3+
-- distributed with this work for additional information
4+
-- regarding copyright ownership. The ASF licenses this file
5+
-- to you under the Apache License, Version 2.0 (the
6+
-- "License"); you may not use this file except in compliance
7+
-- with the License. You may obtain a copy of the License at
8+
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language gOVERning permissions and
15+
-- limitations under the License.
16+
17+
SELECT
18+
c9,
19+
row_number() OVER (PARTITION BY c2 ORDER BY c9) AS row_number,
20+
count(c3) OVER (PARTITION BY c2 ORDER BY c9) AS count_c3,
21+
avg(c3) OVER (PARTITION BY c2 ORDER BY c9) AS avg_c3_by_c2,
22+
sum(c3) OVER (PARTITION BY c2 ORDER BY c9) AS sum_c3_by_c2,
23+
max(c3) OVER (PARTITION BY c2 ORDER BY c9) AS max_c3_by_c2,
24+
min(c3) OVER (PARTITION BY c2 ORDER BY c9) AS min_c3_by_c2
25+
FROM test
26+
ORDER BY c9;

integration-tests/test_psql_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase):
7474
def test_parity(self):
7575
root = Path(os.path.dirname(__file__)) / "sqls"
7676
files = set(root.glob("*.sql"))
77-
self.assertEqual(len(files), 7, msg="tests are missed")
77+
self.assertEqual(len(files), 9, msg="tests are missed")
7878
for fname in files:
7979
with self.subTest(fname=fname):
8080
datafusion_output = pd.read_csv(

0 commit comments

Comments
 (0)