Skip to content
77 changes: 65 additions & 12 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch};
use arrow::compute::SortOptions;
use arrow::datatypes::{Schema, SchemaRef};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema,
ScalarValue,
Expand Down Expand Up @@ -2075,29 +2077,36 @@ fn maybe_fix_physical_column_name(
expr: Result<Arc<dyn PhysicalExpr>>,
input_physical_schema: &SchemaRef,
) -> Result<Arc<dyn PhysicalExpr>> {
if let Ok(e) = &expr {
if let Some(column) = e.as_any().downcast_ref::<Column>() {
let physical_field = input_physical_schema.field(column.index());
let Ok(expr) = expr else { return expr };
expr.transform_down(|node| {
if let Some(column) = node.as_any().downcast_ref::<Column>() {
let idx = column.index();
let physical_field = input_physical_schema.field(idx);
let expr_col_name = column.name();
let physical_name = physical_field.name();

if physical_name != expr_col_name {
if expr_col_name != physical_name {
// handle edge cases where the physical_name contains ':'.
let colon_count = physical_name.matches(':').count();
let mut splits = expr_col_name.match_indices(':');
let split_pos = splits.nth(colon_count);

if let Some((idx, _)) = split_pos {
let base_name = &expr_col_name[..idx];
if let Some((i, _)) = split_pos {
let base_name = &expr_col_name[..i];
if base_name == physical_name {
let updated_column = Column::new(physical_name, column.index());
return Ok(Arc::new(updated_column));
let updated_column = Column::new(physical_name, idx);
return Ok(Transformed::yes(Arc::new(updated_column)));
}
}
}

// If names already match or fix is not possible, just leave it as it is
Ok(Transformed::no(node))
} else {
Ok(Transformed::no(node))
}
}
expr
})
.data()
}

struct OptimizationInvariantChecker<'a> {
Expand Down Expand Up @@ -2203,8 +2212,11 @@ mod tests {
};
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore};
use datafusion_expr::{
col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore,
};
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr};
use datafusion_physical_expr::EquivalenceProperties;
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};

Expand Down Expand Up @@ -2769,6 +2781,47 @@ mod tests {

assert_eq!(col.name(), "metric:avg");
}

#[tokio::test]
async fn test_maybe_fix_nested_column_name_with_colon() {
let schema = Schema::new(vec![Field::new("column", DataType::Int32, false)]);
let schema_ref: SchemaRef = Arc::new(schema);

// Construct the nested expr
let col_expr = Arc::new(Column::new("column:1", 0)) as Arc<dyn PhysicalExpr>;
let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone()));

// Create a binary expression and put the column inside
let binary_expr = Arc::new(BinaryExpr::new(
is_not_null_expr.clone(),
Operator::Or,
is_not_null_expr.clone(),
)) as Arc<dyn PhysicalExpr>;

let fixed_expr =
maybe_fix_physical_column_name(Ok(binary_expr), &schema_ref).unwrap();

let bin = fixed_expr
.as_any()
.downcast_ref::<BinaryExpr>()
.expect("Expected BinaryExpr");

// Check that both sides where renamed
for expr in &[bin.left(), bin.right()] {
let is_not_null = expr
.as_any()
.downcast_ref::<IsNotNullExpr>()
.expect("Expected IsNotNull");

let col = is_not_null
.arg()
.as_any()
.downcast_ref::<Column>()
.expect("Expected Column");

assert_eq!(col.name(), "column");
}
}
struct ErrorExtensionPlanner {}

#[async_trait]
Expand Down
10 changes: 9 additions & 1 deletion datafusion/physical-plan/src/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,12 @@ fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {

let fields = (0..first_schema.fields().len())
.map(|i| {
inputs
// We take the name from the left side of the union to match how names are coerced during logical planning,
// which also uses the left side names.
let base_field = first_schema.field(i).clone();

// Coerce metadata and nullability across all inputs
let merged_field = inputs
.iter()
.enumerate()
.map(|(input_idx, input)| {
Expand All @@ -562,6 +567,9 @@ fn union_schema(inputs: &[Arc<dyn ExecutionPlan>]) -> SchemaRef {
// We can unwrap this because if inputs was empty, this would've already panic'ed when we
// indexed into inputs[0].
.unwrap()
.with_name(base_field.name());

merged_field
})
.collect::<Vec<_>>();

Expand Down
24 changes: 24 additions & 0 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,4 +560,28 @@ mod tests {
);
Ok(())
}

#[tokio::test]
async fn test_multiple_unions() -> Result<()> {
let plan_str = test_plan_to_string("multiple_unions.json").await?;
assert_snapshot!(
plan_str,
@r#"
Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key
Union
Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key
Left Join: sales.product_key = food.@food_id
TableScan: sales
TableScan: food
Union
Projection: people.$f3, people.$f5, people.product_key0
Left Join: people.product_key0 = food.@food_id
TableScan: people
TableScan: food
TableScan: more_products
"#
);

Ok(())
}
}
Loading