Skip to content

Commit

Permalink
chore: Move some utility methods to submodules of scalar_funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Jun 20, 2024
1 parent 28309a4 commit a298162
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 147 deletions.
12 changes: 7 additions & 5 deletions core/benches/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
mod common;

use arrow_array::ArrayRef;
use comet::execution::datafusion::expressions::scalar_funcs::spark_murmur3_hash;
use comet::execution::datafusion::spark_hash::create_xxhash64_hashes;
use comet::execution::datafusion::expressions::scalar_funcs::{spark_murmur3_hash, spark_xxhash64};
use comet::execution::kernels::hash;
use common::*;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
Expand Down Expand Up @@ -100,12 +99,15 @@ fn criterion_benchmark(c: &mut Criterion) {
},
);
group.bench_function(BenchmarkId::new("xxhash64", BATCH_SIZE), |b| {
let input = vec![a3.clone(), a4.clone()];
let mut dst = vec![0; BATCH_SIZE];
let inputs = &[
ColumnarValue::Array(a3.clone()),
ColumnarValue::Array(a4.clone()),
ColumnarValue::Scalar(ScalarValue::Int64(Some(42i64))),
];

b.iter(|| {
for _ in 0..NUM_ITER {
create_xxhash64_hashes(&input, &mut dst).unwrap();
spark_xxhash64(inputs).unwrap();
}
});
});
Expand Down
142 changes: 9 additions & 133 deletions core/src/execution/datafusion/expressions/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use std::{
any::Any,
cmp::min,
fmt::{Debug, Write},
sync::Arc,
};
use std::{any::Any, cmp::min, fmt::Debug, sync::Arc};

use crate::execution::datafusion::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes};
use arrow::{
array::{
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray,
Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array, StringArray};
use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array};
use arrow_schema::DataType;
use datafusion::{
execution::FunctionRegistry,
Expand All @@ -39,8 +33,8 @@ use datafusion::{
physical_plan::ColumnarValue,
};
use datafusion_common::{
cast::{as_binary_array, as_generic_string_array},
exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
Result as DataFusionResult, ScalarValue,
};
use datafusion_expr::ScalarUDF;
use num::{
Expand All @@ -53,11 +47,15 @@ mod unhex;
use unhex::spark_unhex;

mod hex;
use hex::spark_hex;
use hex::{spark_hex, wrap_digest_result_as_hex_string};

mod chr;
use chr::spark_chr;

pub mod hash_expressions;
// exposed for benchmark only
pub use hash_expressions::{spark_murmur3_hash, spark_xxhash64};

macro_rules! make_comet_scalar_udf {
($name:expr, $func:ident, $data_type:ident) => {{
let scalar_func = CometScalarFunction::new(
Expand Down Expand Up @@ -635,125 +633,3 @@ fn spark_decimal_div(
let result = result.with_data_type(DataType::Decimal128(p3, s3));
Ok(ColumnarValue::Array(Arc::new(result)))
}

pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
let length = args.len();
let seed = &args[length - 1];
match seed {
ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => {
// iterate over the arguments to find out the length of the array
let num_rows = args[0..args.len() - 1]
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(array) => Some(array.len()),
ColumnarValue::Scalar(_) => None,
})
.unwrap_or(1);
let mut hashes: Vec<u32> = vec![0_u32; num_rows];
hashes.fill(*seed as u32);
let arrays = args[0..args.len() - 1]
.iter()
.map(|arg| match arg {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.clone().to_array_of_size(num_rows).unwrap()
}
})
.collect::<Vec<ArrayRef>>();
create_murmur3_hashes(&arrays, &mut hashes)?;
if num_rows == 1 {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(
hashes[0] as i32,
))))
} else {
let hashes: Vec<i32> = hashes.into_iter().map(|x| x as i32).collect();
Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes))))
}
}
_ => {
internal_err!(
"The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.",
seed
)
}
}
}

