Skip to content

Commit 00f05c2

Browse files
authored
Make limit pushdown work for SortPreservingMergeExec (#17893)
* Make limit pushdown work for sortpreservingmerge * Remove test * datafusion/physical-optimizer/src/limit_pushdown_past_window.rs
1 parent 10a437b commit 00f05c2

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

datafusion/physical-optimizer/src/limit_pushdown_past_window.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use datafusion_expr::{WindowFrameBound, WindowFrameUnits};
2323
use datafusion_physical_plan::execution_plan::CardinalityEffect;
2424
use datafusion_physical_plan::limit::GlobalLimitExec;
2525
use datafusion_physical_plan::sorts::sort::SortExec;
26+
use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
2627
use datafusion_physical_plan::windows::BoundedWindowAggExec;
2728
use datafusion_physical_plan::ExecutionPlan;
2829
use std::cmp;
@@ -91,6 +92,22 @@ impl PhysicalOptimizerRule for LimitPushPastWindows {
9192
return Ok(Transformed::no(node));
9293
}
9394

95+
// Apply the limit if we hit a sortpreservingmerge node
96+
if let Some(spm) = node.as_any().downcast_ref::<SortPreservingMergeExec>() {
97+
let latest = latest_limit.take();
98+
let Some(fetch) = latest else {
99+
latest_max = 0;
100+
return Ok(Transformed::no(node));
101+
};
102+
let fetch = match spm.fetch() {
103+
None => fetch + latest_max,
104+
Some(existing) => cmp::min(existing, fetch + latest_max),
105+
};
106+
let spm: Arc<dyn ExecutionPlan> = spm.with_fetch(Some(fetch)).unwrap();
107+
latest_max = 0;
108+
return Ok(Transformed::complete(spm));
109+
}
110+
94111
// Apply the limit if we hit a sort node
95112
if let Some(sort) = node.as_any().downcast_ref::<SortExec>() {
96113
let latest = latest_limit.take();

datafusion/sqllogictest/test_files/window.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5966,8 +5966,8 @@ physical_plan
59665966
01)ProjectionExec: expr=[c1@2 as c1, c2@3 as c2, sum(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum1, sum(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum2, count(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as count1, array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@7 as array_agg1, array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as array_agg2]
59675967
02)--GlobalLimitExec: skip=0, fetch=5
59685968
03)----BoundedWindowAggExec: wdw=[sum(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "sum(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, count(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "count(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { name: "array_agg(test.c2) FILTER (WHERE test.c2 >= Int64(2) AND test.c2 < Int64(4) AND test.c1 > Int64(0)) ORDER BY [test.c1 ASC NULLS LAST, test.c2 ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted]
5969-
04)------SortPreservingMergeExec: [c1@2 ASC NULLS LAST, c2@3 ASC NULLS LAST]
5970-
05)--------SortExec: expr=[c1@2 ASC NULLS LAST, c2@3 ASC NULLS LAST], preserve_partitioning=[true]
5969+
04)------SortPreservingMergeExec: [c1@2 ASC NULLS LAST, c2@3 ASC NULLS LAST], fetch=5
5970+
05)--------SortExec: TopK(fetch=5), expr=[c1@2 ASC NULLS LAST, c2@3 ASC NULLS LAST], preserve_partitioning=[true]
59715971
06)----------ProjectionExec: expr=[__common_expr_3@0 as __common_expr_1, __common_expr_3@0 AND c2@2 < 4 AND c1@1 > 0 as __common_expr_2, c1@1 as c1, c2@2 as c2]
59725972
07)------------ProjectionExec: expr=[c2@1 >= 2 as __common_expr_3, c1@0 as c1, c2@1 as c2]
59735973
08)--------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_csv/partition-0.csv], [WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_csv/partition-1.csv], [WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_csv/partition-2.csv], [WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_csv/partition-3.csv]]}, projection=[c1, c2], file_type=csv, has_header=false

0 commit comments

Comments
 (0)