Skip to content

Commit

Permalink
Adds field method to WindowUDFImpl trait
Browse files Browse the repository at this point in the history
  • Loading branch information
jcsherin committed Sep 3, 2024
1 parent f09383c commit ca256fa
Show file tree
Hide file tree
Showing 13 changed files with 54 additions and 10 deletions.
1 change: 1 addition & 0 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ bigdecimal = { workspace = true }
criterion = { version = "0.5", features = ["async_tokio"] }
csv = "1.1.6"
ctor = { workspace = true }
datafusion-functions-window-common = {workspace = true}
doc-comment = { workspace = true }
env_logger = { workspace = true }
half = { workspace = true, default-features = true }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ use std::{

use arrow::array::AsArray;
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray};
use arrow_schema::DataType;
use arrow_schema::{DataType, Field};
use datafusion::{assert_batches_eq, prelude::SessionContext};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
};
use datafusion_functions_window_common::field::FieldArgs;

/// A query with a window function evaluated over the entire partition
const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \
Expand Down Expand Up @@ -565,6 +566,10 @@ impl OddCounter {
fn aliases(&self) -> &[String] {
&self.aliases
}

fn field(&self, field_args: FieldArgs) -> Result<Field> {
Ok(Field::new(field_args.display_name, field_args.return_type, true))
}
}

ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ chrono = { workspace = true }
datafusion-common = { workspace = true }
datafusion-expr-common = { workspace = true }
datafusion-functions-aggregate-common = { workspace = true }
datafusion-functions-window-common = {workspace = true}
datafusion-physical-expr-common = { workspace = true }
paste = "^1.0"
serde_json = { workspace = true }
Expand Down
5 changes: 5 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use arrow::compute::kernels::cast_utils::{
};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference};
use datafusion_functions_window_common::field::FieldArgs;
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::Debug;
Expand Down Expand Up @@ -665,6 +666,10 @@ impl WindowUDFImpl for SimpleWindowUDF {
fn partition_evaluator(&self) -> Result<Box<dyn crate::PartitionEvaluator>> {
(self.partition_evaluator_factory)()
}

fn field(&self, field_args: FieldArgs) -> Result<Field> {
Ok(Field::new(field_args.display_name, field_args.return_type, true))
}
}

pub fn interval_year_month_lit(value: &str) -> Expr {
Expand Down
17 changes: 16 additions & 1 deletion datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ use std::{
sync::Arc,
};

use arrow::datatypes::DataType;
use arrow::datatypes::{DataType, Field};

use datafusion_common::{not_impl_err, Result};
use datafusion_functions_window_common::field::FieldArgs;

use crate::expr::WindowFunction;
use crate::{
Expand Down Expand Up @@ -185,6 +186,10 @@ impl WindowUDF {
self.inner.nullable()
}

pub fn field(&self, field_args: FieldArgs) -> Result<Field> {
self.inner.field(field_args)
}

/// Returns custom result ordering introduced by this window function
/// which is used to update ordering equivalences.
///
Expand Down Expand Up @@ -353,6 +358,8 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
true
}

fn field(&self, field_args: FieldArgs) -> Result<Field>;

/// Allows the window UDF to define a custom result ordering.
///
/// By default, a window UDF doesn't introduce an ordering.
Expand Down Expand Up @@ -454,6 +461,10 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
self.inner.nullable()
}

fn field(&self, field_args: FieldArgs) -> Result<Field> {
self.inner.field(field_args)
}

fn sort_options(&self) -> Option<SortOptions> {
self.inner.sort_options()
}
Expand Down Expand Up @@ -510,4 +521,8 @@ impl WindowUDFImpl for WindowUDFLegacyWrapper {
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
(self.partition_evaluator_factory)()
}

fn field(&self, field_args: FieldArgs) -> Result<Field> {
Ok(Field::new(field_args.display_name, field_args.return_type, true))
}
}
3 changes: 2 additions & 1 deletion datafusion/functions-window-common/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use datafusion_common::arrow::datatypes::DataType;

pub struct FieldArgs {
pub struct FieldArgs<'a> {
pub return_type: DataType,
pub display_name: &'a str,
}
1 change: 1 addition & 0 deletions datafusion/functions-window/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ path = "src/lib.rs"
[dependencies]
datafusion-common = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions-window-common = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
log = { workspace = true }

Expand Down
7 changes: 7 additions & 0 deletions datafusion/functions-window/src/row_number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@ use datafusion_common::arrow::array::ArrayRef;
use datafusion_common::arrow::array::UInt64Array;
use datafusion_common::arrow::compute::SortOptions;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_common::arrow::datatypes::Field;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{Expr, PartitionEvaluator, Signature, Volatility, WindowUDFImpl};
use datafusion_functions_window_common::field;
use field::FieldArgs;

/// Create a [`WindowFunction`](Expr::WindowFunction) expression for
/// `row_number` user-defined window function.
Expand Down Expand Up @@ -96,6 +99,10 @@ impl WindowUDFImpl for RowNumber {
false
}

fn field(&self, field_args: FieldArgs) -> Result<Field> {
Ok(Field::new(field_args.display_name, field_args.return_type, false))
}

fn sort_options(&self) -> Option<SortOptions> {
Some(SortOptions {
descending: false,
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,6 @@ regex-syntax = "0.8.0"
arrow-buffer = { workspace = true }
ctor = { workspace = true }
datafusion-functions-aggregate = { workspace = true }
datafusion-functions-window-common = {workspace = true}
datafusion-sql = { workspace = true }
env_logger = { workspace = true }
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,7 @@ mod tests {
interval_arithmetic::Interval,
*,
};
use datafusion_functions_window_common::field::FieldArgs;
use std::{
collections::HashMap,
ops::{BitAnd, BitOr, BitXor},
Expand Down Expand Up @@ -3916,5 +3917,9 @@ mod tests {
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
unimplemented!("not needed for tests")
}

fn field(&self, _field_args: FieldArgs) -> Result<Field> {
unimplemented!("not needed for tests")
}
}
}
10 changes: 3 additions & 7 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,10 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr {
}

fn field(&self) -> Result<Field> {
let _field_args = FieldArgs {
self.fun.field(FieldArgs {
return_type: self.data_type.clone(),
};
Ok(Field::new(
&self.name,
self.data_type.clone(),
self.fun.nullable(),
))
display_name: &self.name,
})
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ serde_json = { workspace = true, optional = true }
[dev-dependencies]
datafusion-functions = { workspace = true, default-features = true }
datafusion-functions-aggregate = { workspace = true }
datafusion-functions-window-common = {workspace = true}
doc-comment = { workspace = true }
strum = { version = "0.26.1", features = ["derive"] }
tokio = { workspace = true, features = ["rt-multi-thread"] }
5 changes: 5 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ use datafusion_functions_aggregate::expr_fn::{
nth_value,
};
use datafusion_functions_aggregate::string_agg::string_agg;
use datafusion_functions_window_common::field::FieldArgs;
use datafusion_proto::bytes::{
logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec,
logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec,
Expand Down Expand Up @@ -2417,6 +2418,10 @@ fn roundtrip_window() {
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
make_partition_evaluator()
}

fn field(&self, field_args: FieldArgs) -> Result<Field> {
Ok(Field::new(field_args.display_name, field_args.return_type, true))
}
}

fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
Expand Down

0 comments on commit ca256fa

Please sign in to comment.