Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 221 additions & 0 deletions vortex-array/src/expr/exprs/fill_null.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::fmt::Formatter;

use vortex_dtype::DType;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_error::vortex_err;
use vortex_session::VortexSession;

use crate::ArrayRef;
use crate::compute::fill_null as compute_fill_null;
use crate::expr::Arity;
use crate::expr::ChildName;
use crate::expr::EmptyOptions;
use crate::expr::ExecutionArgs;
use crate::expr::ExprId;
use crate::expr::Expression;
use crate::expr::VTable;
use crate::expr::VTableExt;

/// An expression that replaces null values in the input with a fill value.
pub struct FillNull;

impl VTable for FillNull {
type Options = EmptyOptions;

fn id(&self) -> ExprId {
ExprId::from("vortex.fill_null")
}

fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(vec![]))
}

fn deserialize(
&self,
_metadata: &[u8],
_session: &VortexSession,
) -> VortexResult<Self::Options> {
Ok(EmptyOptions)
}

fn arity(&self, _options: &Self::Options) -> Arity {
Arity::Exact(2)
}

fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
match child_idx {
0 => ChildName::from("input"),
1 => ChildName::from("fill_value"),
_ => unreachable!("Invalid child index {} for FillNull expression", child_idx),
}
}

fn fmt_sql(
&self,
_options: &Self::Options,
expr: &Expression,
f: &mut Formatter<'_>,
) -> std::fmt::Result {
write!(f, "fill_null(")?;
expr.child(0).fmt_sql(f)?;
write!(f, ", ")?;
expr.child(1).fmt_sql(f)?;
write!(f, ")")
}

fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
vortex_ensure!(
arg_dtypes[0].eq_ignore_nullability(&arg_dtypes[1]),
"fill_null requires input and fill value to have the same base type, got {} and {}",
arg_dtypes[0],
arg_dtypes[1]
);
// The result dtype takes the nullability of the fill value.
Ok(arg_dtypes[0]
.clone()
.with_nullability(arg_dtypes[1].nullability()))
}

fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
let [input, fill_value]: [ArrayRef; _] = args
.inputs
.try_into()
.map_err(|_| vortex_err!("Wrong arg count"))?;

let fill_scalar = fill_value
.as_constant()
.ok_or_else(|| vortex_err!("fill_null fill_value must be a constant/scalar"))?;

compute_fill_null(input.as_ref(), &fill_scalar)
}

fn simplify(
&self,
_options: &Self::Options,
expr: &Expression,
ctx: &dyn crate::expr::SimplifyCtx,
) -> VortexResult<Option<Expression>> {
let input_dtype = ctx.return_dtype(expr.child(0))?;

if !input_dtype.is_nullable() {
return Ok(Some(expr.child(0).clone()));
}

Ok(None)
}

fn validity(
&self,
_options: &Self::Options,
expression: &Expression,
) -> VortexResult<Option<Expression>> {
// After fill_null, the result validity depends on the fill value's nullability.
// If fill_value is non-nullable, the result is always valid.
Ok(Some(expression.child(1).validity()?))
}

fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
true
}

fn is_fallible(&self, _options: &Self::Options) -> bool {
false
}
}

/// Creates an expression that replaces null values with a fill value.
///
/// ```rust
/// # use vortex_array::expr::{fill_null, root, lit};
/// let expr = fill_null(root(), lit(0i32));
/// ```
pub fn fill_null(child: Expression, fill_value: Expression) -> Expression {
FillNull.new_expr(EmptyOptions, [child, fill_value])
}

