Skip to content

fix: Fix SparkSha2 to be compliant with Spark response and add support for Int32 #16350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 32 additions & 29 deletions datafusion/spark/src/function/hash/sha2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ extern crate datafusion_functions;
use crate::function::error_utils::{
invalid_arg_count_exec_err, unsupported_data_type_exec_err,
};
use crate::function::math::hex::spark_hex;
use crate::function::math::hex::spark_sha2_hex;
use arrow::array::{ArrayRef, AsArray, StringArray};
use arrow::datatypes::{DataType, UInt32Type};
use arrow::datatypes::{DataType, Int32Type};
use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
use datafusion_expr::Signature;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
Expand Down Expand Up @@ -121,7 +121,7 @@ impl ScalarUDFImpl for SparkSha2 {
)),
}?;
let bit_length_type = if arg_types[1].is_numeric() {
Ok(DataType::UInt32)
Ok(DataType::Int32)
} else if arg_types[1].is_null() {
Ok(DataType::Null)
} else {
Expand All @@ -138,39 +138,24 @@ impl ScalarUDFImpl for SparkSha2 {

pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
match args {
[ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Scalar(ScalarValue::UInt32(Some(bit_length_arg)))] => {
match bit_length_arg {
0 | 256 => sha256(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
224 => sha224(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
384 => sha384(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
512 => sha512(&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))]),
_ => exec_err!(
"sha2 function only supports 224, 256, 384, and 512 bit lengths."
),
}
.map(|hashed| spark_hex(&[hashed]).unwrap())
[ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => {
compute_sha2(
bit_length_arg,
&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))],
)
}
[ColumnarValue::Array(expr_arg), ColumnarValue::Scalar(ScalarValue::UInt32(Some(bit_length_arg)))] => {
match bit_length_arg {
0 | 256 => sha256(&[ColumnarValue::from(expr_arg)]),
224 => sha224(&[ColumnarValue::from(expr_arg)]),
384 => sha384(&[ColumnarValue::from(expr_arg)]),
512 => sha512(&[ColumnarValue::from(expr_arg)]),
_ => exec_err!(
"sha2 function only supports 224, 256, 384, and 512 bit lengths."
),
}
.map(|hashed| spark_hex(&[hashed]).unwrap())
[ColumnarValue::Array(expr_arg), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => {
compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)])
}
[ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Array(bit_length_arg)] =>
{
let arr: StringArray = bit_length_arg
.as_primitive::<UInt32Type>()
.as_primitive::<Int32Type>()
.iter()
.map(|bit_length| {
match sha2([
ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())),
ColumnarValue::Scalar(ScalarValue::UInt32(bit_length)),
ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
])
.unwrap()
{
Expand All @@ -188,15 +173,15 @@ pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
}
[ColumnarValue::Array(expr_arg), ColumnarValue::Array(bit_length_arg)] => {
let expr_iter = expr_arg.as_string::<i32>().iter();
let bit_length_iter = bit_length_arg.as_primitive::<UInt32Type>().iter();
let bit_length_iter = bit_length_arg.as_primitive::<Int32Type>().iter();
let arr: StringArray = expr_iter
.zip(bit_length_iter)
.map(|(expr, bit_length)| {
match sha2([
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
expr.unwrap().to_string(),
))),
ColumnarValue::Scalar(ScalarValue::UInt32(bit_length)),
ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
])
.unwrap()
{
Expand All @@ -215,3 +200,21 @@ pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
_ => exec_err!("Unsupported argument types for sha2 function"),
}
}

fn compute_sha2(
bit_length_arg: i32,
expr_arg: &[ColumnarValue],
) -> Result<ColumnarValue> {
match bit_length_arg {
0 | 256 => sha256(expr_arg),
224 => sha224(expr_arg),
384 => sha384(expr_arg),
512 => sha512(expr_arg),
_ => {
// Return null for unsupported bit lengths instead of error, because spark sha2 does not
// error out for this.
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
}
}
.map(|hashed| spark_sha2_hex(&[hashed]).unwrap())
}
31 changes: 23 additions & 8 deletions datafusion/spark/src/function/math/hex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,28 @@ fn hex_encode<T: AsRef<[u8]>>(data: T, lower_case: bool) -> String {
}

#[inline(always)]
fn hex_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<String, std::fmt::Error> {
let hex_string = hex_encode(bytes, false);
fn hex_bytes<T: AsRef<[u8]>>(
bytes: T,
lowercase: bool,
) -> Result<String, std::fmt::Error> {
let hex_string = hex_encode(bytes, lowercase);
Ok(hex_string)
}

/// Spark-compatible `hex` function
pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
compute_hex(args, false)
}

/// Spark-compatible `sha2` function
pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
compute_hex(args, true)
}