fn spark_xxhash64(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
let length = args.len();
let seed = &args[length - 1];
match seed {
ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => {
// iterate over the arguments to find out the length of the array
let num_rows = args[0..args.len() - 1]
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(array) => Some(array.len()),
ColumnarValue::Scalar(_) => None,
})
.unwrap_or(1);
let mut hashes: Vec<u64> = vec![0_u64; num_rows];
hashes.fill(*seed as u64);
let arrays = args[0..args.len() - 1]
.iter()
.map(|arg| match arg {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.clone().to_array_of_size(num_rows).unwrap()
}
})
.collect::<Vec<ArrayRef>>();
create_xxhash64_hashes(&arrays, &mut hashes)?;
if num_rows == 1 {
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(
hashes[0] as i64,
))))
} else {
let hashes: Vec<i64> = hashes.into_iter().map(|x| x as i64).collect();
Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes))))
}
}
_ => {
internal_err!(
"The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.",
seed
)
}
}
}

#[inline]
fn hex_encode<T: AsRef<[u8]>>(data: T) -> String {
let mut s = String::with_capacity(data.as_ref().len() * 2);
for b in data.as_ref() {
// Writing to a string never errors, so we can unwrap here.
write!(&mut s, "{b:02x}").unwrap();
}
s
}

fn wrap_digest_result_as_hex_string(
args: &[ColumnarValue],
digest: ScalarFunctionImplementation,
) -> Result<ColumnarValue, DataFusionError> {
let value = digest(args)?;
match value {
ColumnarValue::Array(array) => {
let binary_array = as_binary_array(&array)?;
let string_array: StringArray = binary_array
.iter()
.map(|opt| opt.map(hex_encode::<_>))
.collect();
Ok(ColumnarValue::Array(Arc::new(string_array)))
}
ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar(
ScalarValue::Utf8(opt.map(hex_encode::<_>)),
)),
_ => {
exec_err!(
"digest function should return binary value, but got: {:?}",
value.data_type()
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::execution::datafusion::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes};
use arrow_array::{ArrayRef, Int32Array, Int64Array};
use datafusion_common::{internal_err, DataFusionError, ScalarValue};
use datafusion_expr::ColumnarValue;
use std::sync::Arc;

pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
let length = args.len();
let seed = &args[length - 1];
match seed {
ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => {
// iterate over the arguments to find out the length of the array
let num_rows = args[0..args.len() - 1]
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(array) => Some(array.len()),
ColumnarValue::Scalar(_) => None,
})
.unwrap_or(1);
let mut hashes: Vec<u32> = vec![0_u32; num_rows];
hashes.fill(*seed as u32);
let arrays = args[0..args.len() - 1]
.iter()
.map(|arg| match arg {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.clone().to_array_of_size(num_rows).unwrap()
}
})
.collect::<Vec<ArrayRef>>();
create_murmur3_hashes(&arrays, &mut hashes)?;
if num_rows == 1 {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(
hashes[0] as i32,
))))
} else {
let hashes: Vec<i32> = hashes.into_iter().map(|x| x as i32).collect();
Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes))))
}
}
_ => {
internal_err!(
"The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.",
seed
)
}
}
}

pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
let length = args.len();
let seed = &args[length - 1];
match seed {
ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => {
// iterate over the arguments to find out the length of the array
let num_rows = args[0..args.len() - 1]
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(array) => Some(array.len()),
ColumnarValue::Scalar(_) => None,
})
.unwrap_or(1);
let mut hashes: Vec<u64> = vec![0_u64; num_rows];
hashes.fill(*seed as u64);
let arrays = args[0..args.len() - 1]
.iter()
.map(|arg| match arg {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.clone().to_array_of_size(num_rows).unwrap()
}
})
.collect::<Vec<ArrayRef>>();
create_xxhash64_hashes(&arrays, &mut hashes)?;
if num_rows == 1 {
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(
hashes[0] as i64,
))))
} else {
let hashes: Vec<i64> = hashes.into_iter().map(|x| x as i64).collect();
Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes))))
}
}
_ => {
internal_err!(
"The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.",
seed
)
}
}
}
Loading

0 comments on commit a298162

Please sign in to comment.