Skip to content

Create datafusion-functions crate, extract encode and decode to #8705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ jobs:

- name: Check workspace with all features
run: cargo check --workspace --benches --features avro,json

# Ensure that the datafusion crate can be built with only a subset of the function
# packages enabled.
- name: Check function packages (encoding_expressions)
run: cargo check --no-default-features --features=encoding_expressions -p datafusion

- name: Check Cargo.lock for datafusion-cli
run: |
# If this test fails, try running `cargo update` in the `datafusion-cli` directory
Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

[workspace]
exclude = ["datafusion-cli"]
members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks",
members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/functions", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks",
]
resolver = "2"

Expand Down Expand Up @@ -49,6 +49,7 @@ datafusion = { path = "datafusion/core", version = "35.0.0" }
datafusion-common = { path = "datafusion/common", version = "35.0.0" }
datafusion-execution = { path = "datafusion/execution", version = "35.0.0" }
datafusion-expr = { path = "datafusion/expr", version = "35.0.0" }
datafusion-functions = { path = "datafusion/functions", version = "35.0.0" }
datafusion-optimizer = { path = "datafusion/optimizer", version = "35.0.0" }
datafusion-physical-expr = { path = "datafusion/physical-expr", version = "35.0.0" }
datafusion-physical-plan = { path = "datafusion/physical-plan", version = "35.0.0" }
Expand Down
14 changes: 14 additions & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ backtrace = ["datafusion-common/backtrace"]
compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression", "tokio-util"]
crypto_expressions = ["datafusion-physical-expr/crypto_expressions", "datafusion-optimizer/crypto_expressions"]
default = ["crypto_expressions", "encoding_expressions", "regex_expressions", "unicode_expressions", "compression", "parquet"]
encoding_expressions = ["datafusion-physical-expr/encoding_expressions"]
encoding_expressions = ["datafusion-functions/encoding_expressions"]
# Used for testing ONLY: causes all values to hash to the same value (test for collisions)
force_hash_collisions = []
parquet = ["datafusion-common/parquet", "dep:parquet"]
Expand All @@ -65,6 +65,7 @@ dashmap = { workspace = true }
datafusion-common = { path = "../common", version = "35.0.0", features = ["object_store"], default-features = false }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions = { path = "../functions", version = "35.0.0" }
datafusion-optimizer = { path = "../optimizer", version = "35.0.0", default-features = false }
datafusion-physical-expr = { path = "../physical-expr", version = "35.0.0", default-features = false }
datafusion-physical-plan = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ use crate::physical_plan::{
collect, collect_partitioned, execute_stream, execute_stream_partitioned,
ExecutionPlan, SendableRecordBatchStream,
};
use crate::prelude::SessionContext;

use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
Expand All @@ -59,6 +58,7 @@ use datafusion_expr::{
TableProviderFilterPushDown, UNNAMED_TABLE,
};

use crate::prelude::SessionContext;
use async_trait::async_trait;

/// Contains options that control how data is
Expand Down
14 changes: 12 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,7 @@ impl SessionState {
);
}

SessionState {
let mut new_self = SessionState {
session_id,
analyzer: Analyzer::new(),
optimizer: Optimizer::new(),
Expand All @@ -1356,7 +1356,13 @@ impl SessionState {
execution_props: ExecutionProps::new(),
runtime_env: runtime,
table_factories,
}
};

// register built in functions
datafusion_functions::register_all(&mut new_self)
.expect("can not register built in functions");

new_self
}
/// Returns new [`SessionState`] using the provided
/// [`SessionConfig`] and [`RuntimeEnv`].
Expand Down Expand Up @@ -1976,6 +1982,10 @@ impl FunctionRegistry for SessionState {
plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry")
})
}

fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
Ok(self.scalar_functions.insert(udf.name().into(), udf))
}
}

impl OptimizerConfig for SessionState {
Expand Down
5 changes: 5 additions & 0 deletions datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,11 @@ pub mod sql {
pub use datafusion_sql::*;
}

/// re-export of [`datafusion_functions`] crate
pub mod functions {
pub use datafusion_functions::*;
}

#[cfg(test)]
pub mod test;
pub mod test_util;
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub use datafusion_expr::{
logical_plan::{JoinType, Partitioning},
Expr,
};
pub use datafusion_functions::expr_fn::*;

pub use std::ops::Not;
pub use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
Expand Down
59 changes: 55 additions & 4 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use arrow::{
array::{Int32Array, StringArray},
record_batch::RecordBatch,
};
use arrow_schema::SchemaRef;
use std::sync::Arc;

use datafusion::dataframe::DataFrame;
Expand All @@ -31,14 +32,19 @@ use datafusion::prelude::*;
use datafusion::execution::context::SessionContext;

use datafusion::assert_batches_eq;
use datafusion_common::DFSchema;
use datafusion_expr::expr::Alias;
use datafusion_expr::{approx_median, cast};
use datafusion_expr::{approx_median, cast, ExprSchemable};

async fn create_test_table() -> Result<DataFrame> {
let schema = Arc::new(Schema::new(vec![
fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int32, false),
]));
]))
}

