Skip to content

Commit f0d3117

Browse files
Fix: handle column name collisions when combining UNION logical inputs & nested Column expressions in maybe_fix_physical_column_name (apache#16064)
* Fix union schema name coercion * Address renaming for columns that are not in the top level as well * Add unit test * Format * Use insta tests properly * Address review - comment + minor simplification change --------- Co-authored-by: Berkay Şahin <124376117+berkaysynnada@users.noreply.github.com> (cherry picked from commit e5f596b)
1 parent 11f5af5 commit f0d3117

File tree

4 files changed

+426
-13
lines changed

4 files changed

+426
-13
lines changed

datafusion/core/src/physical_planner.rs

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch};
6262
use arrow::compute::SortOptions;
6363
use arrow::datatypes::{Schema, SchemaRef};
6464
use datafusion_common::display::ToStringifiedPlan;
65-
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
65+
use datafusion_common::tree_node::{
66+
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
67+
};
6668
use datafusion_common::{
6769
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema,
6870
ScalarValue,
@@ -2075,29 +2077,36 @@ fn maybe_fix_physical_column_name(
20752077
expr: Result<Arc<dyn PhysicalExpr>>,
20762078
input_physical_schema: &SchemaRef,
20772079
) -> Result<Arc<dyn PhysicalExpr>> {
2078-
if let Ok(e) = &expr {
2079-
if let Some(column) = e.as_any().downcast_ref::<Column>() {
2080-
let physical_field = input_physical_schema.field(column.index());
2080+
let Ok(expr) = expr else { return expr };
2081+
expr.transform_down(|node| {
2082+
if let Some(column) = node.as_any().downcast_ref::<Column>() {
2083+
let idx = column.index();
2084+
let physical_field = input_physical_schema.field(idx);
20812085
let expr_col_name = column.name();
20822086
let physical_name = physical_field.name();
20832087

2084-
if physical_name != expr_col_name {
2088+
if expr_col_name != physical_name {
20852089
// handle edge cases where the physical_name contains ':'.
20862090
let colon_count = physical_name.matches(':').count();
20872091
let mut splits = expr_col_name.match_indices(':');
20882092
let split_pos = splits.nth(colon_count);
20892093

2090-
if let Some((idx, _)) = split_pos {
2091-
let base_name = &expr_col_name[..idx];
2094+
if let Some((i, _)) = split_pos {
2095+
let base_name = &expr_col_name[..i];
20922096
if base_name == physical_name {
2093-
let updated_column = Column::new(physical_name, column.index());
2094-
return Ok(Arc::new(updated_column));
2097+
let updated_column = Column::new(physical_name, idx);
2098+
return Ok(Transformed::yes(Arc::new(updated_column)));
20952099
}
20962100
}
20972101
}
2102+
2103+
// If names already match or fix is not possible, just leave it as it is
2104+
Ok(Transformed::no(node))
2105+
} else {
2106+
Ok(Transformed::no(node))
20982107
}
2099-
}
2100-
expr
2108+
})
2109+
.data()
21012110
}
21022111

21032112
struct OptimizationInvariantChecker<'a> {
@@ -2201,8 +2210,11 @@ mod tests {
22012210
use datafusion_common::{assert_contains, DFSchemaRef, TableReference};
22022211
use datafusion_execution::runtime_env::RuntimeEnv;
22032212
use datafusion_execution::TaskContext;
2204-
use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore};
2213+
use datafusion_expr::{
2214+
col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore,
2215+
};
22052216
use datafusion_functions_aggregate::expr_fn::sum;
2217+
use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr};
22062218
use datafusion_physical_expr::EquivalenceProperties;
22072219
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
22082220

@@ -2719,6 +2731,47 @@ mod tests {
27192731

27202732
assert_eq!(col.name(), "metric:avg");
27212733
}
2734+
2735+
#[tokio::test]
2736+
async fn test_maybe_fix_nested_column_name_with_colon() {
2737+
let schema = Schema::new(vec![Field::new("column", DataType::Int32, false)]);
2738+
let schema_ref: SchemaRef = Arc::new(schema);
2739+
2740+
// Construct the nested expr
2741+
let col_expr = Arc::new(Column::new("column:1", 0)) as Arc<dyn PhysicalExpr>;
2742+
let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone()));
2743+
2744+
// Create a binary expression and put the column inside
2745+
let binary_expr = Arc::new(BinaryExpr::new(
2746+
is_not_null_expr.clone(),
2747+
Operator::Or,
2748+
is_not_null_expr.clone(),
2749+
)) as Arc<dyn PhysicalExpr>;
2750+
2751+
let fixed_expr =
2752+
maybe_fix_physical_column_name(Ok(binary_expr), &schema_ref).unwrap();
2753+
2754+
let bin = fixed_expr
2755+
.as_any()
2756+
.downcast_ref::<BinaryExpr>()
2757+
.expect("Expected BinaryExpr");
2758+
2759+
// Check that both sides where renamed
2760+
for expr in &[bin.left(), bin.right()] {
2761+
let is_not_null = expr
2762+
.as_any()
2763+
.downcast_ref::<IsNotNullExpr>()
2764+
.expect("Expected IsNotNull");
2765+
2766+
let col = is_not_null
2767+
.arg()
2768+
.as_any()
2769+
.downcast_ref::<Column>()
2770+
.expect("Expected Column");
2771+
2772+
assert_eq!(col.name(), "column");
2773+
}
2774+
}
27222775
struct ErrorExtensionPlanner {}
27232776

27242777
#[async_trait]

datafusion/physical-plan/src/union.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,12 @@ fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {
513513

514514
let fields = (0..first_schema.fields().len())
515515
.map(|i| {
516-
inputs
516+
// We take the name from the left side of the union to match how names are coerced during logical planning,
517+
// which also uses the left side names.
518+
let base_field = first_schema.field(i).clone();
519+
520+
// Coerce metadata and nullability across all inputs
521+
let merged_field = inputs
517522
.iter()
518523
.enumerate()
519524
.map(|(input_idx, input)| {
@@ -535,6 +540,9 @@ fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {
535540
// We can unwrap this because if inputs was empty, this would've already panic'ed when we
536541
// indexed into inputs[0].
537542
.unwrap()
543+
.with_name(base_field.name());
544+
545+
merged_field
538546
})
539547
.collect::<Vec<_>>();
540548

datafusion/substrait/tests/cases/consumer_integration.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,4 +560,28 @@ mod tests {
560560
);
561561
Ok(())
562562
}
563+
564+
#[tokio::test]
565+
async fn test_multiple_unions() -> Result<()> {
566+
let plan_str = test_plan_to_string("multiple_unions.json").await?;
567+
assert_snapshot!(
568+
plan_str,
569+
@r#"
570+
Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key
571+
Union
572+
Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key
573+
Left Join: sales.product_key = food.@food_id
574+
TableScan: sales
575+
TableScan: food
576+
Union
577+
Projection: people.$f3, people.$f5, people.product_key0
578+
Left Join: people.product_key0 = food.@food_id
579+
TableScan: people
580+
TableScan: food
581+
TableScan: more_products
582+
"#
583+
);
584+
585+
Ok(())
586+
}
563587
}

0 commit comments

Comments
 (0)