Skip to content

Commit

Permalink
Create datafusion-functions crate, extract encode and decode to (#8705
Browse files Browse the repository at this point in the history
)

* Extract encode and decode to `datafusion-functions` crate

* better docs

* Improve docs + macros

* tweaks

* updates

* fix doc

* tomlfmt

* fix doc

* update datafusion-cli Cargo.locl

* update datafusion-cli cargo.lock

* Remove outdated comment, make non pub

* Apply suggestions from code review

Co-authored-by: Liang-Chi Hsieh <viirya@gmail.com>

---------

Co-authored-by: Liang-Chi Hsieh <viirya@gmail.com>
  • Loading branch information
alamb and viirya authored Jan 30, 2024
1 parent 262d093 commit d6d35f7
Show file tree
Hide file tree
Showing 26 changed files with 639 additions and 159 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ jobs:

- name: Check workspace with all features
run: cargo check --workspace --benches --features avro,json

# Ensure that the datafusion crate can be built with only a subset of the function
# packages enabled.
- name: Check function packages (encoding_expressions)
run: cargo check --no-default-features --features=encoding_expressions -p datafusion

- name: Check Cargo.lock for datafusion-cli
run: |
# If this test fails, try running `cargo update` in the `datafusion-cli` directory
Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

[workspace]
exclude = ["datafusion-cli"]
members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks",
members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/functions", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks",
]
resolver = "2"

Expand Down Expand Up @@ -49,6 +49,7 @@ datafusion = { path = "datafusion/core", version = "35.0.0" }
datafusion-common = { path = "datafusion/common", version = "35.0.0" }
datafusion-execution = { path = "datafusion/execution", version = "35.0.0" }
datafusion-expr = { path = "datafusion/expr", version = "35.0.0" }
datafusion-functions = { path = "datafusion/functions", version = "35.0.0" }
datafusion-optimizer = { path = "datafusion/optimizer", version = "35.0.0" }
datafusion-physical-expr = { path = "datafusion/physical-expr", version = "35.0.0" }
datafusion-physical-plan = { path = "datafusion/physical-plan", version = "35.0.0" }
Expand Down
14 changes: 14 additions & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ backtrace = ["datafusion-common/backtrace"]
compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression", "tokio-util"]
crypto_expressions = ["datafusion-physical-expr/crypto_expressions", "datafusion-optimizer/crypto_expressions"]
default = ["crypto_expressions", "encoding_expressions", "regex_expressions", "unicode_expressions", "compression", "parquet"]
encoding_expressions = ["datafusion-physical-expr/encoding_expressions"]
encoding_expressions = ["datafusion-functions/encoding_expressions"]
# Used for testing ONLY: causes all values to hash to the same value (test for collisions)
force_hash_collisions = []
parquet = ["datafusion-common/parquet", "dep:parquet"]
Expand All @@ -65,6 +65,7 @@ dashmap = { workspace = true }
datafusion-common = { path = "../common", version = "35.0.0", features = ["object_store"], default-features = false }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions = { path = "../functions", version = "35.0.0" }
datafusion-optimizer = { path = "../optimizer", version = "35.0.0", default-features = false }
datafusion-physical-expr = { path = "../physical-expr", version = "35.0.0", default-features = false }
datafusion-physical-plan = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ use crate::physical_plan::{
collect, collect_partitioned, execute_stream, execute_stream_partitioned,
ExecutionPlan, SendableRecordBatchStream,
};
use crate::prelude::SessionContext;

use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
Expand All @@ -59,6 +58,7 @@ use datafusion_expr::{
TableProviderFilterPushDown, UNNAMED_TABLE,
};

use crate::prelude::SessionContext;
use async_trait::async_trait;

/// Contains options that control how data is
Expand Down
14 changes: 12 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,7 @@ impl SessionState {
);
}

