@@ -23,11 +23,19 @@ use std::collections::HashMap;
2323use std:: sync:: Arc ;
2424
2525use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef , TimeUnit } ;
26+ use arrow_schema:: { Fields , SchemaBuilder } ;
2627use datafusion_common:: config:: ConfigOptions ;
27- use datafusion_common:: { plan_err, Result } ;
28- use datafusion_expr:: { AggregateUDF , LogicalPlan , ScalarUDF , TableSource , WindowUDF } ;
28+ use datafusion_common:: tree_node:: { TransformedResult , TreeNode } ;
29+ use datafusion_common:: { plan_err, DFSchema , Result , ScalarValue } ;
30+ use datafusion_expr:: interval_arithmetic:: { Interval , NullableInterval } ;
31+ use datafusion_expr:: {
32+ col, lit, AggregateUDF , BinaryExpr , Expr , ExprSchemable , LogicalPlan , Operator ,
33+ ScalarUDF , TableSource , WindowUDF ,
34+ } ;
35+ use datafusion_functions:: core:: expr_ext:: FieldAccessor ;
2936use datafusion_optimizer:: analyzer:: Analyzer ;
3037use datafusion_optimizer:: optimizer:: Optimizer ;
38+ use datafusion_optimizer:: simplify_expressions:: GuaranteeRewriter ;
3139use datafusion_optimizer:: { OptimizerConfig , OptimizerContext } ;
3240use datafusion_sql:: planner:: { ContextProvider , SqlToRel } ;
3341use datafusion_sql:: sqlparser:: ast:: Statement ;
@@ -233,3 +241,120 @@ impl TableSource for MyTableSource {
233241 self . schema . clone ( )
234242 }
235243}
244+
245+ #[ test]
246+ fn test_nested_schema_nullability ( ) {
247+ let mut builder = SchemaBuilder :: new ( ) ;
248+ builder. push ( Field :: new ( "foo" , DataType :: Int32 , true ) ) ;
249+ builder. push ( Field :: new (
250+ "parent" ,
251+ DataType :: Struct ( Fields :: from ( vec ! [ Field :: new(
252+ "child" ,
253+ DataType :: Int64 ,
254+ false ,
255+ ) ] ) ) ,
256+ true ,
257+ ) ) ;
258+ let schema = builder. finish ( ) ;
259+
260+ let dfschema = DFSchema :: from_field_specific_qualified_schema (
261+ vec ! [ Some ( "table_name" . into( ) ) , None ] ,
262+ & Arc :: new ( schema) ,
263+ )
264+ . unwrap ( ) ;
265+
266+ let expr = col ( "parent" ) . field ( "child" ) ;
267+ assert ! ( expr. nullable( & dfschema) . unwrap( ) ) ;
268+ }
269+
270+ #[ test]
271+ fn test_inequalities_non_null_bounded ( ) {
272+ let guarantees = vec ! [
273+ // x ∈ [1, 3] (not null)
274+ (
275+ col( "x" ) ,
276+ NullableInterval :: NotNull {
277+ values: Interval :: make( Some ( 1_i32 ) , Some ( 3_i32 ) ) . unwrap( ) ,
278+ } ,
279+ ) ,
280+ // s.y ∈ [1, 3] (not null)
281+ (
282+ col( "s" ) . field( "y" ) ,
283+ NullableInterval :: NotNull {
284+ values: Interval :: make( Some ( 1_i32 ) , Some ( 3_i32 ) ) . unwrap( ) ,
285+ } ,
286+ ) ,
287+ ] ;
288+
289+ let mut rewriter = GuaranteeRewriter :: new ( guarantees. iter ( ) ) ;
290+
291+ // (original_expr, expected_simplification)
292+ let simplified_cases = & [
293+ ( col ( "x" ) . lt ( lit ( 0 ) ) , false ) ,
294+ ( col ( "s" ) . field ( "y" ) . lt ( lit ( 0 ) ) , false ) ,
295+ ( col ( "x" ) . lt_eq ( lit ( 3 ) ) , true ) ,
296+ ( col ( "x" ) . gt ( lit ( 3 ) ) , false ) ,
297+ ( col ( "x" ) . gt ( lit ( 0 ) ) , true ) ,
298+ ( col ( "x" ) . eq ( lit ( 0 ) ) , false ) ,
299+ ( col ( "x" ) . not_eq ( lit ( 0 ) ) , true ) ,
300+ ( col ( "x" ) . between ( lit ( 0 ) , lit ( 5 ) ) , true ) ,
301+ ( col ( "x" ) . between ( lit ( 5 ) , lit ( 10 ) ) , false ) ,
302+ ( col ( "x" ) . not_between ( lit ( 0 ) , lit ( 5 ) ) , false ) ,
303+ ( col ( "x" ) . not_between ( lit ( 5 ) , lit ( 10 ) ) , true ) ,
304+ (
305+ Expr :: BinaryExpr ( BinaryExpr {
306+ left : Box :: new ( col ( "x" ) ) ,
307+ op : Operator :: IsDistinctFrom ,
308+ right : Box :: new ( lit ( ScalarValue :: Null ) ) ,
309+ } ) ,
310+ true ,
311+ ) ,
312+ (
313+ Expr :: BinaryExpr ( BinaryExpr {
314+ left : Box :: new ( col ( "x" ) ) ,
315+ op : Operator :: IsDistinctFrom ,
316+ right : Box :: new ( lit ( 5 ) ) ,
317+ } ) ,
318+ true ,
319+ ) ,
320+ ] ;
321+
322+ validate_simplified_cases ( & mut rewriter, simplified_cases) ;
323+
324+ let unchanged_cases = & [
325+ col ( "x" ) . gt ( lit ( 2 ) ) ,
326+ col ( "x" ) . lt_eq ( lit ( 2 ) ) ,
327+ col ( "x" ) . eq ( lit ( 2 ) ) ,
328+ col ( "x" ) . not_eq ( lit ( 2 ) ) ,
329+ col ( "x" ) . between ( lit ( 3 ) , lit ( 5 ) ) ,
330+ col ( "x" ) . not_between ( lit ( 3 ) , lit ( 10 ) ) ,
331+ ] ;
332+
333+ validate_unchanged_cases ( & mut rewriter, unchanged_cases) ;
334+ }
335+
336+ fn validate_simplified_cases < T > ( rewriter : & mut GuaranteeRewriter , cases : & [ ( Expr , T ) ] )
337+ where
338+ ScalarValue : From < T > ,
339+ T : Clone ,
340+ {
341+ for ( expr, expected_value) in cases {
342+ let output = expr. clone ( ) . rewrite ( rewriter) . data ( ) . unwrap ( ) ;
343+ let expected = lit ( ScalarValue :: from ( expected_value. clone ( ) ) ) ;
344+ assert_eq ! (
345+ output, expected,
346+ "{} simplified to {}, but expected {}" ,
347+ expr, output, expected
348+ ) ;
349+ }
350+ }
351+ fn validate_unchanged_cases ( rewriter : & mut GuaranteeRewriter , cases : & [ Expr ] ) {
352+ for expr in cases {
353+ let output = expr. clone ( ) . rewrite ( rewriter) . data ( ) . unwrap ( ) ;
354+ assert_eq ! (
355+ & output, expr,
356+ "{} was simplified to {}, but expected it to be unchanged" ,
357+ expr, output
358+ ) ;
359+ }
360+ }
0 commit comments