diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index d384e4bc7ebf..501b05c25d8e 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -76,6 +76,12 @@ jobs: - name: Check workspace with all features run: cargo check --workspace --benches --features avro,json + + # Ensure that the datafusion crate can be built with only a subset of the function + # packages enabled. + - name: Check function packages (encoding_expressions) + run: cargo check --no-default-features --features=encoding_expressions -p datafusion + - name: Check Cargo.lock for datafusion-cli run: | # If this test fails, try running `cargo update` in the `datafusion-cli` directory diff --git a/Cargo.toml b/Cargo.toml index 89018aab7606..d56d37ad2b35 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ [workspace] exclude = ["datafusion-cli"] -members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks", +members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/execution", "datafusion/functions", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", "docs", "test-utils", "benchmarks", ] resolver = "2" @@ -49,6 +49,7 @@ datafusion = { path = "datafusion/core", version = "35.0.0" } datafusion-common = { path = "datafusion/common", version = "35.0.0" } datafusion-execution = { path = "datafusion/execution", version = "35.0.0" } datafusion-expr = { path = "datafusion/expr", version = "35.0.0" } +datafusion-functions = { path = "datafusion/functions", version = "35.0.0" } datafusion-optimizer = { path = "datafusion/optimizer", version = "35.0.0" } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "35.0.0" } datafusion-physical-plan = { path = "datafusion/physical-plan", version = "35.0.0" } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 4ebd2badaf9e..e89a8f172f74 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1115,6 +1115,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-plan", @@ -1224,6 +1225,19 @@ dependencies = [ "strum_macros 0.26.1", ] +[[package]] +name = "datafusion-functions" +version = "35.0.0" +dependencies = [ + "arrow", + "base64", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "hex", + "log", +] + [[package]] name = "datafusion-optimizer" version = "35.0.0" diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 69b18a326951..f9a4c54b7dc6 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -40,7 +40,7 @@ backtrace = ["datafusion-common/backtrace"] compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression", "tokio-util"] crypto_expressions = ["datafusion-physical-expr/crypto_expressions", "datafusion-optimizer/crypto_expressions"] default = ["crypto_expressions", "encoding_expressions", "regex_expressions", "unicode_expressions", "compression", "parquet"] -encoding_expressions = ["datafusion-physical-expr/encoding_expressions"] +encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] parquet = ["datafusion-common/parquet", "dep:parquet"] @@ -65,6 +65,7 @@ dashmap = { workspace = true } datafusion-common = { path = "../common", version = "35.0.0", features = ["object_store"], default-features = false } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions = { path = "../functions", version = "35.0.0" } datafusion-optimizer = { path = "../optimizer", version = "35.0.0", default-features = false } datafusion-physical-expr = { path = "../physical-expr", version = "35.0.0", default-features = false } datafusion-physical-plan = { workspace = true } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 4fd543f0eab8..f4023642ef04 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -40,7 +40,6 @@ use crate::physical_plan::{ collect, collect_partitioned, execute_stream, execute_stream_partitioned, ExecutionPlan, SendableRecordBatchStream, }; -use crate::prelude::SessionContext; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; @@ -59,6 +58,7 @@ use datafusion_expr::{ TableProviderFilterPushDown, UNNAMED_TABLE, }; +use crate::prelude::SessionContext; use async_trait::async_trait; /// Contains options that control how data is diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index b5ad6174821b..f5ca3992fa18 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1340,7 +1340,7 @@ impl SessionState { ); } - SessionState { + let mut new_self = SessionState { session_id, analyzer: Analyzer::new(), optimizer: Optimizer::new(), @@ -1356,7 +1356,13 @@ impl SessionState { execution_props: ExecutionProps::new(), runtime_env: runtime, table_factories, - } + }; + + // register built in functions + datafusion_functions::register_all(&mut new_self) + .expect("can not register built in functions"); + + new_self } /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. @@ -1976,6 +1982,10 @@ impl FunctionRegistry for SessionState { plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry") }) } + + fn register_udf(&mut self, udf: Arc) -> Result>> { + Ok(self.scalar_functions.insert(udf.name().into(), udf)) + } } impl OptimizerConfig for SessionState { diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 365f359f495d..0f7292e1c3d3 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -516,6 +516,11 @@ pub mod sql { pub use datafusion_sql::*; } +/// re-export of [`datafusion_functions`] crate +pub mod functions { + pub use datafusion_functions::*; +} + #[cfg(test)] pub mod test; pub mod test_util; diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index 5cd8b3870f81..69c33355402b 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -38,6 +38,7 @@ pub use datafusion_expr::{ logical_plan::{JoinType, Partitioning}, Expr, }; +pub use datafusion_functions::expr_fn::*; pub use std::ops::Not; pub use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 2d4203464300..486ea712edeb 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -20,6 +20,7 @@ use arrow::{ array::{Int32Array, StringArray}, record_batch::RecordBatch, }; +use arrow_schema::SchemaRef; use std::sync::Arc; use datafusion::dataframe::DataFrame; @@ -31,14 +32,19 @@ use datafusion::prelude::*; use datafusion::execution::context::SessionContext; use datafusion::assert_batches_eq; +use datafusion_common::DFSchema; use datafusion_expr::expr::Alias; -use datafusion_expr::{approx_median, cast}; +use datafusion_expr::{approx_median, cast, ExprSchemable}; -async fn create_test_table() -> Result { - let schema = Arc::new(Schema::new(vec![ +fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Int32, false), - ])); + ])) +} + +async fn create_test_table() -> Result { + let schema = test_schema(); // define data. let batch = RecordBatch::try_new( @@ -790,3 +796,48 @@ async fn test_fn_upper() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_fn_encode() -> Result<()> { + let expr = encode(col("a"), lit("hex")); + + let expected = [ + "+----------------------------+", + "| encode(test.a,Utf8(\"hex\")) |", + "+----------------------------+", + "| 616263444546 |", + "| 616263313233 |", + "| 434241646566 |", + "| 313233416263446566 |", + "+----------------------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_decode() -> Result<()> { + // Note that the decode function returns binary, and the default display of + // binary is "hexadecimal" and therefore the output looks like decode did + // nothing. So compare to a constant. + let df_schema = DFSchema::try_from(test_schema().as_ref().clone())?; + let expr = decode(encode(col("a"), lit("hex")), lit("hex")) + // need to cast to utf8 otherwise the default display of binary array is hex + // so it looks like nothing is done + .cast_to(&DataType::Utf8, &df_schema)?; + + let expected = [ + "+------------------------------------------------+", + "| decode(encode(test.a,Utf8(\"hex\")),Utf8(\"hex\")) |", + "+------------------------------------------------+", + "| abcDEF |", + "| abc123 |", + "| CBAdef |", + "| 123AbcDef |", + "+------------------------------------------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} diff --git a/datafusion/execution/src/registry.rs b/datafusion/execution/src/registry.rs index 9ba487e715b3..4d5b80f054df 100644 --- a/datafusion/execution/src/registry.rs +++ b/datafusion/execution/src/registry.rs @@ -17,8 +17,9 @@ //! FunctionRegistry trait -use datafusion_common::Result; +use datafusion_common::{not_impl_err, plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; +use std::collections::HashMap; use std::{collections::HashSet, sync::Arc}; /// A registry knows how to build logical expressions out of user-defined function' names @@ -34,6 +35,17 @@ pub trait FunctionRegistry { /// Returns a reference to the udwf named `name`. fn udwf(&self, name: &str) -> Result>; + + /// Registers a new [`ScalarUDF`], returning any previously registered + /// implementation. + /// + /// Returns an error (the default) if the function can not be registered, + /// for example if the registry is read only. + fn register_udf(&mut self, _udf: Arc) -> Result>> { + not_impl_err!("Registering ScalarUDF") + } + + // TODO add register_udaf and register_udwf } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. @@ -53,3 +65,51 @@ pub trait SerializerRegistry: Send + Sync { bytes: &[u8], ) -> Result>; } + +/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s +#[derive(Default, Debug)] +pub struct MemoryFunctionRegistry { + /// Scalar Functions + udfs: HashMap>, + /// Aggregate Functions + udafs: HashMap>, + /// Window Functions + udwfs: HashMap>, +} + +impl MemoryFunctionRegistry { + pub fn new() -> Self { + Self::default() + } +} + +impl FunctionRegistry for MemoryFunctionRegistry { + fn udfs(&self) -> HashSet { + self.udfs.keys().cloned().collect() + } + + fn udf(&self, name: &str) -> Result> { + self.udfs + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Function {name} not found")) + } + + fn udaf(&self, name: &str) -> Result> { + self.udafs + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Aggregate Function {name} not found")) + } + + fn udwf(&self, name: &str) -> Result> { + self.udwfs + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Window Function {name} not found")) + } + + fn register_udf(&mut self, udf: Arc) -> Result>> { + Ok(self.udfs.insert(udf.name().to_string(), udf)) + } +} diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index d8c22e69335d..f2eb82ebf9bd 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -69,14 +69,10 @@ pub enum BuiltinScalarFunction { Cos, /// cos Cosh, - /// Decode - Decode, /// degrees Degrees, /// Digest Digest, - /// Encode - Encode, /// exp Exp, /// factorial @@ -381,9 +377,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Cos => Volatility::Immutable, BuiltinScalarFunction::Cosh => Volatility::Immutable, - BuiltinScalarFunction::Decode => Volatility::Immutable, BuiltinScalarFunction::Degrees => Volatility::Immutable, - BuiltinScalarFunction::Encode => Volatility::Immutable, BuiltinScalarFunction::Exp => Volatility::Immutable, BuiltinScalarFunction::Factorial => Volatility::Immutable, BuiltinScalarFunction::Floor => Volatility::Immutable, @@ -774,30 +768,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Digest => { utf8_or_binary_to_binary_type(&input_expr_types[0], "digest") } - BuiltinScalarFunction::Encode => Ok(match input_expr_types[0] { - Utf8 => Utf8, - LargeUtf8 => LargeUtf8, - Binary => Utf8, - LargeBinary => LargeUtf8, - Null => Null, - _ => { - return plan_err!( - "The encode function can only accept utf8 or binary." - ); - } - }), - BuiltinScalarFunction::Decode => Ok(match input_expr_types[0] { - Utf8 => Binary, - LargeUtf8 => LargeBinary, - Binary => Binary, - LargeBinary => LargeBinary, - Null => Null, - _ => { - return plan_err!( - "The decode function can only accept utf8 or binary." - ); - } - }), BuiltinScalarFunction::SplitPart => { utf8_to_str_type(&input_expr_types[0], "split_part") } @@ -1089,24 +1059,6 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Encode => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Binary, Utf8]), - Exact(vec![LargeBinary, Utf8]), - ], - self.volatility(), - ), - BuiltinScalarFunction::Decode => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Binary, Utf8]), - Exact(vec![LargeBinary, Utf8]), - ], - self.volatility(), - ), BuiltinScalarFunction::DateTrunc => Signature::one_of( vec![ Exact(vec![Utf8, Timestamp(Nanosecond, None)]), @@ -1551,10 +1503,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::SHA384 => &["sha384"], BuiltinScalarFunction::SHA512 => &["sha512"], - // encode/decode - BuiltinScalarFunction::Encode => &["encode"], - BuiltinScalarFunction::Decode => &["decode"], - // other functions BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index dddf176dbe9f..4608badde231 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -496,7 +496,7 @@ pub fn is_not_unknown(expr: Expr) -> Expr { macro_rules! scalar_expr { ($ENUM:ident, $FUNC:ident, $($arg:ident)*, $DOC:expr) => { - #[doc = $DOC ] + #[doc = $DOC] pub fn $FUNC($($arg: Expr),*) -> Expr { Expr::ScalarFunction(ScalarFunction::new( built_in_function::BuiltinScalarFunction::$ENUM, @@ -795,8 +795,6 @@ scalar_expr!( "converts the Unicode code point to a UTF8 character" ); scalar_expr!(Digest, digest, input algorithm, "compute the binary hash of `input`, using the `algorithm`"); -scalar_expr!(Encode, encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex"); -scalar_expr!(Decode, decode, input encoding, "decode the`input`, using the `encoding`. encoding can be base64 or hex"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(InStr, instr, string substring, "returns the position of the first occurrence of `substring` in `string`"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); @@ -1370,8 +1368,6 @@ mod test { test_scalar_expr!(CharacterLength, character_length, string); test_scalar_expr!(Chr, chr, string); test_scalar_expr!(Digest, digest, string, algorithm); - test_scalar_expr!(Encode, encode, string, encoding); - test_scalar_expr!(Decode, decode, string, encoding); test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); @@ -1486,34 +1482,4 @@ mod test { unreachable!(); } } - - #[test] - fn encode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args, - }) = encode(col("tableA.a"), lit("base64")) - { - let name = BuiltinScalarFunction::Encode; - assert_eq!(name, fun); - assert_eq!(2, args.len()); - } else { - unreachable!(); - } - } - - #[test] - fn decode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args, - }) = decode(col("tableA.a"), lit("hex")) - { - let name = BuiltinScalarFunction::Decode; - assert_eq!(name, fun); - assert_eq!(2, args.len()); - } else { - unreachable!(); - } - } } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml new file mode 100644 index 000000000000..6d4a716e2e8e --- /dev/null +++ b/datafusion/functions/Cargo.toml @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-functions" +description = "Function packages for the DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[features] +# Enable encoding by default so the doctests work. In general don't automatically enable all packages. +default = ["encoding_expressions"] +# enable the encode/decode functions +encoding_expressions = ["base64", "hex"] + + +[lib] +name = "datafusion_functions" +path = "src/lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { workspace = true } +base64 = { version = "0.21", optional = true } +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +hex = { version = "0.4", optional = true } +log = "0.4.20" diff --git a/datafusion/functions/README.md b/datafusion/functions/README.md new file mode 100644 index 000000000000..a610d135c0f6 --- /dev/null +++ b/datafusion/functions/README.md @@ -0,0 +1,27 @@ + + +# DataFusion Function Library + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate contains packages of function that can be used to customize the +functionality of DataFusion. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-expr/src/encoding_expressions.rs b/datafusion/functions/src/encoding/inner.rs similarity index 82% rename from datafusion/physical-expr/src/encoding_expressions.rs rename to datafusion/functions/src/encoding/inner.rs index b74310485fb7..886a031a5269 100644 --- a/datafusion/physical-expr/src/encoding_expressions.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -32,12 +32,122 @@ use datafusion_expr::ColumnarValue; use std::sync::Arc; use std::{fmt, str::FromStr}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; + +#[derive(Debug)] +pub(super) struct EncodeFunc { + signature: Signature, +} + +impl EncodeFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Binary, Utf8]), + Exact(vec![LargeBinary, Utf8]), + ], + Volatility::Immutable, + ) + } + } +} + +impl ScalarUDFImpl for EncodeFunc { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "encode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(match arg_types[0] { + Utf8 => Utf8, + LargeUtf8 => LargeUtf8, + Binary => Utf8, + LargeBinary => LargeUtf8, + Null => Null, + _ => { + return plan_err!("The encode function can only accept utf8 or binary."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + encode(args) + } +} + +#[derive(Debug)] +pub(super) struct DecodeFunc { + signature: Signature, +} + +impl DecodeFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Binary, Utf8]), + Exact(vec![LargeBinary, Utf8]), + ], + Volatility::Immutable, + ) + } + } +} +impl ScalarUDFImpl for DecodeFunc { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "decode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(match arg_types[0] { + Utf8 => Binary, + LargeUtf8 => LargeBinary, + Binary => Binary, + LargeBinary => LargeBinary, + Null => Null, + _ => { + return plan_err!("The decode function can only accept utf8 or binary."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + decode(args) + } +} + #[derive(Debug, Copy, Clone)] enum Encoding { Base64, Hex, } - fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result { match value { ColumnarValue::Array(a) => match a.data_type() { @@ -293,7 +403,7 @@ impl FromStr for Encoding { /// Encodes the given data, accepts Binary, LargeBinary, Utf8 or LargeUtf8 and returns a [`ColumnarValue`]. /// Second argument is the encoding to use. /// Standard encodings are base64 and hex. -pub fn encode(args: &[ColumnarValue]) -> Result { +fn encode(args: &[ColumnarValue]) -> Result { if args.len() != 2 { return internal_err!( "{:?} args were supplied but encode takes exactly two arguments", @@ -319,7 +429,7 @@ pub fn encode(args: &[ColumnarValue]) -> Result { /// Decodes the given data, accepts Binary, LargeBinary, Utf8 or LargeUtf8 and returns a [`ColumnarValue`]. /// Second argument is the encoding to use. /// Standard encodings are base64 and hex. -pub fn decode(args: &[ColumnarValue]) -> Result { +fn decode(args: &[ColumnarValue]) -> Result { if args.len() != 2 { return internal_err!( "{:?} args were supplied but decode takes exactly two arguments", diff --git a/datafusion/functions/src/encoding/mod.rs b/datafusion/functions/src/encoding/mod.rs new file mode 100644 index 000000000000..7bb1a0ea3aa3 --- /dev/null +++ b/datafusion/functions/src/encoding/mod.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod inner; + + +// create `encode` and `decode` UDFs +make_udf_function!(inner::EncodeFunc, ENCODE, encode); +make_udf_function!(inner::DecodeFunc, DECODE, decode); + +// Export the functions out of this package, both as expr_fn as well as a list of functions +export_functions!( + (encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex"), + (decode, input encoding, "decode the `input`, using the `encoding`. encoding can be base64 or hex") +); + diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs new file mode 100644 index 000000000000..91a5c510f0f9 --- /dev/null +++ b/datafusion/functions/src/lib.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Function packages for [DataFusion]. +//! +//! This crate contains a collection of various function packages for DataFusion, +//! implemented using the extension API. Users may wish to control which functions +//! are available to control the binary size of their application as well as +//! use dialect specific implementations of functions (e.g. Spark vs Postgres) +//! +//! Each package is implemented as a separate +//! module, activated by a feature flag. +//! +//! [DataFusion]: https://crates.io/crates/datafusion +//! +//! # Available Packages +//! See the list of [modules](#modules) in this crate for available packages. +//! +//! # Using A Package +//! You can register all functions in all packages using the [`register_all`] function. +//! +//! To access and use only the functions in a certain package, use the +//! `functions()` method in each module. +//! +//! ``` +//! # fn main() -> datafusion_common::Result<()> { +//! # let mut registry = datafusion_execution::registry::MemoryFunctionRegistry::new(); +//! # use datafusion_execution::FunctionRegistry; +//! // get the encoding functions +//! use datafusion_functions::encoding; +//! for udf in encoding::functions() { +//! registry.register_udf(udf)?; +//! } +//! # Ok(()) +//! # } +//! ``` +//! +//! Each package also exports an `expr_fn` submodule to help create [`Expr`]s that invoke +//! functions using a fluent style. For example: +//! +//! ``` +//! // create an Expr that will invoke the encode function +//! use datafusion_expr::{col, lit}; +//! use datafusion_functions::expr_fn; +//! // Equivalent to "encode(my_data, 'hex')" in SQL: +//! let expr = expr_fn::encode(col("my_data"), lit("hex")); +//! ``` +//! +//![`Expr`]: datafusion_expr::Expr +//! +//! # Implementing A New Package +//! +//! To add a new package to this crate, you should follow the model of existing +//! packages. The high level steps are: +//! +//! 1. Create a new module with the appropriate [`ScalarUDF`] implementations. +//! +//! 2. Use the macros in [`macros`] to create standard entry points. +//! +//! 3. Add a new feature to `Cargo.toml`, with any optional dependencies +//! +//! 4. Use the `make_package!` macro to expose the module when the +//! feature is enabled. +//! +//! [`ScalarUDF`]: datafusion_expr::ScalarUDF +use datafusion_common::Result; +use datafusion_execution::FunctionRegistry; +use log::debug; + +#[macro_use] +pub mod macros; + +make_package!( + encoding, + "encoding_expressions", + "Hex and binary `encode` and `decode` functions." +); + +/// Fluent-style API for creating `Expr`s +pub mod expr_fn { + #[cfg(feature = "encoding_expressions")] + pub use super::encoding::expr_fn::*; +} + +/// Registers all enabled packages with a [`FunctionRegistry`] +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + encoding::functions().into_iter().try_for_each(|udf| { + let existing_udf = registry.register_udf(udf)?; + if let Some(existing_udf) = existing_udf { + debug!("Overwrite existing UDF: {}", existing_udf.name()); + } + Ok(()) as Result<()> + })?; + Ok(()) +} diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs new file mode 100644 index 000000000000..1931ee279421 --- /dev/null +++ b/datafusion/functions/src/macros.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// macro that exports a list of function names as: +/// 1. individual functions in an `expr_fn` module +/// 2. a single function that returns a list of all functions +/// +/// Equivalent to +/// ```text +/// pub mod expr_fn { +/// use super::*; +/// /// Return encode(arg) +/// pub fn encode(args: Vec) -> Expr { +/// super::encode().call(args) +/// } +/// ... +/// /// Return a list of all functions in this package +/// pub(crate) fn functions() -> Vec> { +/// vec![ +/// encode(), +/// decode() +/// ] +/// } +/// ``` +macro_rules! export_functions { + ($(($FUNC:ident, $($arg:ident)*, $DOC:expr)),*) => { + pub mod expr_fn { + $( + #[doc = $DOC] + /// Return $name(arg) + pub fn $FUNC($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + super::$FUNC().call(vec![$($arg),*],) + } + )* + } + + /// Return a list of all functions in this package + pub fn functions() -> Vec> { + vec![ + $( + $FUNC(), + )* + ] + } + }; +} + +/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a +/// function named `$NAME` which returns that function named $NAME. +/// +/// This is used to ensure creating the list of `ScalarUDF` only happens once. +macro_rules! make_udf_function { + ($UDF:ty, $GNAME:ident, $NAME:ident) => { + /// Singleton instance of the function + static $GNAME: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + /// Return a [`ScalarUDF`] for [`$UDF`] + /// + /// [`ScalarUDF`]: datafusion_expr::ScalarUDF + fn $NAME() -> std::sync::Arc { + $GNAME + .get_or_init(|| { + std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( + <$UDF>::new(), + )) + }) + .clone() + } + }; +} + +/// Macro creates the named module if the feature is enabled +/// otherwise creates a stub +/// +/// Which returns: +/// +/// 1. The list of actual function implementation when the relevant +/// feature is activated, +/// +/// 2. A list of stub function when the feature is not activated that produce +/// a runtime error (and explain what feature flag is needed to activate them). +/// +/// The rationale for providing stub functions is to help users to configure datafusion +/// properly (so they get an error telling them why a function is not available) +/// instead of getting a cryptic "no function found" message at runtime. + +macro_rules! make_package { + ($name:ident, $feature:literal, $DOC:expr) => { + #[cfg(feature = $feature)] + #[doc = $DOC ] + #[doc = concat!("Enabled via feature flag `", $feature, "`")] + pub mod $name; + + #[cfg(not(feature = $feature))] + #[doc = concat!("Disabled. Enable via feature flag `", $feature, "`")] + pub mod $name { + use datafusion_expr::ScalarUDF; + use log::debug; + use std::sync::Arc; + + /// Returns an empty list of functions when the feature is not enabled + pub fn functions() -> Vec> { + debug!("{} functions disabled", stringify!($name)); + vec![] + } + } + }; +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 7a175e23e3b4..cd4e6f96f0fe 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -85,26 +85,6 @@ pub fn create_physical_expr( ))) } -#[cfg(feature = "encoding_expressions")] -macro_rules! invoke_if_encoding_expressions_feature_flag { - ($FUNC:ident, $NAME:expr) => {{ - use crate::encoding_expressions; - encoding_expressions::$FUNC - }}; -} - -#[cfg(not(feature = "encoding_expressions"))] -macro_rules! invoke_if_encoding_expressions_feature_flag { - ($FUNC:ident, $NAME:expr) => { - |_: &[ColumnarValue]| -> Result { - internal_err!( - "function {} requires compilation with feature flag: encoding_expressions.", - $NAME - ) - } - }; -} - #[cfg(feature = "crypto_expressions")] macro_rules! invoke_if_crypto_expressions_feature_flag { ($FUNC:ident, $NAME:expr) => {{ @@ -652,12 +632,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Digest => { Arc::new(invoke_if_crypto_expressions_feature_flag!(digest, "digest")) } - BuiltinScalarFunction::Decode => Arc::new( - invoke_if_encoding_expressions_feature_flag!(decode, "decode"), - ), - BuiltinScalarFunction::Encode => Arc::new( - invoke_if_encoding_expressions_feature_flag!(encode, "encode"), - ), BuiltinScalarFunction::NullIf => Arc::new(nullif_func), BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 6f55f56916e7..95c1f3591d59 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -22,8 +22,6 @@ pub mod conditional_expressions; #[cfg(feature = "crypto_expressions")] pub mod crypto_expressions; pub mod datetime_expressions; -#[cfg(feature = "encoding_expressions")] -pub mod encoding_expressions; pub mod equivalence; pub mod execution_props; pub mod expressions; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 0ac7120d242b..0b93820db841 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -639,8 +639,6 @@ enum ScalarFunction { Cardinality = 98; ArrayElement = 99; ArraySlice = 100; - Encode = 101; - Decode = 102; Cot = 103; ArrayHas = 104; ArrayHasAny = 105; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index c0cc0b943ada..55e83a885382 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22391,8 +22391,6 @@ impl serde::Serialize for ScalarFunction { Self::Cardinality => "Cardinality", Self::ArrayElement => "ArrayElement", Self::ArraySlice => "ArraySlice", - Self::Encode => "Encode", - Self::Decode => "Decode", Self::Cot => "Cot", Self::ArrayHas => "ArrayHas", Self::ArrayHasAny => "ArrayHasAny", @@ -22536,8 +22534,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cardinality", "ArrayElement", "ArraySlice", - "Encode", - "Decode", "Cot", "ArrayHas", "ArrayHasAny", @@ -22710,8 +22706,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cardinality" => Ok(ScalarFunction::Cardinality), "ArrayElement" => Ok(ScalarFunction::ArrayElement), "ArraySlice" => Ok(ScalarFunction::ArraySlice), - "Encode" => Ok(ScalarFunction::Encode), - "Decode" => Ok(ScalarFunction::Decode), "Cot" => Ok(ScalarFunction::Cot), "ArrayHas" => Ok(ScalarFunction::ArrayHas), "ArrayHasAny" => Ok(ScalarFunction::ArrayHasAny), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e3b83748b89b..b17bcd3a49d7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2734,8 +2734,6 @@ pub enum ScalarFunction { Cardinality = 98, ArrayElement = 99, ArraySlice = 100, - Encode = 101, - Decode = 102, Cot = 103, ArrayHas = 104, ArrayHasAny = 105, @@ -2876,8 +2874,6 @@ impl ScalarFunction { ScalarFunction::Cardinality => "Cardinality", ScalarFunction::ArrayElement => "ArrayElement", ScalarFunction::ArraySlice => "ArraySlice", - ScalarFunction::Encode => "Encode", - ScalarFunction::Decode => "Decode", ScalarFunction::Cot => "Cot", ScalarFunction::ArrayHas => "ArrayHas", ScalarFunction::ArrayHasAny => "ArrayHasAny", @@ -3015,8 +3011,6 @@ impl ScalarFunction { "Cardinality" => Some(Self::Cardinality), "ArrayElement" => Some(Self::ArrayElement), "ArraySlice" => Some(Self::ArraySlice), - "Encode" => Some(Self::Encode), - "Decode" => Some(Self::Decode), "Cot" => Some(Self::Cot), "ArrayHas" => Some(Self::ArrayHas), "ArrayHasAny" => Some(Self::ArrayHasAny), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index bbaa280d63c5..b025f79bd1d0 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -55,8 +55,8 @@ use datafusion_expr::{ array_resize, array_slice, array_sort, array_to_string, array_union, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, - current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, - encode, ends_with, exp, + current_date, current_time, date_bin, date_part, date_trunc, degrees, digest, + ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, initcap, instr, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -519,8 +519,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Sha384 => Self::SHA384, ScalarFunction::Sha512 => Self::SHA512, ScalarFunction::Digest => Self::Digest, - ScalarFunction::Encode => Self::Encode, - ScalarFunction::Decode => Self::Decode, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, ScalarFunction::Ascii => Self::Ascii, @@ -1569,14 +1567,6 @@ pub fn parse_expr( ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), - ScalarFunction::Encode => Ok(encode( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), - ScalarFunction::Decode => Ok(decode( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), ScalarFunction::NullIf => Ok(nullif( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e7b9474a2d23..f7be15136bbb 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1517,8 +1517,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::SHA384 => Self::Sha384, BuiltinScalarFunction::SHA512 => Self::Sha512, BuiltinScalarFunction::Digest => Self::Digest, - BuiltinScalarFunction::Decode => Self::Decode, - BuiltinScalarFunction::Encode => Self::Encode, BuiltinScalarFunction::ToTimestampMillis => Self::ToTimestampMillis, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 1d935ebcd383..0db086419a79 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -35,7 +35,7 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; -use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; +use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; @@ -53,8 +53,8 @@ use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, - ColumnarValue, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, - Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + ColumnarValue, Expr, ExprSchemable, LogicalPlan, Operator, PartitionEvaluator, + Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_proto::bytes::{ @@ -558,6 +558,27 @@ async fn roundtrip_logical_plan_with_extension() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_expr_api() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + let table = ctx.table("t1").await?; + let schema = table.schema().clone(); + + // ensure expressions created with the expr api can be round tripped + let plan = table + .select(vec![ + encode(col("a").cast_to(&DataType::Utf8, &schema)?, lit("hex")), + decode(lit("1234"), lit("hex")), + ])? + .into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + Ok(()) +} + #[tokio::test] async fn roundtrip_logical_plan_with_view_scan() -> Result<()> { let ctx = SessionContext::new();