Skip to content

feature: sort by/cluster by/distribute by #16310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,69 @@ impl DataFrame {
)
}

/// Sorts the DataFrame within each partition using the specified expressions.
///
/// This function performs a local sort within each partition of the DataFrame,
/// meaning that the sorting is done independently for each partition without
/// merging the results across partitions. This is more efficient than a global
/// sort when you only need data sorted within each partition.
///
/// # Arguments
///
/// * `expr` - A vector of expressions to sort by. Each expression can be a column name
/// or a more complex expression. The expressions are evaluated in order, with
/// earlier expressions taking precedence over later ones.
///
/// # Returns
///
/// Returns a new DataFrame with the data sorted within each partition according
/// to the specified expressions.
///
/// # Example
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # use datafusion_common::assert_batches_sorted_eq;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?;
/// // First repartition the data
/// let df = df.repartition(Partitioning::RoundRobinBatch(2))?;
/// // Then sort within each partition
/// let df = df.sort_by_within_partitions(vec![
/// col("a"), // a ASC
/// col("b"), // b ASC
/// ])?;
/// let expected = vec![
/// "+---+---+---+",
/// "| a | b | c |",
/// "+---+---+---+",
/// "| 7 | 8 | 9 |",
/// "| 4 | 5 | 6 |",
/// "| 1 | 2 | 3 |",
/// "+---+---+---+",
/// ];
/// # assert_batches_sorted_eq!(expected, &df.collect().await?);
/// # Ok(())
/// # }
/// ```
///
/// # Note
///
/// - This operation maintains the existing partitioning of the data
/// - The sort order is not guaranteed across partitions
/// - For a global sort across all partitions, use [`sort_by()`](Self::sort_by) instead
///
pub fn sort_by_within_partitions(self, expr: Vec<Expr>) -> Result<DataFrame> {
self.sort_within_partitions(
expr.into_iter()
.map(|e| e.sort(true, false))
.collect::<Vec<SortExpr>>(),
)
}

/// Sort the DataFrame by the specified sorting expressions.
///
/// Note that any expression can be turned into
Expand Down Expand Up @@ -1152,6 +1215,70 @@ impl DataFrame {
projection_requires_validation: self.projection_requires_validation,
})
}
/// Sorts the DataFrame within each partition using the specified expressions.
///
/// This function performs a local sort within each partition of the DataFrame,
/// meaning that the sorting is done independently for each partition without
/// merging the results across partitions. This is more efficient than a global
/// sort when you only need data sorted within each partition.
///
/// # Arguments
///
/// * `expr` - A vector of expressions to sort by. Each expression can be a column name
/// or a more complex expression. The expressions are evaluated in order, with
/// earlier expressions taking precedence over later ones.
///
/// # Returns
///
/// Returns a new DataFrame with the data sorted within each partition according
/// to the specified expressions.
///
/// # Example
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # use datafusion_common::assert_batches_sorted_eq;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?;
/// // First repartition the data
/// let df = df.repartition(Partitioning::RoundRobinBatch(2))?;
/// // Then sort within each partition
/// let df = df.sort_within_partitions(vec![
/// col("a").sort(false, true), // a DESC, nulls first
/// col("b").sort(true, false), // b ASC, nulls last
/// ])?;
/// let expected = vec![
/// "+---+---+---+",
/// "| a | b | c |",
/// "+---+---+---+",
/// "| 7 | 8 | 9 |",
/// "| 4 | 5 | 6 |",
/// "| 1 | 2 | 3 |",
/// "+---+---+---+",
/// ];
/// # assert_batches_sorted_eq!(expected, &df.collect().await?);
/// # Ok(())
/// # }
/// ```
///
/// # Note
///
/// - This operation maintains the existing partitioning of the data
/// - The sort order is not guaranteed across partitions
/// - For a global sort across all partitions, use [`sort()`](Self::sort) instead
pub fn sort_within_partitions(self, expr: Vec<SortExpr>) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.sort_within_partitions(expr)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
})
}

