Skip to content

Commit 345b31c

Browse files
chenkovskytobixdev
authored andcommitted
feat: spark udf array shuffle (apache#17674)
## Which issue does this PR close? ## Rationale for this change support shuffle udf ## What changes are included in this PR? support shuffle udf ## Are these changes tested? UT ## Are there any user-facing changes? No
1 parent 25ef514 commit 345b31c

File tree

4 files changed

+313
-2
lines changed

4 files changed

+313
-2
lines changed

datafusion/spark/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ datafusion-execution = { workspace = true }
4646
datafusion-expr = { workspace = true }
4747
datafusion-functions = { workspace = true, features = ["crypto_expressions"] }
4848
log = { workspace = true }
49+
rand = { workspace = true }
4950
sha1 = "0.10"
5051
url = { workspace = true }
5152

5253
[dev-dependencies]
5354
criterion = { workspace = true }
54-
rand = { workspace = true }
5555

5656
[[bench]]
5757
harness = false

datafusion/spark/src/function/array/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,27 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
pub mod shuffle;
1819
pub mod spark_array;
1920

2021
use datafusion_expr::ScalarUDF;
2122
use datafusion_functions::make_udf_function;
2223
use std::sync::Arc;
2324

2425
make_udf_function!(spark_array::SparkArray, array);
26+
make_udf_function!(shuffle::SparkShuffle, shuffle);
2527

2628
pub mod expr_fn {
2729
use datafusion_functions::export_functions;
2830

2931
export_functions!((array, "Returns an array with the given elements.", args));
32+
export_functions!((
33+
shuffle,
34+
"Returns a random permutation of the given array.",
35+
args
36+
));
3037
}
3138

3239
pub fn functions() -> Vec<Arc<ScalarUDF>> {
33-
vec![array()]
40+
vec![array(), shuffle()]
3441
}
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::function::functions_nested_utils::make_scalar_function;
19+
use arrow::array::{
20+
Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData,
21+
OffsetSizeTrait,
22+
};
23+
use arrow::buffer::OffsetBuffer;
24+
use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
25+
use arrow::datatypes::{DataType, FieldRef};
26+
use datafusion_common::cast::{
27+
as_fixed_size_list_array, as_large_list_array, as_list_array,
28+
};
29+
use datafusion_common::{exec_err, utils::take_function_args, Result};
30+
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
31+
use rand::rng;
32+
use rand::seq::SliceRandom;
33+
use std::any::Any;
34+
use std::sync::Arc;
35+
36+
#[derive(Debug, PartialEq, Eq, Hash)]
37+
pub struct SparkShuffle {
38+
signature: Signature,
39+
}
40+
41+
impl Default for SparkShuffle {
42+
fn default() -> Self {
43+
Self::new()
44+
}
45+
}
46+
47+
impl SparkShuffle {
48+
pub fn new() -> Self {
49+
Self {
50+
signature: Signature::arrays(1, None, Volatility::Volatile),
51+
}
52+
}
53+
}
54+
55+
impl ScalarUDFImpl for SparkShuffle {
56+
fn as_any(&self) -> &dyn Any {
57+
self
58+
}
59+
60+
fn name(&self) -> &str {
61+
"shuffle"
62+
}
63+
64+
fn signature(&self) -> &Signature {
65+
&self.signature
66+
}
67+
68+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
69+
Ok(arg_types[0].clone())
70+
}
71+
72+
fn invoke_with_args(
73+
&self,
74+
args: datafusion_expr::ScalarFunctionArgs,
75+
) -> Result<ColumnarValue> {
76+
make_scalar_function(array_shuffle_inner)(&args.args)
77+
}
78+
}
79+
80+
/// array_shuffle SQL function
81+
pub fn array_shuffle_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
82+
let [input_array] = take_function_args("shuffle", arg)?;
83+
match &input_array.data_type() {
84+
List(field) => {
85+
let array = as_list_array(input_array)?;
86+
general_array_shuffle::<i32>(array, field)
87+
}
88+
LargeList(field) => {
89+
let array = as_large_list_array(input_array)?;
90+
general_array_shuffle::<i64>(array, field)
91+
}
92+
FixedSizeList(field, _) => {
93+
let array = as_fixed_size_list_array(input_array)?;
94+
fixed_size_array_shuffle(array, field)
95+
}
96+
Null => Ok(Arc::clone(input_array)),
97+
array_type => exec_err!("shuffle does not support type '{array_type}'."),
98+
}
99+
}
100+
101+
fn general_array_shuffle<O: OffsetSizeTrait>(
102+
array: &GenericListArray<O>,
103+
field: &FieldRef,
104+
) -> Result<ArrayRef> {
105+
let values = array.values();
106+
let original_data = values.to_data();
107+
let capacity = Capacities::Array(original_data.len());
108+
let mut offsets = vec![O::usize_as(0)];
109+
let mut nulls = vec![];
110+
let mut mutable =
111+
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
112+
let mut rng = rng();
113+
114+
for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
115+
// skip the null value
116+
if array.is_null(row_index) {
117+
nulls.push(false);
118+
offsets.push(offsets[row_index] + O::one());
119+
mutable.extend(0, 0, 1);
120+
continue;
121+
}
122+
nulls.push(true);
123+
let start = offset_window[0];
124+
let end = offset_window[1];
125+
let length = (end - start).to_usize().unwrap();
126+
127+
// Create indices and shuffle them
128+
let mut indices: Vec<usize> =
129+
(start.to_usize().unwrap()..end.to_usize().unwrap()).collect();
130+
indices.shuffle(&mut rng);
131+
132+
// Add shuffled elements
133+
for &index in &indices {
134+
mutable.extend(0, index, index + 1);
135+
}
136+
137+
offsets.push(offsets[row_index] + O::usize_as(length));
138+
}
139+
140+
let data = mutable.freeze();
141+
Ok(Arc::new(GenericListArray::<O>::try_new(
142+
Arc::clone(field),
143+
OffsetBuffer::<O>::new(offsets.into()),
144+
arrow::array::make_array(data),
145+
Some(nulls.into()),
146+
)?))
147+
}
148+
149+
fn fixed_size_array_shuffle(
150+
array: &FixedSizeListArray,
151+
field: &FieldRef,
152+
) -> Result<ArrayRef> {
153+
let values = array.values();
154+
let original_data = values.to_data();
155+
let capacity = Capacities::Array(original_data.len());
156+
let mut nulls = vec![];
157+
let mut mutable =
158+
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
159+
let value_length = array.value_length() as usize;
160+
let mut rng = rng();
161+
162+
for row_index in 0..array.len() {
163+
// skip the null value
164+
if array.is_null(row_index) {
165+
nulls.push(false);
166+
mutable.extend(0, 0, value_length);
167+
continue;
168+
}
169+
nulls.push(true);
170+
171+
let start = row_index * value_length;
172+
let end = start + value_length;
173+
174+
// Create indices and shuffle them
175+
let mut indices: Vec<usize> = (start..end).collect();
176+
indices.shuffle(&mut rng);
177+
178+
// Add shuffled elements
179+
for &index in &indices {
180+
mutable.extend(0, index, index + 1);
181+
}
182+
}
183+
184+
let data = mutable.freeze();
185+
Ok(Arc::new(FixedSizeListArray::try_new(
186+
Arc::clone(field),
187+
array.value_length(),
188+
arrow::array::make_array(data),
189+
Some(nulls.into()),
190+
)?))
191+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# Test shuffle function with simple arrays
19+
query B
20+
SELECT array_sort(shuffle([1, 2, 3, 4, 5, NULL])) = [NULL,1, 2, 3, 4, 5];
21+
----
22+
true
23+
24+
query B
25+
SELECT shuffle([1, 2, 3, 4, 5, NULL]) != [1, 2, 3, 4, 5, NULL];
26+
----
27+
true
28+
29+
# Test shuffle function with string arrays
30+
31+
query B
32+
SELECT array_sort(shuffle(['a', 'b', 'c', 'd', 'e', 'f'])) = ['a', 'b', 'c', 'd', 'e', 'f'];
33+
----
34+
true
35+
36+
query B
37+
SELECT shuffle(['a', 'b', 'c', 'd', 'e', 'f']) != ['a', 'b', 'c', 'd', 'e', 'f'];;
38+
----
39+
true
40+
41+
# Test shuffle function with empty array
42+
query ?
43+
SELECT shuffle([]);
44+
----
45+
[]
46+
47+
# Test shuffle function with single element
48+
query ?
49+
SELECT shuffle([42]);
50+
----
51+
[42]
52+
53+
# Test shuffle function with null array
54+
query ?
55+
SELECT shuffle(NULL);
56+
----
57+
NULL
58+
59+
# Test shuffle function with fixed size list arrays
60+
query B
61+
SELECT array_sort(shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)'))) = [NULL, 1, 2, 3, 4, 5];
62+
----
63+
true
64+
65+
query B
66+
SELECT shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)')) != [1, 2, NULL, 3, 4, 5];
67+
----
68+
true
69+
70+
# Test shuffle on table data with different list types
71+
statement ok
72+
CREATE TABLE test_shuffle_list_types AS VALUES
73+
([1, 2, 3, 4]),
74+
([5, 6, 7, 8, 9]),
75+
([10]),
76+
(NULL),
77+
([]);
78+
79+
# Test shuffle with large list from table
80+
query ?
81+
SELECT array_sort(shuffle(column1)) FROM test_shuffle_list_types;
82+
----
83+
[1, 2, 3, 4]
84+
[5, 6, 7, 8, 9]
85+
[10]
86+
NULL
87+
[]
88+
89+
# Test fixed size list table
90+
statement ok
91+
CREATE TABLE test_shuffle_fixed_size AS VALUES
92+
(arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)')),
93+
(arrow_cast([4, 5, 6], 'FixedSizeList(3, Int64)')),
94+
(arrow_cast([NULL, 8, 9], 'FixedSizeList(3, Int64)')),
95+
(NULL);
96+
97+
# Test shuffle with fixed size list from table
98+
query ?
99+
SELECT array_sort(shuffle(column1)) FROM test_shuffle_fixed_size;
100+
----
101+
[1, 2, 3]
102+
[4, 5, 6]
103+
[NULL, 8, 9]
104+
NULL
105+
106+
# Clean up
107+
statement ok
108+
DROP TABLE test_shuffle_list_types;
109+
110+
statement ok
111+
DROP TABLE test_shuffle_fixed_size;
112+
113+

0 commit comments

Comments
 (0)