Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions vortex-array/benches/expr/case_when_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ fn make_struct_array(size: usize) -> ArrayRef {
}

/// Benchmark a simple binary CASE WHEN with varying array sizes.
#[divan::bench(args = [10000, 100000, 1000000])]
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_simple(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

Expand Down Expand Up @@ -94,7 +94,7 @@ fn case_when_nary_3_conditions(bencher: Bencher, size: usize) {
}

/// Benchmark CASE WHEN where all conditions are true (short-circuit path).
#[divan::bench(args = [10000, 100000, 1000000])]
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_all_true(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

Expand All @@ -117,7 +117,7 @@ fn case_when_all_true(bencher: Bencher, size: usize) {
}

/// Benchmark CASE WHEN where all conditions are false (short-circuit path).
#[divan::bench(args = [10000, 100000, 1000000])]
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_all_false(bencher: Bencher, size: usize) {
let array = make_struct_array(size);

Expand Down Expand Up @@ -181,7 +181,7 @@ fn case_when_nary_10_conditions(bencher: Bencher, size: usize) {
}

/// Benchmark n-ary CASE WHEN with 100 conditions.
#[divan::bench(args = [10000, 100000, 1000000])]
#[divan::bench(args = [1000, 10000, 100000])]
fn case_when_nary_100_conditions(bencher: Bencher, size: usize) {
use vortex_array::expr::Expression;

Expand Down
44 changes: 24 additions & 20 deletions vortex-array/src/expr/exprs/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ use crate::compute::zip;
use crate::expr::Arity;
use crate::expr::ChildName;
use crate::expr::ExecutionArgs;
use crate::expr::ExecutionResult;
use crate::expr::ExprId;
use crate::expr::VTable;
use crate::expr::VTableExt;
Expand Down Expand Up @@ -79,12 +78,10 @@ impl VTable for CaseWhen {
}

fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
let num_children =
options.num_when_then_pairs * 2 + if options.has_else { 1 } else { 0 };
Ok(Some(
pb::CaseWhenOpts {
num_when_then_pairs: options.num_when_then_pairs,
has_else: options.has_else,
}
.encode_to_vec(),
pb::CaseWhenOpts { num_children }.encode_to_vec(),
))
}

Expand All @@ -95,8 +92,8 @@ impl VTable for CaseWhen {
) -> VortexResult<Self::Options> {
let opts = pb::CaseWhenOpts::decode(metadata)?;
Ok(CaseWhenOptions {
num_when_then_pairs: opts.num_when_then_pairs,
has_else: opts.has_else,
num_when_then_pairs: opts.num_children / 2,
has_else: opts.num_children % 2 == 1,
})
}

Expand Down Expand Up @@ -156,6 +153,18 @@ impl VTable for CaseWhen {
// The return dtype is based on the first THEN expression (index 1)
let then_dtype = &arg_dtypes[1];

// All THEN (and ELSE) value dtypes must match
debug_assert!(
(0..options.num_when_then_pairs as usize).all(|i| {
let idx = i * 2 + 1;
&arg_dtypes[idx] == then_dtype
}),
"All THEN expression dtypes must match, got {:?}",
(0..options.num_when_then_pairs as usize)
.map(|i| &arg_dtypes[i * 2 + 1])
.collect::<Vec<_>>()
);

// If there's no ELSE, the result is always nullable (unmatched rows are NULL)
if !options.has_else {
Ok(then_dtype.as_nullable())
Expand All @@ -168,7 +177,7 @@ impl VTable for CaseWhen {
&self,
options: &Self::Options,
args: ExecutionArgs,
) -> VortexResult<ExecutionResult> {
) -> VortexResult<ArrayRef> {
let row_count = args.row_count;
let num_pairs = options.num_when_then_pairs as usize;

Expand Down Expand Up @@ -222,7 +231,7 @@ impl VTable for CaseWhen {
result = zip(then_value.as_ref(), result.as_ref(), &mask)?;
}

result.execute::<ExecutionResult>(args.ctx)
Ok(result)
}

fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
Expand All @@ -236,7 +245,7 @@ impl VTable for CaseWhen {
}

/// Efficient implementation for binary CASE WHEN (single when/then pair)
fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResult<ExecutionResult> {
fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResult<ArrayRef> {
let row_count = args.row_count;

// Extract inputs based on arity: [condition, then_value] or [condition, then_value, else_value]
Expand Down Expand Up @@ -265,20 +274,17 @@ fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResul

// Short-circuit: all true -> just return THEN value
if mask.all_true() {
return then_value.execute::<ExecutionResult>(args.ctx);
return Ok(then_value);
}

// Short-circuit: all false -> return ELSE value or NULL
if mask.all_false() {
return match else_value {
Some(else_value) => else_value.execute::<ExecutionResult>(args.ctx),
Some(else_value) => Ok(else_value),
None => {
// Create NULL constant of appropriate type
let then_dtype = then_value.dtype().as_nullable();
Ok(ExecutionResult::constant(
Scalar::null(then_dtype),
row_count,
))
Ok(ConstantArray::new(Scalar::null(then_dtype), row_count).into_array())
}
};
}
Expand All @@ -290,9 +296,7 @@ fn execute_binary_case_when(_has_else: bool, args: ExecutionArgs) -> VortexResul
});

// Use zip to select: where mask is true, take then_value; else take else_value
let result = zip(then_value.as_ref(), else_value.as_ref(), &mask)?;

result.execute::<ExecutionResult>(args.ctx)
zip(then_value.as_ref(), else_value.as_ref(), &mask)
}

/// Creates an N-ary CASE WHEN expression from a flat list of children.
Expand Down
7 changes: 5 additions & 2 deletions vortex-proto/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ message SelectOpts {
}

// Options for `vortex.case_when`
// Encodes num_when_then_pairs and has_else into a single u32 (num_children).
// num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0)
// has_else = num_children % 2 == 1
// num_when_then_pairs = num_children / 2
message CaseWhenOpts {
uint32 num_when_then_pairs = 1;
bool has_else = 2;
uint32 num_children = 1;
}
8 changes: 5 additions & 3 deletions vortex-proto/src/generated/vortex.expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ pub mod select_opts {
}
}
/// Options for `vortex.case_when`
/// Encodes num_when_then_pairs and has_else into a single u32 (num_children).
/// num_children = num_when_then_pairs * 2 + (has_else ? 1 : 0)
/// has_else = num_children % 2 == 1
/// num_when_then_pairs = num_children / 2
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct CaseWhenOpts {
#[prost(uint32, tag = "1")]
pub num_when_then_pairs: u32,
#[prost(bool, tag = "2")]
pub has_else: bool,
pub num_children: u32,
}
Loading