Skip to content

Commit b2eaee3

Browse files
authored
API to get Expr's type and nullability without a DFSchema (#1726)
* API to get Expr type and nullability without a `DFSchema` * Add test * publically export * Improve docs
1 parent 78c30b6 commit b2eaee3

File tree

2 files changed

+113
-14
lines changed

2 files changed

+113
-14
lines changed

datafusion/src/logical_plan/expr.rs

Lines changed: 111 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -392,20 +392,59 @@ impl PartialOrd for Expr {
392392
}
393393
}
394394

395+
/// Provides schema information needed by [Expr] methods such as
396+
/// [Expr::nullable] and [Expr::data_type].
397+
///
398+
/// Note that this trait is implemented for &[DFSchema] which is
399+
/// widely used in the DataFusion codebase.
400+
pub trait ExprSchema {
401+
/// Is this column reference nullable?
402+
fn nullable(&self, col: &Column) -> Result<bool>;
403+
404+
/// What is the datatype of this column?
405+
fn data_type(&self, col: &Column) -> Result<&DataType>;
406+
}
407+
408+
// Implement `ExprSchema` for `Arc<DFSchema>`
409+
impl<P: AsRef<DFSchema>> ExprSchema for P {
410+
fn nullable(&self, col: &Column) -> Result<bool> {
411+
self.as_ref().nullable(col)
412+
}
413+
414+
fn data_type(&self, col: &Column) -> Result<&DataType> {
415+
self.as_ref().data_type(col)
416+
}
417+
}
418+
419+
impl ExprSchema for DFSchema {
420+
fn nullable(&self, col: &Column) -> Result<bool> {
421+
Ok(self.field_from_column(col)?.is_nullable())
422+
}
423+
424+
fn data_type(&self, col: &Column) -> Result<&DataType> {
425+
Ok(self.field_from_column(col)?.data_type())
426+
}
427+
}
428+
395429
impl Expr {
396-
/// Returns the [arrow::datatypes::DataType] of the expression based on [arrow::datatypes::Schema].
430+
/// Returns the [arrow::datatypes::DataType] of the expression
431+
/// based on [ExprSchema]
432+
///
433+
/// Note: [DFSchema] implements [ExprSchema].
397434
///
398435
/// # Errors
399436
///
400-
/// This function errors when it is not possible to compute its [arrow::datatypes::DataType].
401-
/// This happens when e.g. the expression refers to a column that does not exist in the schema, or when
402-
/// the expression is incorrectly typed (e.g. `[utf8] + [bool]`).
403-
pub fn get_type(&self, schema: &DFSchema) -> Result<DataType> {
437+
/// This function errors when it is not possible to compute its
438+
/// [arrow::datatypes::DataType]. This happens when e.g. the
439+
/// expression refers to a column that does not exist in the
440+
/// schema, or when the expression is incorrectly typed
441+
/// (e.g. `[utf8] + [bool]`).
442+
pub fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
404443
match self {
405444
Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => {
406445
expr.get_type(schema)
407446
}
408-
Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()),
447+
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
409448
Expr::ScalarVariable(_) => Ok(DataType::Utf8),
410449
Expr::Literal(l) => Ok(l.get_datatype()),
411450
Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
@@ -472,21 +511,24 @@ impl Expr {
472511
}
473512
}
474513

475-
/// Returns the nullability of the expression based on [arrow::datatypes::Schema].
514+
/// Returns the nullability of the expression based on [ExprSchema].
515+
///
516+
/// Note: [DFSchema] implements [ExprSchema].
476517
///
477518
/// # Errors
478519
///
479-
/// This function errors when it is not possible to compute its nullability.
480-
/// This happens when the expression refers to a column that does not exist in the schema.
481-
pub fn nullable(&self, input_schema: &DFSchema) -> Result<bool> {
520+
/// This function errors when it is not possible to compute its
521+
/// nullability. This happens when the expression refers to a
522+
/// column that does not exist in the schema.
523+
pub fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
482524
match self {
483525
Expr::Alias(expr, _)
484526
| Expr::Not(expr)
485527
| Expr::Negative(expr)
486528
| Expr::Sort { expr, .. }
487529
| Expr::Between { expr, .. }
488530
| Expr::InList { expr, .. } => expr.nullable(input_schema),
489-
Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()),
531+
Expr::Column(c) => input_schema.nullable(c),
490532
Expr::Literal(value) => Ok(value.is_null()),
491533
Expr::Case {
492534
when_then_expr,
@@ -561,7 +603,11 @@ impl Expr {
561603
///
562604
/// This function errors when it is impossible to cast the
563605
/// expression to the target [arrow::datatypes::DataType].
564-
pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result<Expr> {
606+
pub fn cast_to<S: ExprSchema>(
607+
self,
608+
cast_to_type: &DataType,
609+
schema: &S,
610+
) -> Result<Expr> {
565611
// TODO(kszucs): most of the operations do not validate the type correctness
566612
// like all of the binary expressions below. Perhaps Expr should track the
567613
// type of the expression?
@@ -2557,4 +2603,57 @@ mod tests {
25572603
combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]);
25582604
assert_eq!(result, Some(and(and(filter1, filter2), filter3)));
25592605
}
2606+
2607+
#[test]
2608+
fn expr_schema_nullability() {
2609+
let expr = col("foo").eq(lit(1));
2610+
assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
2611+
assert!(expr
2612+
.nullable(&MockExprSchema::new().with_nullable(true))
2613+
.unwrap());
2614+
}
2615+
2616+
#[test]
2617+
fn expr_schema_data_type() {
2618+
let expr = col("foo");
2619+
assert_eq!(
2620+
DataType::Utf8,
2621+
expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
2622+
.unwrap()
2623+
);
2624+
}
2625+
2626+
struct MockExprSchema {
2627+
nullable: bool,
2628+
data_type: DataType,
2629+
}
2630+
2631+
impl MockExprSchema {
2632+
fn new() -> Self {
2633+
Self {
2634+
nullable: false,
2635+
data_type: DataType::Null,
2636+
}
2637+
}
2638+
2639+
fn with_nullable(mut self, nullable: bool) -> Self {
2640+
self.nullable = nullable;
2641+
self
2642+
}
2643+
2644+
fn with_data_type(mut self, data_type: DataType) -> Self {
2645+
self.data_type = data_type;
2646+
self
2647+
}
2648+
}
2649+
2650+
impl ExprSchema for MockExprSchema {
2651+
fn nullable(&self, _col: &Column) -> Result<bool> {
2652+
Ok(self.nullable)
2653+
}
2654+
2655+
fn data_type(&self, _col: &Column) -> Result<&DataType> {
2656+
Ok(&self.data_type)
2657+
}
2658+
}
25602659
}

datafusion/src/logical_plan/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ pub use expr::{
4646
rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512,
4747
signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex,
4848
translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when,
49-
Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion,
50-
SimplifyInfo,
49+
Column, Expr, ExprRewriter, ExprSchema, ExpressionVisitor, Literal, Recursion,
50+
RewriteRecursion, SimplifyInfo,
5151
};
5252
pub use extension::UserDefinedLogicalNode;
5353
pub use operators::Operator;

0 commit comments

Comments
 (0)