Skip to content

Commit fc5888b

Browse files
authored
feat(spark): implement Spark length function (#17475)
1 parent a96dcc8 commit fc5888b

File tree

6 files changed

+330
-65
lines changed

6 files changed

+330
-65
lines changed
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{
19+
Array, ArrayRef, AsArray, BinaryArrayType, PrimitiveArray, StringArrayType,
20+
};
21+
use arrow::datatypes::{DataType, Int32Type};
22+
use datafusion_common::exec_err;
23+
use datafusion_expr::{
24+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
25+
};
26+
use datafusion_functions::utils::make_scalar_function;
27+
use std::sync::Arc;
28+
29+
/// Spark-compatible `length` expression
30+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#length>
31+
#[derive(Debug, PartialEq, Eq, Hash)]
32+
pub struct SparkLengthFunc {
33+
signature: Signature,
34+
aliases: Vec<String>,
35+
}
36+
37+
impl Default for SparkLengthFunc {
38+
fn default() -> Self {
39+
Self::new()
40+
}
41+
}
42+
43+
impl SparkLengthFunc {
44+
pub fn new() -> Self {
45+
Self {
46+
signature: Signature::uniform(
47+
1,
48+
vec![
49+
DataType::Utf8View,
50+
DataType::Utf8,
51+
DataType::LargeUtf8,
52+
DataType::Binary,
53+
DataType::LargeBinary,
54+
DataType::BinaryView,
55+
],
56+
Volatility::Immutable,
57+
),
58+
aliases: vec![
59+
String::from("character_length"),
60+
String::from("char_length"),
61+
String::from("len"),
62+
],
63+
}
64+
}
65+
}
66+
67+
impl ScalarUDFImpl for SparkLengthFunc {
68+
fn as_any(&self) -> &dyn std::any::Any {
69+
self
70+
}
71+
72+
fn name(&self) -> &str {
73+
"length"
74+
}
75+
76+
fn signature(&self) -> &Signature {
77+
&self.signature
78+
}
79+
80+
fn return_type(&self, _args: &[DataType]) -> datafusion_common::Result<DataType> {
81+
// spark length always returns Int32
82+
Ok(DataType::Int32)
83+
}
84+
85+
fn invoke_with_args(
86+
&self,
87+
args: ScalarFunctionArgs,
88+
) -> datafusion_common::Result<ColumnarValue> {
89+
make_scalar_function(spark_length, vec![])(&args.args)
90+
}
91+
92+
fn aliases(&self) -> &[String] {
93+
&self.aliases
94+
}
95+
}
96+
97+
fn spark_length(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
98+
match args[0].data_type() {
99+
DataType::Utf8 => {
100+
let string_array = args[0].as_string::<i32>();
101+
character_length::<_>(string_array)
102+
}
103+
DataType::LargeUtf8 => {
104+
let string_array = args[0].as_string::<i64>();
105+
character_length::<_>(string_array)
106+
}
107+
DataType::Utf8View => {
108+
let string_array = args[0].as_string_view();
109+
character_length::<_>(string_array)
110+
}
111+
DataType::Binary => {
112+
let binary_array = args[0].as_binary::<i32>();
113+
byte_length::<_>(binary_array)
114+
}
115+
DataType::LargeBinary => {
116+
let binary_array = args[0].as_binary::<i64>();
117+
byte_length::<_>(binary_array)
118+
}
119+
DataType::BinaryView => {
120+
let binary_array = args[0].as_binary_view();
121+
byte_length::<_>(binary_array)
122+
}
123+
other => exec_err!("Unsupported data type {other:?} for function `length`"),
124+
}
125+
}
126+
127+
fn character_length<'a, V>(array: V) -> datafusion_common::Result<ArrayRef>
128+
where
129+
V: StringArrayType<'a>,
130+
{
131+
// String characters are variable length encoded in UTF-8, counting the
132+
// number of chars requires expensive decoding, however checking if the
133+
// string is ASCII only is relatively cheap.
134+
// If strings are ASCII only, count bytes instead.
135+
let is_array_ascii_only = array.is_ascii();
136+
let nulls = array.nulls().cloned();
137+
let array = {
138+
if is_array_ascii_only {
139+
let values: Vec<_> = (0..array.len())
140+
.map(|i| {
141+
// Safety: we are iterating with array.len() so the index is always valid
142+
let value = unsafe { array.value_unchecked(i) };
143+
value.len() as i32
144+
})
145+
.collect();
146+
PrimitiveArray::<Int32Type>::new(values.into(), nulls)
147+
} else {
148+
let values: Vec<_> = (0..array.len())
149+
.map(|i| {
150+
// Safety: we are iterating with array.len() so the index is always valid
151+
if array.is_null(i) {
152+
i32::default()
153+
} else {
154+
let value = unsafe { array.value_unchecked(i) };
155+
if value.is_empty() {
156+
i32::default()
157+
} else if value.is_ascii() {
158+
value.len() as i32
159+
} else {
160+
value.chars().count() as i32
161+
}
162+
}
163+
})
164+
.collect();
165+
PrimitiveArray::<Int32Type>::new(values.into(), nulls)
166+
}
167+
};
168+
169+
Ok(Arc::new(array))
170+
}
171+
172+
fn byte_length<'a, V>(array: V) -> datafusion_common::Result<ArrayRef>
173+
where
174+
V: BinaryArrayType<'a>,
175+
{
176+
let nulls = array.nulls().cloned();
177+
let values: Vec<_> = (0..array.len())
178+
.map(|i| {
179+
// Safety: we are iterating with array.len() so the index is always valid
180+
let value = unsafe { array.value_unchecked(i) };
181+
value.len() as i32
182+
})
183+
.collect();
184+
Ok(Arc::new(PrimitiveArray::<Int32Type>::new(
185+
values.into(),
186+
nulls,
187+
)))
188+
}
189+
190+
#[cfg(test)]
191+
mod tests {
192+
use super::*;
193+
use crate::function::utils::test::test_scalar_function;
194+
use arrow::array::{Array, Int32Array};
195+
use arrow::datatypes::DataType::Int32;
196+
use datafusion_common::{Result, ScalarValue};
197+
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
198+
199+
macro_rules! test_spark_length_string {
200+
($INPUT:expr, $EXPECTED:expr) => {
201+
test_scalar_function!(
202+
SparkLengthFunc::new(),
203+
vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
204+
$EXPECTED,
205+
i32,
206+
Int32,
207+
Int32Array
208+
);
209+
210+
test_scalar_function!(
211+
SparkLengthFunc::new(),
212+
vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
213+
$EXPECTED,
214+
i32,
215+
Int32,
216+
Int32Array
217+
);
218+
219+
test_scalar_function!(
220+
SparkLengthFunc::new(),
221+
vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
222+
$EXPECTED,
223+
i32,
224+
Int32,
225+
Int32Array
226+
);
227+
};
228+
}
229+
230+
macro_rules! test_spark_length_binary {
231+
($INPUT:expr, $EXPECTED:expr) => {
232+
test_scalar_function!(
233+
SparkLengthFunc::new(),
234+
vec![ColumnarValue::Scalar(ScalarValue::Binary($INPUT))],
235+
$EXPECTED,
236+
i32,
237+
Int32,
238+
Int32Array
239+
);
240+
241+
test_scalar_function!(
242+
SparkLengthFunc::new(),
243+
vec![ColumnarValue::Scalar(ScalarValue::LargeBinary($INPUT))],
244+
$EXPECTED,
245+
i32,
246+
Int32,
247+
Int32Array
248+
);
249+
250+
test_scalar_function!(
251+
SparkLengthFunc::new(),
252+
vec![ColumnarValue::Scalar(ScalarValue::BinaryView($INPUT))],
253+
$EXPECTED,
254+
i32,
255+
Int32,
256+
Int32Array
257+
);
258+
};
259+
}
260+
261+
#[test]
262+
fn test_functions() -> Result<()> {
263+
test_spark_length_string!(Some(String::from("chars")), Ok(Some(5)));
264+
test_spark_length_string!(Some(String::from("josé")), Ok(Some(4)));
265+
// test long strings (more than 12 bytes for StringView)
266+
test_spark_length_string!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16)));
267+
test_spark_length_string!(Some(String::from("")), Ok(Some(0)));
268+
test_spark_length_string!(None, Ok(None));
269+
270+
test_spark_length_binary!(Some(String::from("chars").into_bytes()), Ok(Some(5)));
271+
test_spark_length_binary!(Some(String::from("josé").into_bytes()), Ok(Some(5)));
272+
// test long strings (more than 12 bytes for BinaryView)
273+
test_spark_length_binary!(
274+
Some(String::from("joséjoséjoséjosé").into_bytes()),
275+
Ok(Some(20))
276+
);
277+
test_spark_length_binary!(Some(String::from("").into_bytes()), Ok(Some(0)));
278+
test_spark_length_binary!(None, Ok(None));
279+
280+
Ok(())
281+
}
282+
}

