Skip to content

Commit 69254ff

Browse files
committed
Keep output as scalar for scalar function if all inputs are scalar
1 parent 448dff5 commit 69254ff

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

datafusion/physical-expr/src/functions.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ where
357357
ColumnarValue::Array(a) => Some(a.len()),
358358
});
359359

360+
let is_scalar = len.is_none();
361+
360362
let inferred_length = len.unwrap_or(1);
361363
let args = args
362364
.iter()
@@ -373,7 +375,14 @@ where
373375
.collect::<Vec<ArrayRef>>();
374376

375377
let result = (inner)(&args);
376-
result.map(ColumnarValue::Array)
378+
379+
if is_scalar {
380+
// If all inputs are scalar, keeps output as scalar
381+
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
382+
result.map(ColumnarValue::Scalar)
383+
} else {
384+
result.map(ColumnarValue::Array)
385+
}
377386
})
378387
}
379388

datafusion/physical-expr/src/planner.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,37 @@ pub fn create_physical_expr(
448448
}
449449
}
450450
}
451+
452+
#[cfg(test)]
453+
mod tests {
454+
use super::*;
455+
use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray};
456+
use arrow_schema::{DataType, Field, Schema};
457+
use datafusion_common::{DFSchema, Result};
458+
use datafusion_expr::{col, left, Literal};
459+
460+
#[test]
461+
fn test_create_physical_expr_scalar_input_output() -> Result<()> {
462+
let expr = col("letter").eq(left("APACHE".lit(), 1i64.lit()));
463+
464+
let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, false)]);
465+
let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?;
466+
let p = create_physical_expr(&expr, &df_schema, &schema, &ExecutionProps::new())?;
467+
468+
let batch = RecordBatch::try_new(
469+
Arc::new(schema),
470+
vec![Arc::new(StringArray::from_iter_values(vec![
471+
"A", "B", "C", "D",
472+
]))],
473+
)?;
474+
let result = p.evaluate(&batch)?;
475+
let result = result.into_array(4);
476+
477+
assert_eq!(
478+
&result,
479+
&(Arc::new(BooleanArray::from(vec![true, false, false, false,])) as ArrayRef)
480+
);
481+
482+
Ok(())
483+
}
484+
}

0 commit comments

Comments
 (0)