SessionState {
let mut new_self = SessionState {
session_id,
analyzer: Analyzer::new(),
optimizer: Optimizer::new(),
Expand All @@ -1356,7 +1356,13 @@ impl SessionState {
execution_props: ExecutionProps::new(),
runtime_env: runtime,
table_factories,
}
};

// register built in functions
datafusion_functions::register_all(&mut new_self)
.expect("can not register built in functions");

new_self
}
/// Returns new [`SessionState`] using the provided
/// [`SessionConfig`] and [`RuntimeEnv`].
Expand Down Expand Up @@ -1976,6 +1982,10 @@ impl FunctionRegistry for SessionState {
plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry")
})
}

fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
Ok(self.scalar_functions.insert(udf.name().into(), udf))
}
}

impl OptimizerConfig for SessionState {
Expand Down
5 changes: 5 additions & 0 deletions datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,11 @@ pub mod sql {
pub use datafusion_sql::*;
}

/// re-export of [`datafusion_functions`] crate
pub mod functions {
pub use datafusion_functions::*;
}

#[cfg(test)]
pub mod test;
pub mod test_util;
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub use datafusion_expr::{
logical_plan::{JoinType, Partitioning},
Expr,
};
pub use datafusion_functions::expr_fn::*;

pub use std::ops::Not;
pub use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
Expand Down
59 changes: 55 additions & 4 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use arrow::{
array::{Int32Array, StringArray},
record_batch::RecordBatch,
};
use arrow_schema::SchemaRef;
use std::sync::Arc;

use datafusion::dataframe::DataFrame;
Expand All @@ -31,14 +32,19 @@ use datafusion::prelude::*;
use datafusion::execution::context::SessionContext;

use datafusion::assert_batches_eq;
use datafusion_common::DFSchema;
use datafusion_expr::expr::Alias;
use datafusion_expr::{approx_median, cast};
use datafusion_expr::{approx_median, cast, ExprSchemable};

async fn create_test_table() -> Result<DataFrame> {
let schema = Arc::new(Schema::new(vec![
fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int32, false),
]));
]))
}

async fn create_test_table() -> Result<DataFrame> {
let schema = test_schema();

// define data.
let batch = RecordBatch::try_new(
Expand Down Expand Up @@ -790,3 +796,48 @@ async fn test_fn_upper() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_fn_encode() -> Result<()> {
let expr = encode(col("a"), lit("hex"));

let expected = [
"+----------------------------+",
"| encode(test.a,Utf8(\"hex\")) |",
"+----------------------------+",
"| 616263444546 |",
"| 616263313233 |",
"| 434241646566 |",
"| 313233416263446566 |",
"+----------------------------+",
];
assert_fn_batches!(expr, expected);

Ok(())
}

#[tokio::test]
async fn test_fn_decode() -> Result<()> {
// Note that the decode function returns binary, and the default display of
// binary is "hexadecimal" and therefore the output looks like decode did
// nothing. So compare to a constant.
let df_schema = DFSchema::try_from(test_schema().as_ref().clone())?;
let expr = decode(encode(col("a"), lit("hex")), lit("hex"))
// need to cast to utf8 otherwise the default display of binary array is hex
// so it looks like nothing is done
.cast_to(&DataType::Utf8, &df_schema)?;

let expected = [
"+------------------------------------------------+",
"| decode(encode(test.a,Utf8(\"hex\")),Utf8(\"hex\")) |",
"+------------------------------------------------+",
"| abcDEF |",
"| abc123 |",
"| CBAdef |",
"| 123AbcDef |",
"+------------------------------------------------+",
];
assert_fn_batches!(expr, expected);

Ok(())
}
62 changes: 61 additions & 1 deletion datafusion/execution/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

//! FunctionRegistry trait

use datafusion_common::Result;
use datafusion_common::{not_impl_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use std::collections::HashMap;
use std::{collections::HashSet, sync::Arc};

/// A registry knows how to build logical expressions out of user-defined function' names
Expand All @@ -34,6 +35,17 @@ pub trait FunctionRegistry {

/// Returns a reference to the udwf named `name`.
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;

/// Registers a new [`ScalarUDF`], returning any previously registered
/// implementation.
///
/// Returns an error (the default) if the function can not be registered,
/// for example if the registry is read only.
fn register_udf(&mut self, _udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
not_impl_err!("Registering ScalarUDF")
}

// TODO add register_udaf and register_udwf
}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
Expand All @@ -53,3 +65,51 @@ pub trait SerializerRegistry: Send + Sync {
bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
}

/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
#[derive(Default, Debug)]
pub struct MemoryFunctionRegistry {
/// Scalar Functions
udfs: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate Functions
udafs: HashMap<String, Arc<AggregateUDF>>,
/// Window Functions
udwfs: HashMap<String, Arc<WindowUDF>>,
}

impl MemoryFunctionRegistry {
pub fn new() -> Self {
Self::default()
}
}

impl FunctionRegistry for MemoryFunctionRegistry {
fn udfs(&self) -> HashSet<String> {
self.udfs.keys().cloned().collect()
}

fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
self.udfs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Function {name} not found"))
}

fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.udafs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Aggregate Function {name} not found"))
}

fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.udwfs
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("Window Function {name} not found"))
}

fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
Ok(self.udfs.insert(udf.name().to_string(), udf))
}
}
52 changes: 0 additions & 52 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,10 @@ pub enum BuiltinScalarFunction {
Cos,
/// cos
Cosh,
/// Decode
Decode,
/// degrees
Degrees,
/// Digest
Digest,
/// Encode
Encode,
/// exp
Exp,
/// factorial
Expand Down Expand Up @@ -381,9 +377,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Cos => Volatility::Immutable,
BuiltinScalarFunction::Cosh => Volatility::Immutable,
BuiltinScalarFunction::Decode => Volatility::Immutable,
BuiltinScalarFunction::Degrees => Volatility::Immutable,
BuiltinScalarFunction::Encode => Volatility::Immutable,
BuiltinScalarFunction::Exp => Volatility::Immutable,
BuiltinScalarFunction::Factorial => Volatility::Immutable,
BuiltinScalarFunction::Floor => Volatility::Immutable,
Expand Down Expand Up @@ -774,30 +768,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Digest => {
utf8_or_binary_to_binary_type(&input_expr_types[0], "digest")
}
BuiltinScalarFunction::Encode => Ok(match input_expr_types[0] {
Utf8 => Utf8,
LargeUtf8 => LargeUtf8,
Binary => Utf8,
LargeBinary => LargeUtf8,
Null => Null,
_ => {
return plan_err!(
"The encode function can only accept utf8 or binary."
);
}
}),
BuiltinScalarFunction::Decode => Ok(match input_expr_types[0] {
Utf8 => Binary,
LargeUtf8 => LargeBinary,
Binary => Binary,
LargeBinary => LargeBinary,
Null => Null,
_ => {
return plan_err!(
"The decode function can only accept utf8 or binary."
);
}
}),
BuiltinScalarFunction::SplitPart => {
utf8_to_str_type(&input_expr_types[0], "split_part")
}
Expand Down Expand Up @@ -1089,24 +1059,6 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Encode => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Exact(vec![LargeBinary, Utf8]),
],
self.volatility(),
),
BuiltinScalarFunction::Decode => Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Exact(vec![LargeBinary, Utf8]),
],
self.volatility(),
),
BuiltinScalarFunction::DateTrunc => Signature::one_of(
vec![
Exact(vec![Utf8, Timestamp(Nanosecond, None)]),
Expand Down Expand Up @@ -1551,10 +1503,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::SHA384 => &["sha384"],
BuiltinScalarFunction::SHA512 => &["sha512"],

// encode/decode
BuiltinScalarFunction::Encode => &["encode"],
BuiltinScalarFunction::Decode => &["decode"],

// other functions
BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"],

Expand Down
Loading

0 comments on commit d6d35f7

Please sign in to comment.