Skip to content

feat: add metadata to literal expressions #16170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 6, 2025
2 changes: 1 addition & 1 deletion datafusion-cli/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ pub struct ParquetMetadataFunc {}
impl TableFunctionImpl for ParquetMetadataFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let filename = match exprs.first() {
Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet')
Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet')
Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet")
_ => {
return plan_err!(
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async fn main() -> Result<()> {
let expr2 = Expr::BinaryExpr(BinaryExpr::new(
Box::new(col("a")),
Operator::Plus,
Box::new(Expr::Literal(ScalarValue::Int32(Some(5)))),
Box::new(Expr::Literal(ScalarValue::Int32(Some(5)), None)),
));
assert_eq!(expr, expr2);

Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/optimizer_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ fn is_binary_eq(binary_expr: &BinaryExpr) -> bool {

/// Return true if the expression is a literal or column reference
fn is_lit_or_col(expr: &Expr) -> bool {
matches!(expr, Expr::Column(_) | Expr::Literal(_))
matches!(expr, Expr::Column(_) | Expr::Literal(_, _))
}

/// A simple user defined filter function
Expand Down
5 changes: 3 additions & 2 deletions datafusion-examples/examples/simple_udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ struct LocalCsvTableFunc {}

impl TableFunctionImpl for LocalCsvTableFunc {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else {
let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)), _)) = exprs.first()
else {
return plan_err!("read_csv requires at least one string argument");
};

Expand All @@ -145,7 +146,7 @@ impl TableFunctionImpl for LocalCsvTableFunc {
let info = SimplifyContext::new(&execution_props);
let expr = ExprSimplifier::new(info).simplify(expr.clone())?;

if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr {
if let Expr::Literal(ScalarValue::Int64(Some(limit)), _) = expr {
Ok(limit as usize)
} else {
plan_err!("Limit must be an integer")
Expand Down
15 changes: 8 additions & 7 deletions datafusion/catalog-listing/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool {
Ok(TreeNodeRecursion::Stop)
}
}
Expr::Literal(_)
Expr::Literal(_, _)
| Expr::Alias(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::ScalarVariable(_, _)
Expand Down Expand Up @@ -346,8 +346,8 @@ fn populate_partition_values<'a>(
{
match op {
Operator::Eq => match (left.as_ref(), right.as_ref()) {
(Expr::Column(Column { ref name, .. }), Expr::Literal(val))
| (Expr::Literal(val), Expr::Column(Column { ref name, .. })) => {
(Expr::Column(Column { ref name, .. }), Expr::Literal(val, _))
| (Expr::Literal(val, _), Expr::Column(Column { ref name, .. })) => {
if partition_values
.insert(name, PartitionValue::Single(val.to_string()))
.is_some()
Expand Down Expand Up @@ -984,7 +984,7 @@ mod tests {
assert_eq!(
evaluate_partition_prefix(
partitions,
&[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))],
&[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3)), None))],
),
Some(Path::from("a=1970-01-04")),
);
Expand All @@ -993,9 +993,10 @@ mod tests {
assert_eq!(
evaluate_partition_prefix(
partitions,
&[col("a").eq(Expr::Literal(ScalarValue::Date64(Some(
4 * 24 * 60 * 60 * 1000
)))),],
&[col("a").eq(Expr::Literal(
ScalarValue::Date64(Some(4 * 24 * 60 * 60 * 1000)),
None
)),],
),
Some(Path::from("a=1970-01-05")),
);
Expand Down
7 changes: 5 additions & 2 deletions datafusion/core/benches/map_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ fn criterion_benchmark(c: &mut Criterion) {
let mut value_buffer = Vec::new();

for i in 0..1000 {
key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone()))));
value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i]))));
key_buffer.push(Expr::Literal(
ScalarValue::Utf8(Some(keys[i].clone())),
None,
));
value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None));
}
c.bench_function("map_1000_1", |b| {
b.iter(|| {
Expand Down
5 changes: 4 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,10 @@ impl DataFrame {
/// ```
pub async fn count(self) -> Result<usize> {
let rows = self
.aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])?
.aggregate(
vec![],
vec![count(Expr::Literal(COUNT_STAR_EXPANSION, None))],
)?
.collect()
.await?;
let len = *rows
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2230,7 +2230,7 @@ mod tests {
let filter_predicate = Expr::BinaryExpr(BinaryExpr::new(
Box::new(Expr::Column("column1".into())),
Operator::GtEq,
Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))),
Box::new(Expr::Literal(ScalarValue::Int32(Some(0)), None)),
));

// Create a new batch of data to insert into the table
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ impl SessionContext {
let mut params: Vec<ScalarValue> = parameters
.into_iter()
.map(|e| match e {
Expr::Literal(scalar) => Ok(scalar),
Expr::Literal(scalar, _) => Ok(scalar),
_ => not_impl_err!("Unsupported parameter type: {}", e),
})
.collect::<Result<_>>()?;
Expand Down
9 changes: 5 additions & 4 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2257,7 +2257,8 @@ mod tests {
// verify that the plan correctly casts u8 to i64
// the cast from u8 to i64 for literal will be simplified, and get lit(int64(5))
// the cast here is implicit so has CastOptions with safe=true
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }";
let expected = r#"BinaryExpr { left: Column { name: "c7", index: 2 }, op: Lt, right: Literal { value: Int64(5), field: Field { name: "5", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#;

assert!(format!("{exec_plan:?}").contains(expected));
Ok(())
}
Expand All @@ -2282,7 +2283,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), field: Field { name: "NULL", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c1"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c2"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;

assert_eq!(format!("{cube:?}"), expected);

Expand All @@ -2309,7 +2310,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), field: Field { name: "NULL", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c1"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c2"), (Literal { value: Int64(NULL), field: Field { name: "NULL", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;

assert_eq!(format!("{rollup:?}"), expected);

Expand Down Expand Up @@ -2493,7 +2494,7 @@ mod tests {
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.

let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }";
let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\"), field: Field { name: \"a\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\"), field: Field { name: \"1\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, fail_on_overflow: false }";

let actual = format!("{execution_plan:?}");
assert!(actual.contains(expected), "{}", actual);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,12 @@ impl TableProvider for CustomProvider {
match &filters[0] {
Expr::BinaryExpr(BinaryExpr { right, .. }) => {
let int_value = match &**right {
Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int64(Some(i))) => *i,
Expr::Literal(ScalarValue::Int8(Some(i)), _) => *i as i64,
Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64,
Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64,
Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i,
Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() {
Expr::Literal(lit_value) => match lit_value {
Expr::Literal(lit_value, _) => match lit_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
ScalarValue::Int32(Some(v)) => *v as i64,
Expand Down
7 changes: 5 additions & 2 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ async fn join_on_filter_datatype() -> Result<()> {
let join = left.clone().join_on(
right.clone(),
JoinType::Inner,
Some(Expr::Literal(ScalarValue::Null)),
Some(Expr::Literal(ScalarValue::Null, None)),
)?;
assert_snapshot!(join.into_optimized_plan().unwrap(), @"EmptyRelation");

Expand Down Expand Up @@ -4527,7 +4527,10 @@ async fn consecutive_projection_same_schema() -> Result<()> {

// Add `t` column full of nulls
let df = df
.with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32))
.with_column(
"t",
cast(Expr::Literal(ScalarValue::Null, None), DataType::Int32),
)
.unwrap();
df.clone().show().await.unwrap();

Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/tests/execution/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ async fn count_only_nulls() -> Result<()> {
let input = Arc::new(LogicalPlan::Values(Values {
schema: input_schema,
values: vec![
vec![Expr::Literal(ScalarValue::Null)],
vec![Expr::Literal(ScalarValue::Null)],
vec![Expr::Literal(ScalarValue::Null)],
vec![Expr::Literal(ScalarValue::Null, None)],
vec![Expr::Literal(ScalarValue::Null, None)],
vec![Expr::Literal(ScalarValue::Null, None)],
],
}));
let input_col_ref = Expr::Column(Column {
Expand Down
11 changes: 7 additions & 4 deletions datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,13 @@ fn select_date_plus_interval() -> Result<()> {

let date_plus_interval_expr = to_timestamp_expr(ts_string)
.cast_to(&DataType::Date32, schema)?
+ Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime {
days: 123,
milliseconds: 0,
})));
+ Expr::Literal(
ScalarValue::IntervalDayTime(Some(IntervalDayTime {
days: 123,
milliseconds: 0,
})),
None,
);

let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![date_plus_interval_expr])?
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/user_defined/expr_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl ExprPlanner for MyCustomPlanner {
}
BinaryOperator::Question => {
Ok(PlannerResult::Planned(Expr::Alias(Alias::new(
Expr::Literal(ScalarValue::Boolean(Some(true))),
Expr::Literal(ScalarValue::Boolean(Some(true)), None),
None::<&str>,
format!("{} ? {}", expr.left, expr.right),
))))
Expand Down
9 changes: 5 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -912,11 +912,12 @@ impl MyAnalyzerRule {
.map(|e| {
e.transform(|e| {
Ok(match e {
Expr::Literal(ScalarValue::Int64(i)) => {
Expr::Literal(ScalarValue::Int64(i), _) => {
// transform to UInt64
Transformed::yes(Expr::Literal(ScalarValue::UInt64(
i.map(|i| i as u64),
)))
Transformed::yes(Expr::Literal(
ScalarValue::UInt64(i.map(|i| i as u64)),
None,
))
}
_ => Transformed::no(e),
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::array::{as_string_array, record_batch, Int8Array, UInt64Array};
use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array};
use arrow::array::{
builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array,
Int32Array, RecordBatch, StringArray,
Expand All @@ -42,9 +42,9 @@ use datafusion_common::{
};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder,
OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
Signature, Volatility,
lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody,
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions_nested::range::range_udf;
use parking_lot::Mutex;
Expand Down Expand Up @@ -1529,6 +1529,65 @@ async fn test_metadata_based_udf() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_metadata_based_udf_with_literal() -> Result<()> {
let ctx = SessionContext::new();
let input_metadata: HashMap<String, String> =
[("modify_values".to_string(), "double_output".to_string())]
.into_iter()
.collect();
let df = ctx.sql("select 0;").await?.select(vec![
lit(5u64).alias_with_metadata("lit_with_doubling", Some(input_metadata.clone())),
lit(5u64).alias("lit_no_doubling"),
lit_with_metadata(5u64, Some(input_metadata))
.alias("lit_with_double_no_alias_metadata"),
])?;

let output_metadata: HashMap<String, String> =
[("output_metatype".to_string(), "custom_value".to_string())]
.into_iter()
.collect();
let custom_udf = ScalarUDF::from(MetadataBasedUdf::new(output_metadata.clone()));

let plan = LogicalPlanBuilder::from(df.into_optimized_plan()?)
.project(vec![
custom_udf
.call(vec![col("lit_with_doubling")])
.alias("doubled_output"),
custom_udf
.call(vec![col("lit_no_doubling")])
.alias("not_doubled_output"),
custom_udf
.call(vec![col("lit_with_double_no_alias_metadata")])
.alias("double_without_alias_metadata"),
])?
.build()?;

let actual = DataFrame::new(ctx.state(), plan).collect().await?;

let schema = Arc::new(Schema::new(vec![
Field::new("doubled_output", DataType::UInt64, false)
.with_metadata(output_metadata.clone()),
Field::new("not_doubled_output", DataType::UInt64, false)
.with_metadata(output_metadata.clone()),
Field::new("double_without_alias_metadata", DataType::UInt64, false)
.with_metadata(output_metadata.clone()),
]));

let expected = RecordBatch::try_new(
schema,
vec![
create_array!(UInt64, [10]),
create_array!(UInt64, [5]),
create_array!(UInt64, [10]),
],
)?;

assert_eq!(expected, actual[0]);

Ok(())
}

/// This UDF is to test extension handling, both on the input and output
/// sides. For the input, we will handle the data differently if there is
/// the canonical extension type Bool8. For the output we will add a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ impl TableFunctionImpl for SimpleCsvTableFunc {
let mut filepath = String::new();
for expr in exprs {
match expr {
Expr::Literal(ScalarValue::Utf8(Some(ref path))) => {
Expr::Literal(ScalarValue::Utf8(Some(ref path)), _) => {
filepath.clone_from(path);
}
expr => new_exprs.push(expr.clone()),
Expand Down
4 changes: 3 additions & 1 deletion datafusion/datasource-parquet/src/row_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ mod test {
// Test all should fail
let expr = col("timestamp_col").lt(Expr::Literal(
ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))),
None,
));
let expr = logical2physical(&expr, &table_schema);
let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory);
Expand Down Expand Up @@ -597,6 +598,7 @@ mod test {
// Test all should pass
let expr = col("timestamp_col").gt(Expr::Literal(
ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))),
None,
));
let expr = logical2physical(&expr, &table_schema);
let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory);
Expand Down Expand Up @@ -660,7 +662,7 @@ mod test {

let expr = col("string_col")
.is_not_null()
.or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5)))));
.or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5)), None)));
let expr = logical2physical(&expr, &table_schema);

assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema));
Expand Down
Loading