@@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch};
6262use arrow:: compute:: SortOptions ;
6363use arrow:: datatypes:: { Schema , SchemaRef } ;
6464use datafusion_common:: display:: ToStringifiedPlan ;
65- use datafusion_common:: tree_node:: { TreeNode , TreeNodeRecursion , TreeNodeVisitor } ;
65+ use datafusion_common:: tree_node:: {
66+ Transformed , TransformedResult , TreeNode , TreeNodeRecursion , TreeNodeVisitor ,
67+ } ;
6668use datafusion_common:: {
6769 exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema ,
6870 ScalarValue ,
@@ -2075,29 +2077,36 @@ fn maybe_fix_physical_column_name(
20752077 expr : Result < Arc < dyn PhysicalExpr > > ,
20762078 input_physical_schema : & SchemaRef ,
20772079) -> Result < Arc < dyn PhysicalExpr > > {
2078- if let Ok ( e) = & expr {
2079- if let Some ( column) = e. as_any ( ) . downcast_ref :: < Column > ( ) {
2080- let physical_field = input_physical_schema. field ( column. index ( ) ) ;
2080+ let Ok ( expr) = expr else { return expr } ;
2081+ expr. transform_down ( |node| {
2082+ if let Some ( column) = node. as_any ( ) . downcast_ref :: < Column > ( ) {
2083+ let idx = column. index ( ) ;
2084+ let physical_field = input_physical_schema. field ( idx) ;
20812085 let expr_col_name = column. name ( ) ;
20822086 let physical_name = physical_field. name ( ) ;
20832087
2084- if physical_name != expr_col_name {
2088+ if expr_col_name != physical_name {
20852089 // handle edge cases where the physical_name contains ':'.
20862090 let colon_count = physical_name. matches ( ':' ) . count ( ) ;
20872091 let mut splits = expr_col_name. match_indices ( ':' ) ;
20882092 let split_pos = splits. nth ( colon_count) ;
20892093
2090- if let Some ( ( idx , _) ) = split_pos {
2091- let base_name = & expr_col_name[ ..idx ] ;
2094+ if let Some ( ( i , _) ) = split_pos {
2095+ let base_name = & expr_col_name[ ..i ] ;
20922096 if base_name == physical_name {
2093- let updated_column = Column :: new ( physical_name, column . index ( ) ) ;
2094- return Ok ( Arc :: new ( updated_column) ) ;
2097+ let updated_column = Column :: new ( physical_name, idx ) ;
2098+ return Ok ( Transformed :: yes ( Arc :: new ( updated_column) ) ) ;
20952099 }
20962100 }
20972101 }
2102+
2103+ // If names already match or fix is not possible, just leave it as it is
2104+ Ok ( Transformed :: no ( node) )
2105+ } else {
2106+ Ok ( Transformed :: no ( node) )
20982107 }
2099- }
2100- expr
2108+ } )
2109+ . data ( )
21012110}
21022111
21032112struct OptimizationInvariantChecker < ' a > {
@@ -2203,8 +2212,11 @@ mod tests {
22032212 } ;
22042213 use datafusion_execution:: runtime_env:: RuntimeEnv ;
22052214 use datafusion_execution:: TaskContext ;
2206- use datafusion_expr:: { col, lit, LogicalPlanBuilder , UserDefinedLogicalNodeCore } ;
2215+ use datafusion_expr:: {
2216+ col, lit, LogicalPlanBuilder , Operator , UserDefinedLogicalNodeCore ,
2217+ } ;
22072218 use datafusion_functions_aggregate:: expr_fn:: sum;
2219+ use datafusion_physical_expr:: expressions:: { BinaryExpr , IsNotNullExpr } ;
22082220 use datafusion_physical_expr:: EquivalenceProperties ;
22092221 use datafusion_physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
22102222
@@ -2769,6 +2781,47 @@ mod tests {
27692781
27702782 assert_eq ! ( col. name( ) , "metric:avg" ) ;
27712783 }
2784+
2785+ #[ tokio:: test]
2786+ async fn test_maybe_fix_nested_column_name_with_colon ( ) {
2787+ let schema = Schema :: new ( vec ! [ Field :: new( "column" , DataType :: Int32 , false ) ] ) ;
2788+ let schema_ref: SchemaRef = Arc :: new ( schema) ;
2789+
2790+ // Construct the nested expr
2791+ let col_expr = Arc :: new ( Column :: new ( "column:1" , 0 ) ) as Arc < dyn PhysicalExpr > ;
2792+ let is_not_null_expr = Arc :: new ( IsNotNullExpr :: new ( col_expr. clone ( ) ) ) ;
2793+
2794+ // Create a binary expression and put the column inside
2795+ let binary_expr = Arc :: new ( BinaryExpr :: new (
2796+ is_not_null_expr. clone ( ) ,
2797+ Operator :: Or ,
2798+ is_not_null_expr. clone ( ) ,
2799+ ) ) as Arc < dyn PhysicalExpr > ;
2800+
2801+ let fixed_expr =
2802+ maybe_fix_physical_column_name ( Ok ( binary_expr) , & schema_ref) . unwrap ( ) ;
2803+
2804+ let bin = fixed_expr
2805+ . as_any ( )
2806+ . downcast_ref :: < BinaryExpr > ( )
2807+ . expect ( "Expected BinaryExpr" ) ;
2808+
2809+ // Check that both sides where renamed
2810+ for expr in & [ bin. left ( ) , bin. right ( ) ] {
2811+ let is_not_null = expr
2812+ . as_any ( )
2813+ . downcast_ref :: < IsNotNullExpr > ( )
2814+ . expect ( "Expected IsNotNull" ) ;
2815+
2816+ let col = is_not_null
2817+ . arg ( )
2818+ . as_any ( )
2819+ . downcast_ref :: < Column > ( )
2820+ . expect ( "Expected Column" ) ;
2821+
2822+ assert_eq ! ( col. name( ) , "column" ) ;
2823+ }
2824+ }
27722825 struct ErrorExtensionPlanner { }
27732826
27742827 #[ async_trait]
0 commit comments