@@ -25,6 +25,7 @@ use datafusion_common::cast::as_int32_array;
2525use datafusion_common:: ScalarValue ;
2626use datafusion_common:: { DFSchemaRef , ToDFSchema } ;
2727use datafusion_expr:: expr:: ScalarFunction ;
28+ use datafusion_expr:: logical_plan:: builder:: table_scan_with_filters;
2829use datafusion_expr:: simplify:: SimplifyInfo ;
2930use datafusion_expr:: {
3031 expr, table_scan, BuiltinScalarFunction , Cast , ColumnarValue , Expr , ExprSchemable ,
@@ -294,6 +295,45 @@ fn select_date_plus_interval() -> Result<()> {
294295 Ok ( ( ) )
295296}
296297
298+ #[ test]
299+ fn simplify_project_scalar_fn ( ) -> Result < ( ) > {
300+ // Issue https://github.com/apache/arrow-datafusion/issues/5996
301+ let schema = Schema :: new ( vec ! [ Field :: new( "f" , DataType :: Float64 , false ) ] ) ;
302+ let plan = table_scan ( Some ( "test" ) , & schema, None ) ?
303+ . project ( vec ! [ power( col( "f" ) , lit( 1.0 ) ) ] ) ?
304+ . build ( ) ?;
305+
306+ // before simplify: power(t.f, 1.0)
307+ // after simplify: t.f as "power(t.f, 1.0)"
308+ let expected = "Projection: test.f AS power(test.f,Float64(1))\
309+ \n TableScan: test";
310+ let actual = get_optimized_plan_formatted ( & plan, & Utc :: now ( ) ) ;
311+ assert_eq ! ( expected, actual) ;
312+ Ok ( ( ) )
313+ }
314+
315+ #[ test]
316+ fn simplify_scan_predicate ( ) -> Result < ( ) > {
317+ let schema = Schema :: new ( vec ! [
318+ Field :: new( "f" , DataType :: Float64 , false ) ,
319+ Field :: new( "g" , DataType :: Float64 , false ) ,
320+ ] ) ;
321+ let plan = table_scan_with_filters (
322+ Some ( "test" ) ,
323+ & schema,
324+ None ,
325+ vec ! [ col( "g" ) . eq( power( col( "f" ) , lit( 1.0 ) ) ) ] ,
326+ ) ?
327+ . build ( ) ?;
328+
329+ // before simplify: t.g = power(t.f, 1.0)
330+ // after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)"
331+ let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]" ;
332+ let actual = get_optimized_plan_formatted ( & plan, & Utc :: now ( ) ) ;
333+ assert_eq ! ( expected, actual) ;
334+ Ok ( ( ) )
335+ }
336+
297337#[ test]
298338fn test_const_evaluator ( ) {
299339 // true --> true
@@ -431,3 +471,99 @@ fn multiple_now() -> Result<()> {
431471 assert_eq ! ( expected, actual) ;
432472 Ok ( ( ) )
433473}
474+
475+ // ------------------------------
476+ // --- Simplifier tests -----
477+ // ------------------------------
478+
479+ fn expr_test_schema ( ) -> DFSchemaRef {
480+ Schema :: new ( vec ! [
481+ Field :: new( "c1" , DataType :: Utf8 , true ) ,
482+ Field :: new( "c2" , DataType :: Boolean , true ) ,
483+ Field :: new( "c3" , DataType :: Int64 , true ) ,
484+ Field :: new( "c4" , DataType :: UInt32 , true ) ,
485+ Field :: new( "c1_non_null" , DataType :: Utf8 , false ) ,
486+ Field :: new( "c2_non_null" , DataType :: Boolean , false ) ,
487+ Field :: new( "c3_non_null" , DataType :: Int64 , false ) ,
488+ Field :: new( "c4_non_null" , DataType :: UInt32 , false ) ,
489+ ] )
490+ . to_dfschema_ref ( )
491+ . unwrap ( )
492+ }
493+
494+ fn test_simplify ( input_expr : Expr , expected_expr : Expr ) {
495+ let info: MyInfo = MyInfo {
496+ schema : expr_test_schema ( ) ,
497+ execution_props : ExecutionProps :: new ( ) ,
498+ } ;
499+ let simplifier = ExprSimplifier :: new ( info) ;
500+ let simplified_expr = simplifier
501+ . simplify ( input_expr. clone ( ) )
502+ . expect ( "successfully evaluated" ) ;
503+
504+ assert_eq ! (
505+ simplified_expr, expected_expr,
506+ "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
507+ ) ;
508+ }
509+
510+ #[ test]
511+ fn test_simplify_log ( ) {
512+ // Log(c3, 1) ===> 0
513+ {
514+ let expr = log ( col ( "c3_non_null" ) , lit ( 1 ) ) ;
515+ test_simplify ( expr, lit ( 0i64 ) ) ;
516+ }
517+ // Log(c3, c3) ===> 1
518+ {
519+ let expr = log ( col ( "c3_non_null" ) , col ( "c3_non_null" ) ) ;
520+ let expected = lit ( 1i64 ) ;
521+ test_simplify ( expr, expected) ;
522+ }
523+ // Log(c3, Power(c3, c4)) ===> c4
524+ {
525+ let expr = log (
526+ col ( "c3_non_null" ) ,
527+ power ( col ( "c3_non_null" ) , col ( "c4_non_null" ) ) ,
528+ ) ;
529+ let expected = col ( "c4_non_null" ) ;
530+ test_simplify ( expr, expected) ;
531+ }
532+ // Log(c3, c4) ===> Log(c3, c4)
533+ {
534+ let expr = log ( col ( "c3_non_null" ) , col ( "c4_non_null" ) ) ;
535+ let expected = log ( col ( "c3_non_null" ) , col ( "c4_non_null" ) ) ;
536+ test_simplify ( expr, expected) ;
537+ }
538+ }
539+
540+ #[ test]
541+ fn test_simplify_power ( ) {
542+ // Power(c3, 0) ===> 1
543+ {
544+ let expr = power ( col ( "c3_non_null" ) , lit ( 0 ) ) ;
545+ let expected = lit ( 1i64 ) ;
546+ test_simplify ( expr, expected)
547+ }
548+ // Power(c3, 1) ===> c3
549+ {
550+ let expr = power ( col ( "c3_non_null" ) , lit ( 1 ) ) ;
551+ let expected = col ( "c3_non_null" ) ;
552+ test_simplify ( expr, expected)
553+ }
554+ // Power(c3, Log(c3, c4)) ===> c4
555+ {
556+ let expr = power (
557+ col ( "c3_non_null" ) ,
558+ log ( col ( "c3_non_null" ) , col ( "c4_non_null" ) ) ,
559+ ) ;
560+ let expected = col ( "c4_non_null" ) ;
561+ test_simplify ( expr, expected)
562+ }
563+ // Power(c3, c4) ===> Power(c3, c4)
564+ {
565+ let expr = power ( col ( "c3_non_null" ) , col ( "c4_non_null" ) ) ;
566+ let expected = power ( col ( "c3_non_null" ) , col ( "c4_non_null" ) ) ;
567+ test_simplify ( expr, expected)
568+ }
569+ }
0 commit comments