pub fn compute_hex(
args: &[ColumnarValue],
lowercase: bool,
) -> Result<ColumnarValue, DataFusionError> {
if args.len() != 1 {
return Err(DataFusionError::Internal(
"hex expects exactly one argument".to_string(),
Expand All @@ -192,7 +207,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro

let hexed: StringArray = array
.iter()
.map(|v| v.map(hex_bytes).transpose())
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
.collect::<Result<_, _>>()?;

Ok(ColumnarValue::Array(Arc::new(hexed)))
Expand All @@ -202,7 +217,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro

let hexed: StringArray = array
.iter()
.map(|v| v.map(hex_bytes).transpose())
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
.collect::<Result<_, _>>()?;

Ok(ColumnarValue::Array(Arc::new(hexed)))
Expand All @@ -212,7 +227,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro

let hexed: StringArray = array
.iter()
.map(|v| v.map(hex_bytes).transpose())
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
.collect::<Result<_, _>>()?;

Ok(ColumnarValue::Array(Arc::new(hexed)))
Expand All @@ -222,7 +237,7 @@ pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro

let hexed: StringArray = array
.iter()
.map(|v| v.map(hex_bytes).transpose())
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
.collect::<Result<_, _>>()?;

Ok(ColumnarValue::Array(Arc::new(hexed)))
Expand All @@ -237,11 +252,11 @@ pub fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErro
.collect::<Vec<_>>(),
DataType::Utf8 => as_string_array(dict.values())
.iter()
.map(|v| v.map(hex_bytes).transpose())
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
.collect::<Result<_, _>>()?,
DataType::Binary => as_binary_array(dict.values())?
.iter()
.map(|v| v.map(hex_bytes).transpose())
.map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose())
.collect::<Result<_, _>>()?,
_ => exec_err!(
"hex got an unexpected argument type: {:?}",
Expand Down
48 changes: 30 additions & 18 deletions datafusion/sqllogictest/test_files/spark/hash/sha2.slt
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,60 @@
query T
SELECT sha2('Spark', 0::INT);
----
529BC3B07127ECB7E53A4DCF1991D9152C24537D919178022B2C42657F79A26B
529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b

query T
SELECT sha2('Spark', 256::INT);
----
529BC3B07127ECB7E53A4DCF1991D9152C24537D919178022B2C42657F79A26B
529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b

query T
SELECT sha2('Spark', 224::INT);
----
DBEAB94971678D36AF2195851C0F7485775A2A7C60073D62FC04549C
dbeab94971678d36af2195851c0f7485775a2a7c60073d62fc04549c

query T
SELECT sha2('Spark', 384::INT);
----
1E40B8D06C248A1CC32428C22582B6219D072283078FA140D9AD297ECADF2CABEFC341B857AD36226AA8D6D79F2AB67D
1e40b8d06c248a1cc32428c22582b6219d072283078fa140d9ad297ecadf2cabefc341b857ad36226aa8d6d79f2ab67d

query T
SELECT sha2('Spark', 512::INT);
----
44844A586C54C9A212DA1DBFE05C5F1705DE1AF5FDA1F0D36297623249B279FD8F0CCEC03F888F4FB13BF7CD83FDAD58591C797F81121A23CFDD5E0897795238
44844a586c54c9a212da1dbfe05c5f1705de1af5fda1f0d36297623249b279fd8f0ccec03f888f4fb13bf7cd83fdad58591c797f81121a23cfdd5e0897795238

query T
SELECT sha2('Spark', 128::INT);
----
NULL

query T
SELECT sha2(expr, 256::INT) FROM VALUES ('foo'), ('bar') AS t(expr);
----
2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
FCDE2B2EDBA56BF408601FB721FE9B5C338D10EE429EA04FAE5511B68FBF8FB9
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
fcde2b2edba56bf408601fb721fe9b5c338d10ee429ea04fae5511b68fbf8fb9

query T
SELECT sha2('foo', bit_length) FROM VALUES (0::INT), (256::INT), (224::INT), (384::INT), (512::INT) AS t(bit_length);
SELECT sha2(expr, 128::INT) FROM VALUES ('foo'), ('bar') AS t(expr);
----
2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
0808F64E60D58979FCB676C96EC938270DEA42445AEEFCD3A4E6F8DB
98C11FFDFDD540676B1A137CB1A22B2A70350C9A44171D6B1180C6BE5CBB2EE3F79D532C8A1DD9EF2E8E08E752A3BABB
F7FBBA6E0636F890E56FBBF3283E524C6FA3204AE298382D624741D0DC6638326E282C41BE5E4254D8820772C5518A2C5A8C0C7F7EDA19594A7EB539453E1ED7
NULL
NULL

query T
SELECT sha2('foo', bit_length) FROM VALUES (0::INT), (256::INT), (224::INT), (384::INT), (512::INT), (128::INT) AS t(bit_length);
----
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
98c11ffdfdd540676b1a137cb1a22b2a70350c9a44171d6b1180c6be5cbb2ee3f79d532c8a1dd9ef2e8e08e752a3babb
f7fbba6e0636f890e56fbbf3283e524c6fa3204ae298382d624741d0dc6638326e282c41be5e4254d8820772c5518a2c5a8c0c7f7eda19594a7eb539453e1ed7
NULL

query T
SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('baz',384::INT), ('qux',512::INT) AS t(expr, bit_length);
SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('baz',384::INT), ('qux',512::INT), ('qux',128::INT) AS t(expr, bit_length);
----
2C26B46B68FFC68FF99B453C1D30413413422D706483BFA0F98A5E886266E7AE
07DAF010DE7F7F0D8D76A76EB8D1EB40182C8D1E7A3877A6686C9BF0
967004D25DE4ABC1BD6A7C9A216254A5AC0733E8AD96DC9F1EA0FAD9619DA7C32D654EC8AD8BA2F9B5728FED6633BD91
8C6BE9ED448A34883A13A13F4EAD4AEFA036B67DCDA59020C01E57EA075EA8A4792D428F2C6FD0C09D1C49994D6C22789336E062188DF29572ED07E7F9779C52
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
07daf010de7f7f0d8d76a76eb8d1eb40182c8d1e7a3877a6686c9bf0
967004d25de4abc1bd6a7c9a216254a5ac0733e8ad96dc9f1ea0fad9619da7c32d654ec8ad8ba2f9b5728fed6633bd91
8c6be9ed448a34883a13a13f4ead4aefa036b67dcda59020c01e57ea075ea8a4792d428f2c6fd0c09d1c49994d6c22789336e062188df29572ed07e7f9779c52
NULL