@@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch};
6262use arrow:: compute:: SortOptions ;
6363use arrow:: datatypes:: { Schema , SchemaRef } ;
6464use datafusion_common:: display:: ToStringifiedPlan ;
65- use datafusion_common:: tree_node:: { TreeNode , TreeNodeRecursion , TreeNodeVisitor } ;
65+ use datafusion_common:: tree_node:: {
66+ Transformed , TransformedResult , TreeNode , TreeNodeRecursion , TreeNodeVisitor ,
67+ } ;
6668use datafusion_common:: {
6769 exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema ,
6870 ScalarValue ,
@@ -2075,29 +2077,36 @@ fn maybe_fix_physical_column_name(
20752077 expr : Result < Arc < dyn PhysicalExpr > > ,
20762078 input_physical_schema : & SchemaRef ,
20772079) -> Result < Arc < dyn PhysicalExpr > > {
2078- if let Ok ( e) = & expr {
2079- if let Some ( column) = e. as_any ( ) . downcast_ref :: < Column > ( ) {
2080- let physical_field = input_physical_schema. field ( column. index ( ) ) ;
2080+ let Ok ( expr) = expr else { return expr } ;
2081+ expr. transform_down ( |node| {
2082+ if let Some ( column) = node. as_any ( ) . downcast_ref :: < Column > ( ) {
2083+ let idx = column. index ( ) ;
2084+ let physical_field = input_physical_schema. field ( idx) ;
20812085 let expr_col_name = column. name ( ) ;
20822086 let physical_name = physical_field. name ( ) ;
20832087
2084- if physical_name != expr_col_name {
2088+ if expr_col_name != physical_name {
20852089 // handle edge cases where the physical_name contains ':'.
20862090 let colon_count = physical_name. matches ( ':' ) . count ( ) ;
20872091 let mut splits = expr_col_name. match_indices ( ':' ) ;
20882092 let split_pos = splits. nth ( colon_count) ;
20892093
2090- if let Some ( ( idx , _) ) = split_pos {
2091- let base_name = & expr_col_name[ ..idx ] ;
2094+ if let Some ( ( i , _) ) = split_pos {
2095+ let base_name = & expr_col_name[ ..i ] ;
20922096 if base_name == physical_name {
2093- let updated_column = Column :: new ( physical_name, column . index ( ) ) ;
2094- return Ok ( Arc :: new ( updated_column) ) ;
2097+ let updated_column = Column :: new ( physical_name, idx ) ;
2098+ return Ok ( Transformed :: yes ( Arc :: new ( updated_column) ) ) ;
20952099 }
20962100 }
20972101 }
2102+
2103+ // If names already match or fix is not possible, just leave it as it is
2104+ Ok ( Transformed :: no ( node) )
2105+ } else {
2106+ Ok ( Transformed :: no ( node) )
20982107 }
2099- }
2100- expr
2108+ } )
2109+ . data ( )
21012110}
21022111
21032112struct OptimizationInvariantChecker < ' a > {
@@ -2201,8 +2210,11 @@ mod tests {
22012210 use datafusion_common:: { assert_contains, DFSchemaRef , TableReference } ;
22022211 use datafusion_execution:: runtime_env:: RuntimeEnv ;
22032212 use datafusion_execution:: TaskContext ;
2204- use datafusion_expr:: { col, lit, LogicalPlanBuilder , UserDefinedLogicalNodeCore } ;
2213+ use datafusion_expr:: {
2214+ col, lit, LogicalPlanBuilder , Operator , UserDefinedLogicalNodeCore ,
2215+ } ;
22052216 use datafusion_functions_aggregate:: expr_fn:: sum;
2217+ use datafusion_physical_expr:: expressions:: { BinaryExpr , IsNotNullExpr } ;
22062218 use datafusion_physical_expr:: EquivalenceProperties ;
22072219 use datafusion_physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
22082220
@@ -2719,6 +2731,47 @@ mod tests {
27192731
27202732 assert_eq ! ( col. name( ) , "metric:avg" ) ;
27212733 }
2734+
2735+ #[ tokio:: test]
2736+ async fn test_maybe_fix_nested_column_name_with_colon ( ) {
2737+ let schema = Schema :: new ( vec ! [ Field :: new( "column" , DataType :: Int32 , false ) ] ) ;
2738+ let schema_ref: SchemaRef = Arc :: new ( schema) ;
2739+
2740+ // Construct the nested expr
2741+ let col_expr = Arc :: new ( Column :: new ( "column:1" , 0 ) ) as Arc < dyn PhysicalExpr > ;
2742+ let is_not_null_expr = Arc :: new ( IsNotNullExpr :: new ( col_expr. clone ( ) ) ) ;
2743+
2744+ // Create a binary expression and put the column inside
2745+ let binary_expr = Arc :: new ( BinaryExpr :: new (
2746+ is_not_null_expr. clone ( ) ,
2747+ Operator :: Or ,
2748+ is_not_null_expr. clone ( ) ,
2749+ ) ) as Arc < dyn PhysicalExpr > ;
2750+
2751+ let fixed_expr =
2752+ maybe_fix_physical_column_name ( Ok ( binary_expr) , & schema_ref) . unwrap ( ) ;
2753+
2754+ let bin = fixed_expr
2755+ . as_any ( )
2756+ . downcast_ref :: < BinaryExpr > ( )
2757+ . expect ( "Expected BinaryExpr" ) ;
2758+
2759+ // Check that both sides where renamed
2760+ for expr in & [ bin. left ( ) , bin. right ( ) ] {
2761+ let is_not_null = expr
2762+ . as_any ( )
2763+ . downcast_ref :: < IsNotNullExpr > ( )
2764+ . expect ( "Expected IsNotNull" ) ;
2765+
2766+ let col = is_not_null
2767+ . arg ( )
2768+ . as_any ( )
2769+ . downcast_ref :: < Column > ( )
2770+ . expect ( "Expected Column" ) ;
2771+
2772+ assert_eq ! ( col. name( ) , "column" ) ;
2773+ }
2774+ }
27222775 struct ErrorExtensionPlanner { }
27232776
27242777 #[ async_trait]
0 commit comments