@@ -32,7 +32,6 @@ use datafusion_common::{
3232 tree_node:: { Transformed , TransformedResult , TreeNode , TreeNodeRewriter } ,
3333} ;
3434use datafusion_common:: { internal_err, DFSchema , DataFusionError , Result , ScalarValue } ;
35- use datafusion_expr:: simplify:: ExprSimplifyResult ;
3635use datafusion_expr:: {
3736 and, lit, or, BinaryExpr , Case , ColumnarValue , Expr , Like , Operator , Volatility ,
3837 WindowFunctionDefinition ,
@@ -42,14 +41,23 @@ use datafusion_expr::{
4241 expr:: { InList , InSubquery , WindowFunction } ,
4342 utils:: { iter_conjunction, iter_conjunction_owned} ,
4443} ;
44+ use datafusion_expr:: { simplify:: ExprSimplifyResult , Cast , TryCast } ;
4545use datafusion_physical_expr:: { create_physical_expr, execution_props:: ExecutionProps } ;
4646
4747use super :: inlist_simplifier:: ShortenInListSimplifier ;
4848use super :: utils:: * ;
49- use crate :: analyzer:: type_coercion:: TypeCoercionRewriter ;
5049use crate :: simplify_expressions:: guarantees:: GuaranteeRewriter ;
5150use crate :: simplify_expressions:: regex:: simplify_regex_expr;
51+ use crate :: simplify_expressions:: unwrap_cast:: {
52+ is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
53+ is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
54+ unwrap_cast_in_comparison_for_binary,
55+ } ;
5256use crate :: simplify_expressions:: SimplifyInfo ;
57+ use crate :: {
58+ analyzer:: type_coercion:: TypeCoercionRewriter ,
59+ simplify_expressions:: unwrap_cast:: try_cast_literal_to_type,
60+ } ;
5361use indexmap:: IndexSet ;
5462use regex:: Regex ;
5563
@@ -1742,6 +1750,86 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
17421750 }
17431751 }
17441752
1753+ // =======================================
1754+ // unwrap_cast_in_comparison
1755+ // =======================================
1756+ //
1757+ // For case:
1758+ // try_cast/cast(expr as data_type) op literal
1759+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } )
1760+ if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary (
1761+ info, & left, & right,
1762+ ) && op. supports_propagation ( ) =>
1763+ {
1764+ unwrap_cast_in_comparison_for_binary ( info, left, right, op) ?
1765+ }
1766+ // literal op try_cast/cast(expr as data_type)
1767+ // -->
1768+ // try_cast/cast(expr as data_type) op_swap literal
1769+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } )
1770+ if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary (
1771+ info, & right, & left,
1772+ ) && op. supports_propagation ( )
1773+ && op. swap ( ) . is_some ( ) =>
1774+ {
1775+ unwrap_cast_in_comparison_for_binary (
1776+ info,
1777+ right,
1778+ left,
1779+ op. swap ( ) . unwrap ( ) ,
1780+ ) ?
1781+ }
1782+ // For case:
1783+ // try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
1784+ Expr :: InList ( InList {
1785+ expr : mut left,
1786+ list,
1787+ negated,
1788+ } ) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist (
1789+ info, & left, & list,
1790+ ) =>
1791+ {
1792+ let ( Expr :: TryCast ( TryCast {
1793+ expr : left_expr, ..
1794+ } )
1795+ | Expr :: Cast ( Cast {
1796+ expr : left_expr, ..
1797+ } ) ) = left. as_mut ( )
1798+ else {
1799+ return internal_err ! ( "Expect cast expr, but got {:?}" , left) ?;
1800+ } ;
1801+
1802+ let expr_type = info. get_data_type ( left_expr) ?;
1803+ let right_exprs = list
1804+ . into_iter ( )
1805+ . map ( |right| {
1806+ match right {
1807+ Expr :: Literal ( right_lit_value) => {
1808+ // if the right_lit_value can be casted to the type of internal_left_expr
1809+ // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
1810+ let Some ( value) = try_cast_literal_to_type ( & right_lit_value, & expr_type) else {
1811+ internal_err ! (
1812+ "Can't cast the list expr {:?} to type {:?}" ,
1813+ right_lit_value, & expr_type
1814+ ) ?
1815+ } ;
1816+ Ok ( lit ( value) )
1817+ }
1818+ other_expr => internal_err ! (
1819+ "Only support literal expr to optimize, but the expr is {:?}" ,
1820+ & other_expr
1821+ ) ,
1822+ }
1823+ } )
1824+ . collect :: < Result < Vec < _ > > > ( ) ?;
1825+
1826+ Transformed :: yes ( Expr :: InList ( InList {
1827+ expr : std:: mem:: take ( left_expr) ,
1828+ list : right_exprs,
1829+ negated,
1830+ } ) )
1831+ }
1832+
17451833 // no additional rewrites possible
17461834 expr => Transformed :: no ( expr) ,
17471835 } )
0 commit comments