#[cfg(test)]
mod tests {
use vortex_buffer::buffer;
use vortex_dtype::DType;
use vortex_dtype::Nullability;
use vortex_dtype::PType;
use vortex_error::VortexExpect;

use super::fill_null;
use crate::IntoArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::StructArray;
use crate::assert_arrays_eq;
use crate::expr::exprs::get_item::get_item;
use crate::expr::exprs::literal::lit;
use crate::expr::exprs::root::root;

#[test]
fn dtype() {
let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
assert_eq!(
fill_null(root(), lit(0i32)).return_dtype(&dtype).unwrap(),
DType::Primitive(PType::I32, Nullability::NonNullable)
);
}

#[test]
fn replace_children() {
let expr = fill_null(root(), lit(0i32));
expr.with_children(vec![root(), lit(0i32)])
.vortex_expect("operation should succeed in test");
}

#[test]
fn evaluate() {
let test_array =
PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
.into_array();

let expr = fill_null(root(), lit(42i32));
let result = test_array.apply(&expr).unwrap();

assert_eq!(
result.dtype(),
&DType::Primitive(PType::I32, Nullability::NonNullable)
);
assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 42, 3, 42, 5]));
}

#[test]
fn evaluate_struct_field() {
let test_array = StructArray::from_fields(&[(
"a",
PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(),
)])
.unwrap()
.into_array();

let expr = fill_null(get_item("a", root()), lit(0i32));
let result = test_array.apply(&expr).unwrap();

assert_eq!(
result.dtype(),
&DType::Primitive(PType::I32, Nullability::NonNullable)
);
assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 0, 3]));
}

#[test]
fn evaluate_non_nullable_input() {
let test_array = buffer![1i32, 2, 3].into_array();
let expr = fill_null(root(), lit(0i32));
let result = test_array.apply(&expr).unwrap();
assert_arrays_eq!(result, PrimitiveArray::from_iter([1i32, 2, 3]));
}

#[test]
fn test_display() {
let expr = fill_null(get_item("value", root()), lit(0i32));
assert_eq!(expr.to_string(), "fill_null($.value, 0i32)");
}
}
2 changes: 2 additions & 0 deletions vortex-array/src/expr/exprs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub(crate) mod between;
pub(crate) mod binary;
pub(crate) mod cast;
pub(crate) mod dynamic;
pub(crate) mod fill_null;
pub(crate) mod get_item;
pub(crate) mod is_null;
pub(crate) mod like;
Expand All @@ -21,6 +22,7 @@ pub use between::*;
pub use binary::*;
pub use cast::*;
pub use dynamic::*;
pub use fill_null::*;
pub use get_item::*;
pub use is_null::*;
pub use like::*;
Expand Down
2 changes: 2 additions & 0 deletions vortex-array/src/expr/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::expr::ExprVTable;
use crate::expr::exprs::between::Between;
use crate::expr::exprs::binary::Binary;
use crate::expr::exprs::cast::Cast;
use crate::expr::exprs::fill_null::FillNull;
use crate::expr::exprs::get_item::GetItem;
use crate::expr::exprs::is_null::IsNull;
use crate::expr::exprs::like::Like;
Expand Down Expand Up @@ -56,6 +57,7 @@ impl Default for ExprSession {
ExprVTable::new_static(&Between),
ExprVTable::new_static(&Binary),
ExprVTable::new_static(&Cast),
ExprVTable::new_static(&FillNull),
ExprVTable::new_static(&GetItem),
ExprVTable::new_static(&IsNull),
ExprVTable::new_static(&Like),
Expand Down
3 changes: 3 additions & 0 deletions vortex-array/src/expr/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ mod tests {
use crate::expr::exprs::binary::not_eq;
use crate::expr::exprs::binary::or;
use crate::expr::exprs::cast::cast;
use crate::expr::exprs::fill_null::fill_null;
use crate::expr::exprs::get_item::col;
use crate::expr::exprs::get_item::get_item;
use crate::expr::exprs::is_null::is_null;
Expand Down Expand Up @@ -734,6 +735,8 @@ mod tests {
#[case(checked_add(col("a"), lit(5)))]
// Null check expressions
#[case(is_null(col("nullable_col")))]
// Fill null expressions
#[case(fill_null(col("a"), lit(0)))]
// Type casting expressions
#[case(cast(
col("a"),
Expand Down
Loading