/// Join this `DataFrame` with another `DataFrame` using explicitly specified
/// columns and an optional filter expression.
Expand Down
29 changes: 22 additions & 7 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -806,10 +806,20 @@ impl DefaultPhysicalPlanner {
.collect::<Result<Vec<_>>>()?;
Partitioning::Hash(runtime_expr, *n)
}
LogicalPartitioning::DistributeBy(_) => {
return not_impl_err!(
"Physical plan does not support DistributeBy partitioning"
);
LogicalPartitioning::DistributeBy(expr) => {
let n =
session_state.config().options().execution.target_partitions;
let runtime_expr = expr
.iter()
.map(|e| {
self.create_physical_expr(
e,
input_dfschema,
session_state,
)
})
.collect::<Result<Vec<_>>>()?;
Partitioning::Hash(runtime_expr, n)
}
};
Arc::new(RepartitionExec::try_new(
Expand All @@ -818,7 +828,11 @@ impl DefaultPhysicalPlanner {
)?)
}
LogicalPlan::Sort(Sort {
expr, input, fetch, ..
expr,
input,
fetch,
preserve_partitioning,
..
}) => {
let physical_input = children.one()?;
let input_dfschema = input.as_ref().schema();
Expand All @@ -827,8 +841,9 @@ impl DefaultPhysicalPlanner {
input_dfschema,
session_state.execution_props(),
)?;
let new_sort =
SortExec::new(sort_expr, physical_input).with_fetch(*fetch);
let new_sort = SortExec::new(sort_expr, physical_input)
.with_preserve_partitioning(*preserve_partitioning)
.with_fetch(*fetch);
Arc::new(new_sort)
}
LogicalPlan::Subquery(_) => todo!(),
Expand Down
12 changes: 11 additions & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,14 +775,22 @@ impl LogicalPlanBuilder {
self,
sorts: impl IntoIterator<Item = impl Into<SortExpr>> + Clone,
) -> Result<Self> {
self.sort_with_limit(sorts, None)
self.sort_with_limit(sorts, None, false)
}

pub fn sort_within_partitions(
self,
sorts: impl IntoIterator<Item = impl Into<SortExpr>> + Clone,
) -> Result<Self> {
self.sort_with_limit(sorts, None, true)
}

/// Apply a sort
pub fn sort_with_limit(
self,
sorts: impl IntoIterator<Item = impl Into<SortExpr>> + Clone,
fetch: Option<usize>,
preserve_partitioning: bool,
) -> Result<Self> {
let sorts = rewrite_sort_cols_by_aggs(sorts, &self.plan)?;

Expand All @@ -808,6 +816,7 @@ impl LogicalPlanBuilder {
expr: normalize_sorts(sorts, &self.plan)?,
input: self.plan,
fetch,
preserve_partitioning,
})));
}

Expand All @@ -825,6 +834,7 @@ impl LogicalPlanBuilder {
expr: normalize_sorts(sorts, &plan)?,
input: Arc::new(plan),
fetch,
preserve_partitioning: false,
});