datafusion/spark/src/function/string/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
pub mod ascii;
1919
pub mod char;
2020
pub mod ilike;
21+
pub mod length;
2122
pub mod like;
2223
pub mod luhn_check;
2324

@@ -28,6 +29,7 @@ use std::sync::Arc;
2829
make_udf_function!(ascii::SparkAscii, ascii);
2930
make_udf_function!(char::CharFunc, char);
3031
make_udf_function!(ilike::SparkILike, ilike);
32+
make_udf_function!(length::SparkLengthFunc, length);
3133
make_udf_function!(like::SparkLike, like);
3234
make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check);
3335

@@ -49,6 +51,11 @@ pub mod expr_fn {
4951
"Returns true if str matches pattern (case insensitive).",
5052
str pattern
5153
));
54+
export_functions!((
55+
length,
56+
"Returns the character length of string data or number of bytes of binary data. The length of string data includes the trailing spaces. The length of binary data includes binary zeros.",
57+
arg1
58+
));
5259
export_functions!((
5360
like,
5461
"Returns true if str matches pattern (case sensitive).",
@@ -62,5 +69,5 @@ pub mod expr_fn {
6269
}
6370

6471
pub fn functions() -> Vec<Arc<ScalarUDF>> {
65-
vec![ascii(), char(), ilike(), like(), luhn_check()]
72+
vec![ascii(), char(), ilike(), length(), like(), luhn_check()]
6673
}

datafusion/sqllogictest/test_files/spark/string/char_length.slt

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,17 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
# This file was originally created by a porting script from:
19-
# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function
20-
# This file is part of the implementation of the datafusion-spark function library.
21-
# For more information, please see:
22-
# https://github.com/apache/datafusion/issues/15914
18+
query I
19+
SELECT CHAR_LENGTH('Spark SQL ');
20+
----
21+
10
2322

24-
## Original Query: SELECT CHAR_LENGTH('Spark SQL ');
25-
## PySpark 3.5.5 Result: {'char_length(Spark SQL )': 10, 'typeof(char_length(Spark SQL ))': 'int', 'typeof(Spark SQL )': 'string'}
26-
#query
27-
#SELECT CHAR_LENGTH('Spark SQL '::string);
23+
query I
24+
SELECT char_length('Spark SQL ');
25+
----
26+
10
2827

29-
## Original Query: SELECT char_length('Spark SQL ');
30-
## PySpark 3.5.5 Result: {'char_length(Spark SQL )': 10, 'typeof(char_length(Spark SQL ))': 'int', 'typeof(Spark SQL )': 'string'}
31-
#query
32-
#SELECT char_length('Spark SQL '::string);
33-
34-
## Original Query: SELECT char_length(x'537061726b2053514c');
35-
## PySpark 3.5.5 Result: {"char_length(X'537061726B2053514C')": 9, "typeof(char_length(X'537061726B2053514C'))": 'int', "typeof(X'537061726B2053514C')": 'binary'}
36-
#query
37-
#SELECT char_length(X'537061726B2053514C'::binary);
28+
query I
29+
SELECT char_length(x'537061726b2053514c');
30+
----
31+
9

datafusion/sqllogictest/test_files/spark/string/character_length.slt

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,17 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
# This file was originally created by a porting script from:
19-
# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function
20-
# This file is part of the implementation of the datafusion-spark function library.
21-
# For more information, please see:
22-
# https://github.com/apache/datafusion/issues/15914
18+
query I
19+
SELECT CHARACTER_LENGTH('Spark SQL ');
20+
----
21+
10
2322

24-
## Original Query: SELECT CHARACTER_LENGTH('Spark SQL ');
25-
## PySpark 3.5.5 Result: {'character_length(Spark SQL )': 10, 'typeof(character_length(Spark SQL ))': 'int', 'typeof(Spark SQL )': 'string'}
26-
#query
27-
#SELECT CHARACTER_LENGTH('Spark SQL '::string);
23+
query I
24+
SELECT character_length('Spark SQL ');
25+
----
26+
10
2827

29-
## Original Query: SELECT character_length('Spark SQL ');
30-
## PySpark 3.5.5 Result: {'character_length(Spark SQL )': 10, 'typeof(character_length(Spark SQL ))': 'int', 'typeof(Spark SQL )': 'string'}
31-
#query
32-
#SELECT character_length('Spark SQL '::string);
33-
34-
## Original Query: SELECT character_length(x'537061726b2053514c');
35-
## PySpark 3.5.5 Result: {"character_length(X'537061726B2053514C')": 9, "typeof(character_length(X'537061726B2053514C'))": 'int', "typeof(X'537061726B2053514C')": 'binary'}
36-
#query
37-
#SELECT character_length(X'537061726B2053514C'::binary);
28+
query I
29+
SELECT character_length(x'537061726b2053514c');
30+
----
31+
9

0 commit comments

Comments
 (0)