@@ -392,20 +392,59 @@ impl PartialOrd for Expr {
392392 }
393393}
394394
395+ /// Provides schema information needed by [Expr] methods such as
396+ /// [Expr::nullable] and [Expr::data_type].
397+ ///
398+ /// Note that this trait is implemented for &[DFSchema] which is
399+ /// widely used in the DataFusion codebase.
400+ pub trait ExprSchema {
401+ /// Is this column reference nullable?
402+ fn nullable ( & self , col : & Column ) -> Result < bool > ;
403+
404+ /// What is the datatype of this column?
405+ fn data_type ( & self , col : & Column ) -> Result < & DataType > ;
406+ }
407+
408+ // Implement `ExprSchema` for `Arc<DFSchema>`
409+ impl < P : AsRef < DFSchema > > ExprSchema for P {
410+ fn nullable ( & self , col : & Column ) -> Result < bool > {
411+ self . as_ref ( ) . nullable ( col)
412+ }
413+
414+ fn data_type ( & self , col : & Column ) -> Result < & DataType > {
415+ self . as_ref ( ) . data_type ( col)
416+ }
417+ }
418+
419+ impl ExprSchema for DFSchema {
420+ fn nullable ( & self , col : & Column ) -> Result < bool > {
421+ Ok ( self . field_from_column ( col) ?. is_nullable ( ) )
422+ }
423+
424+ fn data_type ( & self , col : & Column ) -> Result < & DataType > {
425+ Ok ( self . field_from_column ( col) ?. data_type ( ) )
426+ }
427+ }
428+
395429impl Expr {
396- /// Returns the [arrow::datatypes::DataType] of the expression based on [arrow::datatypes::Schema].
430+ /// Returns the [arrow::datatypes::DataType] of the expression
431+ /// based on [ExprSchema]
432+ ///
433+ /// Note: [DFSchema] implements [ExprSchema].
397434 ///
398435 /// # Errors
399436 ///
400- /// This function errors when it is not possible to compute its [arrow::datatypes::DataType].
401- /// This happens when e.g. the expression refers to a column that does not exist in the schema, or when
402- /// the expression is incorrectly typed (e.g. `[utf8] + [bool]`).
403- pub fn get_type ( & self , schema : & DFSchema ) -> Result < DataType > {
437+ /// This function errors when it is not possible to compute its
438+ /// [arrow::datatypes::DataType]. This happens when e.g. the
439+ /// expression refers to a column that does not exist in the
440+ /// schema, or when the expression is incorrectly typed
441+ /// (e.g. `[utf8] + [bool]`).
442+ pub fn get_type < S : ExprSchema > ( & self , schema : & S ) -> Result < DataType > {
404443 match self {
405444 Expr :: Alias ( expr, _) | Expr :: Sort { expr, .. } | Expr :: Negative ( expr) => {
406445 expr. get_type ( schema)
407446 }
408- Expr :: Column ( c) => Ok ( schema. field_from_column ( c) ?. data_type ( ) . clone ( ) ) ,
447+ Expr :: Column ( c) => Ok ( schema. data_type ( c) ?. clone ( ) ) ,
409448 Expr :: ScalarVariable ( _) => Ok ( DataType :: Utf8 ) ,
410449 Expr :: Literal ( l) => Ok ( l. get_datatype ( ) ) ,
411450 Expr :: Case { when_then_expr, .. } => when_then_expr[ 0 ] . 1 . get_type ( schema) ,
@@ -472,21 +511,24 @@ impl Expr {
472511 }
473512 }
474513
475- /// Returns the nullability of the expression based on [arrow::datatypes::Schema].
514+ /// Returns the nullability of the expression based on [ExprSchema].
515+ ///
516+ /// Note: [DFSchema] implements [ExprSchema].
476517 ///
477518 /// # Errors
478519 ///
479- /// This function errors when it is not possible to compute its nullability.
480- /// This happens when the expression refers to a column that does not exist in the schema.
481- pub fn nullable ( & self , input_schema : & DFSchema ) -> Result < bool > {
520+ /// This function errors when it is not possible to compute its
521+ /// nullability. This happens when the expression refers to a
522+ /// column that does not exist in the schema.
523+ pub fn nullable < S : ExprSchema > ( & self , input_schema : & S ) -> Result < bool > {
482524 match self {
483525 Expr :: Alias ( expr, _)
484526 | Expr :: Not ( expr)
485527 | Expr :: Negative ( expr)
486528 | Expr :: Sort { expr, .. }
487529 | Expr :: Between { expr, .. }
488530 | Expr :: InList { expr, .. } => expr. nullable ( input_schema) ,
489- Expr :: Column ( c) => Ok ( input_schema. field_from_column ( c ) ? . is_nullable ( ) ) ,
531+ Expr :: Column ( c) => input_schema. nullable ( c ) ,
490532 Expr :: Literal ( value) => Ok ( value. is_null ( ) ) ,
491533 Expr :: Case {
492534 when_then_expr,
@@ -561,7 +603,11 @@ impl Expr {
561603 ///
562604 /// This function errors when it is impossible to cast the
563605 /// expression to the target [arrow::datatypes::DataType].
564- pub fn cast_to ( self , cast_to_type : & DataType , schema : & DFSchema ) -> Result < Expr > {
606+ pub fn cast_to < S : ExprSchema > (
607+ self ,
608+ cast_to_type : & DataType ,
609+ schema : & S ,
610+ ) -> Result < Expr > {
565611 // TODO(kszucs): most of the operations do not validate the type correctness
566612 // like all of the binary expressions below. Perhaps Expr should track the
567613 // type of the expression?
@@ -2557,4 +2603,57 @@ mod tests {
25572603 combine_filters ( & [ filter1. clone ( ) , filter2. clone ( ) , filter3. clone ( ) ] ) ;
25582604 assert_eq ! ( result, Some ( and( and( filter1, filter2) , filter3) ) ) ;
25592605 }
2606+
2607+ #[ test]
2608+ fn expr_schema_nullability ( ) {
2609+ let expr = col ( "foo" ) . eq ( lit ( 1 ) ) ;
2610+ assert ! ( !expr. nullable( & MockExprSchema :: new( ) ) . unwrap( ) ) ;
2611+ assert ! ( expr
2612+ . nullable( & MockExprSchema :: new( ) . with_nullable( true ) )
2613+ . unwrap( ) ) ;
2614+ }
2615+
2616+ #[ test]
2617+ fn expr_schema_data_type ( ) {
2618+ let expr = col ( "foo" ) ;
2619+ assert_eq ! (
2620+ DataType :: Utf8 ,
2621+ expr. get_type( & MockExprSchema :: new( ) . with_data_type( DataType :: Utf8 ) )
2622+ . unwrap( )
2623+ ) ;
2624+ }
2625+
2626+ struct MockExprSchema {
2627+ nullable : bool ,
2628+ data_type : DataType ,
2629+ }
2630+
2631+ impl MockExprSchema {
2632+ fn new ( ) -> Self {
2633+ Self {
2634+ nullable : false ,
2635+ data_type : DataType :: Null ,
2636+ }
2637+ }
2638+
2639+ fn with_nullable ( mut self , nullable : bool ) -> Self {
2640+ self . nullable = nullable;
2641+ self
2642+ }
2643+
2644+ fn with_data_type ( mut self , data_type : DataType ) -> Self {
2645+ self . data_type = data_type;
2646+ self
2647+ }
2648+ }
2649+
2650+ impl ExprSchema for MockExprSchema {
2651+ fn nullable ( & self , _col : & Column ) -> Result < bool > {
2652+ Ok ( self . nullable )
2653+ }
2654+
2655+ fn data_type ( & self , _col : & Column ) -> Result < & DataType > {
2656+ Ok ( & self . data_type )
2657+ }
2658+ }
25602659}
0 commit comments