Projection::try_new(new_expr, Arc::new(sort_plan))
Expand Down
8 changes: 7 additions & 1 deletion datafusion/expr/src/logical_plan/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,16 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> {
"Aggregates": expr_vec_fmt!(aggr_expr)
})
}
LogicalPlan::Sort(Sort { expr, fetch, .. }) => {
LogicalPlan::Sort(Sort {
expr,
fetch,
preserve_partitioning,
..
}) => {
let mut object = json!({
"Node Type": "Sort",
"Sort Key": expr_vec_fmt!(expr),
"Preserve Partitioning": preserve_partitioning,
});

if let Some(fetch) = fetch {
Expand Down
10 changes: 9 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,7 @@ impl LogicalPlan {
LogicalPlan::Sort(Sort {
expr: sort_expr,
fetch,
preserve_partitioning,
..
}) => {
let input = self.only_input(inputs)?;
Expand All @@ -888,6 +889,7 @@ impl LogicalPlan {
.collect(),
input: Arc::new(input),
fetch: *fetch,
preserve_partitioning: *preserve_partitioning,
}))
}
LogicalPlan::Join(Join {
Expand Down Expand Up @@ -1868,7 +1870,7 @@ impl LogicalPlan {
expr_vec_fmt!(group_expr),
expr_vec_fmt!(aggr_expr)
),
LogicalPlan::Sort(Sort { expr, fetch, .. }) => {
LogicalPlan::Sort(Sort { expr, fetch, preserve_partitioning, .. }) => {
write!(f, "Sort: ")?;
for (i, expr_item) in expr.iter().enumerate() {
if i > 0 {
Expand All @@ -1879,6 +1881,9 @@ impl LogicalPlan {
if let Some(a) = fetch {
write!(f, ", fetch={a}")?;
}
if *preserve_partitioning {
write!(f, ", preserve_ordering={preserve_partitioning}")?;
}

Ok(())
}
Expand Down Expand Up @@ -3681,6 +3686,9 @@ pub struct Sort {
pub input: Arc<LogicalPlan>,
/// Optional fetch limit
pub fetch: Option<usize>,
/// Preserve partitions of input plan. If false, the input partitions
/// will be sorted and merged into a single output partition.
pub preserve_partitioning: bool,
}

/// Join two logical plans on one or more join columns
Expand Down
32 changes: 26 additions & 6 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,19 @@ impl TreeNode for LogicalPlan {
schema,
})
}),
LogicalPlan::Sort(Sort { expr, input, fetch }) => input
.map_elements(f)?
.update_data(|input| LogicalPlan::Sort(Sort { expr, input, fetch })),
LogicalPlan::Sort(Sort {
expr,
input,
fetch,
preserve_partitioning,
}) => input.map_elements(f)?.update_data(|input| {
LogicalPlan::Sort(Sort {
expr,
input,
fetch,
preserve_partitioning,
})
}),
LogicalPlan::Join(Join {
left,
right,
Expand Down Expand Up @@ -574,9 +584,19 @@ impl LogicalPlan {
null_equals_null,
})
}),
LogicalPlan::Sort(Sort { expr, input, fetch }) => expr
.map_elements(f)?
.update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })),
LogicalPlan::Sort(Sort {
expr,
input,
fetch,
preserve_partitioning,
}) => expr.map_elements(f)?.update_data(|expr| {
LogicalPlan::Sort(Sort {
expr,
input,
fetch,
preserve_partitioning,
})
}),
LogicalPlan::Extension(Extension { node }) => {
// would be nice to avoid this copy -- maybe can
// update extension to just observer Exprs
Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,7 @@ mod test {
expr: vec![sort_expr],
input: Arc::new(plan),
fetch: None,
preserve_partitioning: false,
});

// Plan C: no coerce
Expand Down Expand Up @@ -1421,6 +1422,7 @@ mod test {
expr: vec![sort_expr],
input: Arc::new(plan),
fetch: None,
preserve_partitioning: false,
});

// Plan C: no coerce
Expand Down
8 changes: 7 additions & 1 deletion datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ impl CommonSubexprEliminate {
sort: Sort,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let Sort { expr, input, fetch } = sort;
let Sort {
expr,
input,
fetch,
preserve_partitioning,
} = sort;
let input = Arc::unwrap_or_clone(input);
let (sort_expressions, sort_params): (Vec<_>, Vec<(_, _)>) = expr
.into_iter()
Expand All @@ -117,6 +122,7 @@ impl CommonSubexprEliminate {
.collect(),
input: Arc::new(new_input),
fetch,
preserve_partitioning,
})
});
Ok(new_sort)
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/eliminate_duplicated_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl OptimizerRule for EliminateDuplicatedExpr {
expr: unique_exprs,
input: sort.input,
fetch: sort.fetch,
preserve_partitioning: sort.preserve_partitioning,
})))
}
LogicalPlan::Aggregate(agg) => {
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ message SortNode {
repeated SortExprNode expr = 2;
// Maximum number of highest/lowest rows to fetch; negative means no limit
int64 fetch = 3;
bool preserve_partitioning = 4;
}

message RepartitionNode {
Expand Down
Loading
Loading