async fn create_test_table() -> Result<DataFrame> {
let schema = test_schema();

// define data.
let batch = RecordBatch::try_new(
Expand Down Expand Up @@ -790,3 +796,48 @@ async fn test_fn_upper() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_fn_encode() -> Result<()> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these functions weren't previously tested. I added tests to show that it is still possible to use the fluent API even with the moving of the expr_fn implementation

let expr = encode(col("a"), lit("hex"));

let expected = [
"+----------------------------+",
"| encode(test.a,Utf8(\"hex\")) |",
"+----------------------------+",
"| 616263444546 |",
"| 616263313233 |",
"| 434241646566 |",
"| 313233416263446566 |",
"+----------------------------+",
];
assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
async fn test_fn_decode() -> Result<()> {
// Note that the decode function returns binary, and the default display of
// binary is "hexadecimal" and therefore the output looks like decode did
// nothing. So compare to a constant.
let df_schema = DFSchema::try_from(test_schema().as_ref().clone())?;
let expr = decode(encode(col("a"), lit("hex")), lit("hex"))
// need to cast to utf8 otherwise the default display of binary array is hex
// so it looks like nothing is done
.cast_to(&DataType::Utf8, &df_schema)?;

let expected = [
"+------------------------------------------------+",
"| decode(encode(test.a,Utf8(\"hex\")),Utf8(\"hex\")) |",
"+------------------------------------------------+",
"| abcDEF |",
"| abc123 |",
"| CBAdef |",
"| 123AbcDef |",
"+------------------------------------------------+",
];
assert_fn_batches!(expr, expected);

Ok(())
}
62 changes: 61 additions & 1 deletion datafusion/execution/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

//! FunctionRegistry trait

use datafusion_common::Result;
use datafusion_common::{not_impl_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use std::collections::HashMap;
use std::{collections::HashSet, sync::Arc};

/// A registry knows how to build logical expressions out of user-defined function' names
Expand All @@ -34,6 +35,17 @@ pub trait FunctionRegistry {

/// Returns a reference to the udwf named `name`.
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;

/// Registers a new [`ScalarUDF`], returning any previously registered
/// implementation.
///
/// Returns an error (the default) if the function can not be registered,
/// for example if the registry is read only.
fn register_udf(&mut self, _udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
not_impl_err!("Registering ScalarUDF")
}

// TODO add register_udaf and register_udwf
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will file follow on tickets for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
Expand All @@ -53,3 +65,51 @@ pub trait SerializerRegistry: Send + Sync {
bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
}

/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
#[derive(Default, Debug)]
pub struct MemoryFunctionRegistry {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently only used for testing, but I plan to reduce the duplication between SessionContext, and TaskContext using it

/// Scalar Functions
udfs: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate Functions
udafs: HashMap<String, Arc<AggregateUDF>>,
/// Window Functions
udwfs: HashMap<String, Arc<WindowUDF>>,
}

impl MemoryFunctionRegistry {
pub fn new() -> Self {
Self::default()
}
}

impl FunctionRegistry for MemoryFunctionRegistry {
fn udfs(&self) -> HashSet<String> {
self.udfs.keys().cloned().collect()
}

fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
self.udfs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Function {name} not found"))
}

fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.udafs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Aggregate Function {name} not found"))
}

fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.udwfs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Window Function {name} not found"))
}

fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
Ok(self.udfs.insert(udf.name().to_string(), udf))
}
}
52 changes: 0 additions & 52 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,10 @@ pub enum BuiltinScalarFunction {
Cos,
/// cos
Cosh,
/// Decode
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key point of this PR: to begin removing items from this (giant) enum

Decode,
/// degrees
Degrees,
/// Digest
Digest,
/// Encode
Encode,
/// exp
Exp,
/// factorial
Expand Down Expand Up @@ -381,9 +377,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Cos => Volatility::Immutable,
BuiltinScalarFunction::Cosh => Volatility::Immutable,
BuiltinScalarFunction::Decode => Volatility::Immutable,
BuiltinScalarFunction::Degrees => Volatility::Immutable,
BuiltinScalarFunction::Encode => Volatility::Immutable,
BuiltinScalarFunction::Exp => Volatility::Immutable,
BuiltinScalarFunction::Factorial => Volatility::Immutable,
BuiltinScalarFunction::Floor => Volatility::Immutable,
Expand Down Expand Up @@ -774,30 +768,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Digest => {
utf8_or_binary_to_binary_type(&input_expr_types[0], "digest")
}
BuiltinScalarFunction::Encode => Ok(match input_expr_types[0] {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is moved into the ScalarUDFImpl for Encode/Decod

Utf8 => Utf8,
LargeUtf8 => LargeUtf8,
Binary => Utf8,
LargeBinary => LargeUtf8,
Null => Null,
_ => {
return plan_err!(
"The encode function can only accept utf8 or binary."
);
}
}),
BuiltinScalarFunction::Decode => Ok(match input_expr_types[0] {
Utf8 => Binary,
LargeUtf8 => LargeBinary,
Binary => Binary,
LargeBinary => LargeBinary,
Null => Null,
_ => {
return plan_err!(
"The decode function can only accept utf8 or binary."
);
}
}),
BuiltinScalarFunction::SplitPart => {
utf8_to_str_type(&input_expr_types[0], "split_part")
}
Expand Down Expand Up @@ -1089,24 +1059,6 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Encode => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Exact(vec![LargeBinary, Utf8]),
],
self.volatility(),
),
BuiltinScalarFunction::Decode => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Exact(vec![LargeBinary, Utf8]),
],
self.volatility(),
),
BuiltinScalarFunction::DateTrunc => Signature::one_of(
vec![
Exact(vec![Utf8, Timestamp(Nanosecond, None)]),
Expand Down Expand Up @@ -1551,10 +1503,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::SHA384 => &["sha384"],
BuiltinScalarFunction::SHA512 => &["sha512"],

// encode/decode
BuiltinScalarFunction::Encode => &["encode"],
BuiltinScalarFunction::Decode => &["decode"],

// other functions
BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"],

Expand Down
Loading