Skip to content

Commit

Permalink
Convert nth_value builtIn function to User Defined Window Function (#…
Browse files Browse the repository at this point in the history
…13201)

* refactored nth_value

* continue

* test

* proto and rustlint

* fix datatype

* cont

* cont

* apply jcsherins early validation

* docs

* doc

* Apply suggestions from code review

Co-authored-by: Sherin Jacob <jacob@protoship.io>

* passes lint but does not have tests

* continue

* Update roundtrip_physical_plan.rs

* udwf, not udaf

* fix bounded but not fixed roundtrip

* added

* Update datafusion/sqllogictest/test_files/errors.slt

Co-authored-by: Sherin Jacob <jacob@protoship.io>

---------

Co-authored-by: Sherin Jacob <jacob@protoship.io>
Co-authored-by: berkaysynnada <berkay.sahin@synnada.ai>
Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
4 people authored Nov 13, 2024
1 parent 4e1f839 commit 54ab128
Show file tree
Hide file tree
Showing 27 changed files with 728 additions and 828 deletions.
10 changes: 4 additions & 6 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1946,12 +1946,12 @@ mod tests {
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{
cast, create_udf, lit, BuiltInWindowFunction, ExprFunctionExt,
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
cast, create_udf, lit, ExprFunctionExt, ScalarFunctionImplementation, Volatility,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_functions_window::expr_fn::row_number;
use datafusion_functions_window::nth_value::first_value_udwf;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
use sqlparser::ast::NullTreatment;
Expand Down Expand Up @@ -2177,9 +2177,7 @@ mod tests {
// build plan using Table API
let t = test_table().await?;
let first_row = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::FirstValue,
),
WindowFunctionDefinition::WindowUDF(first_value_udwf()),
vec![col("aggregate_test_100.c1")],
))
.partition_by(vec![col("aggregate_test_100.c2")])
Expand Down
18 changes: 7 additions & 11 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ use datafusion_common::{Result, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
use datafusion_expr::{
BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf};
Expand All @@ -47,6 +46,9 @@ use test_utils::add_empty_batches;
use datafusion::functions_window::row_number::row_number_udwf;
use datafusion_common::HashMap;
use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf};
use datafusion_functions_window::nth_value::{
first_value_udwf, last_value_udwf, nth_value_udwf,
};
use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use rand::distributions::Alphanumeric;
Expand Down Expand Up @@ -418,27 +420,21 @@ fn get_random_function(
window_fn_map.insert(
"first_value",
(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::FirstValue,
),
WindowFunctionDefinition::WindowUDF(first_value_udwf()),
vec![arg.clone()],
),
);
window_fn_map.insert(
"last_value",
(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::LastValue,
),
WindowFunctionDefinition::WindowUDF(last_value_udwf()),
vec![arg.clone()],
),
);
window_fn_map.insert(
"nth_value",
(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::NthValue,
),
WindowFunctionDefinition::WindowUDF(nth_value_udwf()),
vec![
arg.clone(),
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
Expand Down
89 changes: 0 additions & 89 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use std::collections::HashSet;
use std::fmt::{self, Display, Formatter, Write};
use std::hash::{Hash, Hasher};
use std::mem;
use std::str::FromStr;
use std::sync::Arc;

use crate::expr_fn::binary_expr;
Expand Down Expand Up @@ -832,23 +831,6 @@ impl WindowFunction {
}
}

/// Find DataFusion's built-in window function by name.
pub fn find_df_window_func(name: &str) -> Option<WindowFunctionDefinition> {
let name = name.to_lowercase();
// Code paths for window functions leveraging ordinary aggregators and
// built-in window functions are quite different, and the same function
// may have different implementations for these cases. If the sought
// function is not found among built-in window functions, we search for
// it among aggregate functions.
if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) {
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_function,
))
} else {
None
}
}

/// EXISTS expression
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Exists {
Expand Down Expand Up @@ -2548,77 +2530,6 @@ mod test {

use super::*;

#[test]
fn test_first_value_return_type() -> Result<()> {
let fun = find_df_window_func("first_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::UInt64], &[true], "")?;
assert_eq!(DataType::UInt64, observed);

Ok(())
}

#[test]
fn test_last_value_return_type() -> Result<()> {
let fun = find_df_window_func("last_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
}

#[test]
fn test_nth_value_return_type() -> Result<()> {
let fun = find_df_window_func("nth_value").unwrap();
let observed =
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed =
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
}

#[test]
fn test_window_function_case_insensitive() -> Result<()> {
let names = vec!["first_value", "last_value", "nth_value"];
for name in names {
let fun = find_df_window_func(name).unwrap();
let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap();
assert_eq!(fun, fun2);
if fun.to_string() == "first_value" || fun.to_string() == "last_value" {
assert_eq!(fun.to_string(), name);
} else {
assert_eq!(fun.to_string(), name.to_uppercase());
}
}
Ok(())
}

#[test]
fn test_find_df_window_function() {
assert_eq!(
find_df_window_func("first_value"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::FirstValue
))
);
assert_eq!(
find_df_window_func("LAST_value"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::LastValue
))
);
assert_eq!(find_df_window_func("not_exist"), None)
}

#[test]
fn test_display_wildcard() {
assert_eq!(format!("{}", wildcard()), "*");
Expand Down
1 change: 0 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ pub mod type_coercion;
pub mod utils;
pub mod var_provider;
pub mod window_frame;
pub mod window_function;
pub mod window_state;

pub use built_in_window_function::BuiltInWindowFunction;
Expand Down
26 changes: 0 additions & 26 deletions datafusion/expr/src/window_function.rs

This file was deleted.

6 changes: 6 additions & 0 deletions datafusion/functions-window/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
//!
//! [DataFusion]: https://crates.io/crates/datafusion
//!

use std::sync::Arc;

use log::debug;
Expand All @@ -34,6 +35,7 @@ pub mod macros;

pub mod cume_dist;
pub mod lead_lag;
pub mod nth_value;
pub mod ntile;
pub mod rank;
pub mod row_number;
Expand All @@ -44,6 +46,7 @@ pub mod expr_fn {
pub use super::cume_dist::cume_dist;
pub use super::lead_lag::lag;
pub use super::lead_lag::lead;
pub use super::nth_value::{first_value, last_value, nth_value};
pub use super::ntile::ntile;
pub use super::rank::{dense_rank, percent_rank, rank};
pub use super::row_number::row_number;
Expand All @@ -60,6 +63,9 @@ pub fn all_default_window_functions() -> Vec<Arc<WindowUDF>> {
rank::dense_rank_udwf(),
rank::percent_rank_udwf(),
ntile::ntile_udwf(),
nth_value::first_value_udwf(),
nth_value::last_value_udwf(),
nth_value::nth_value_udwf(),
]
}
/// Registers all enabled packages with a [`FunctionRegistry`]
Expand Down
Loading

0 comments on commit 54ab128

Please sign in to comment.