Skip to content

Commit d524da9

Browse files
committed
Simply expression rewrite in ProjectionPushdown, make more general
1 parent 43cc870 commit d524da9

File tree

1 file changed

+35
-115
lines changed

1 file changed

+35
-115
lines changed

datafusion/core/src/physical_optimizer/projection_pushdown.rs

Lines changed: 35 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,9 @@ use arrow_schema::SchemaRef;
4343
use datafusion_common::config::ConfigOptions;
4444
use datafusion_common::tree_node::{Transformed, TreeNode};
4545
use datafusion_common::JoinSide;
46-
use datafusion_physical_expr::expressions::{
47-
BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr,
48-
};
46+
use datafusion_physical_expr::expressions::Column;
4947
use datafusion_physical_expr::{
5048
Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
51-
ScalarFunctionExpr,
5249
};
5350
use datafusion_physical_plan::union::UnionExec;
5451

@@ -791,119 +788,42 @@ fn update_expr(
791788
projected_exprs: &[(Arc<dyn PhysicalExpr>, String)],
792789
sync_with_child: bool,
793790
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
794-
let expr_any = expr.as_any();
795-
if let Some(column) = expr_any.downcast_ref::<Column>() {
796-
if sync_with_child {
797-
// Update the index of `column`:
798-
Ok(Some(projected_exprs[column.index()].0.clone()))
799-
} else {
800-
// Determine how to update `column` to accommodate `projected_exprs`:
801-
Ok(projected_exprs.iter().enumerate().find_map(
802-
|(index, (projected_expr, alias))| {
803-
projected_expr.as_any().downcast_ref::<Column>().and_then(
804-
|projected_column| {
805-
column
806-
.name()
807-
.eq(projected_column.name())
808-
.then(|| Arc::new(Column::new(alias, index)) as _)
809-
},
810-
)
811-
},
812-
))
813-
}
814-
} else if let Some(binary) = expr_any.downcast_ref::<BinaryExpr>() {
815-
match (
816-
update_expr(binary.left(), projected_exprs, sync_with_child)?,
817-
update_expr(binary.right(), projected_exprs, sync_with_child)?,
818-
) {
819-
(Some(left), Some(right)) => {
820-
Ok(Some(Arc::new(BinaryExpr::new(left, *binary.op(), right))))
821-
}
822-
_ => Ok(None),
823-
}
824-
} else if let Some(cast) = expr_any.downcast_ref::<CastExpr>() {
825-
update_expr(cast.expr(), projected_exprs, sync_with_child).map(|maybe_expr| {
826-
maybe_expr.map(|expr| {
827-
Arc::new(CastExpr::new(
828-
expr,
829-
cast.cast_type().clone(),
830-
Some(cast.cast_options().clone()),
831-
)) as _
832-
})
833-
})
834-
} else if expr_any.is::<Literal>() {
835-
Ok(Some(expr.clone()))
836-
} else if let Some(negative) = expr_any.downcast_ref::<NegativeExpr>() {
837-
update_expr(negative.arg(), projected_exprs, sync_with_child).map(|maybe_expr| {
838-
maybe_expr.map(|expr| Arc::new(NegativeExpr::new(expr)) as _)
839-
})
840-
} else if let Some(scalar_func) = expr_any.downcast_ref::<ScalarFunctionExpr>() {
841-
scalar_func
842-
.args()
843-
.iter()
844-
.map(|expr| update_expr(expr, projected_exprs, sync_with_child))
845-
.collect::<Result<Option<Vec<_>>>>()
846-
.map(|maybe_args| {
847-
maybe_args.map(|new_args| {
848-
Arc::new(ScalarFunctionExpr::new(
849-
scalar_func.name(),
850-
scalar_func.fun().clone(),
851-
new_args,
852-
scalar_func.return_type(),
853-
scalar_func.monotonicity().clone(),
854-
)) as _
855-
})
856-
})
857-
} else if let Some(case) = expr_any.downcast_ref::<CaseExpr>() {
858-
update_case_expr(case, projected_exprs, sync_with_child)
859-
} else {
860-
Ok(None)
861-
}
862-
}
863-
864-
/// Updates the indices `case` refers to according to `projected_exprs`.
865-
fn update_case_expr(
866-
case: &CaseExpr,
867-
projected_exprs: &[(Arc<dyn PhysicalExpr>, String)],
868-
sync_with_child: bool,
869-
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
870-
let new_case = case
871-
.expr()
872-
.map(|expr| update_expr(expr, projected_exprs, sync_with_child))
873-
.transpose()?
874-
.flatten();
875-
876-
let new_else = case
877-
.else_expr()
878-
.map(|expr| update_expr(expr, projected_exprs, sync_with_child))
879-
.transpose()?
880-
.flatten();
881-
882-
let new_when_then = case
883-
.when_then_expr()
884-
.iter()
885-
.map(|(when, then)| {
886-
Ok((
887-
update_expr(when, projected_exprs, sync_with_child)?,
888-
update_expr(then, projected_exprs, sync_with_child)?,
889-
))
890-
})
891-
.collect::<Result<Vec<_>>>()?
892-
.into_iter()
893-
.filter_map(|(maybe_when, maybe_then)| match (maybe_when, maybe_then) {
894-
(Some(when), Some(then)) => Some((when, then)),
895-
_ => None,
896-
})
897-
.collect::<Vec<_>>();
791+
let mut rewritten = false;
898792

899-
if new_when_then.len() != case.when_then_expr().len()
900-
|| case.expr().is_some() && new_case.is_none()
901-
|| case.else_expr().is_some() && new_else.is_none()
902-
{
903-
return Ok(None);
904-
}
793+
let new_expr = expr
794+
.clone()
795+
.transform_down_mut(&mut |expr: Arc<dyn PhysicalExpr>| {
796+
let Some(column) = expr.as_any().downcast_ref::<Column>() else {
797+
return Ok(Transformed::No(expr));
798+
};
799+
if sync_with_child {
800+
rewritten = true;
801+
// Update the index of `column`:
802+
Ok(Transformed::Yes(projected_exprs[column.index()].0.clone()))
803+
} else {
804+
// Determine how to update `column` to accommodate `projected_exprs`:
805+
let new_col = projected_exprs.iter().enumerate().find_map(
806+
|(index, (projected_expr, alias))| {
807+
projected_expr.as_any().downcast_ref::<Column>().and_then(
808+
|projected_column| {
809+
column
810+
.name()
811+
.eq(projected_column.name())
812+
.then(|| Arc::new(Column::new(alias, index)) as _)
813+
},
814+
)
815+
},
816+
);
817+
if let Some(new_col) = new_col {
818+
rewritten = true;
819+
Ok(Transformed::Yes(new_col))
820+
} else {
821+
Ok(Transformed::No(expr))
822+
}
823+
}
824+
});
905825

906-
CaseExpr::try_new(new_case, new_when_then, new_else).map(|e| Some(Arc::new(e) as _))
826+
new_expr.map(|new_expr| if rewritten { Some(new_expr) } else { None })
907827
}
908828

909829
/// Creates a new [`ProjectionExec`] instance with the given child plan and

0 commit comments

Comments
 (0)