Skip to content

Commit 3095a2c

Browse files
committed
ARROW-12426: [Rust] Fix concatentation of arrow dictionaries
1 parent 74d3567 commit 3095a2c

File tree

3 files changed

+189
-14
lines changed

3 files changed

+189
-14
lines changed

arrow/src/array/transform/mod.rs

Lines changed: 106 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::{buffer::MutableBuffer, datatypes::DataType, util::bit_util};
18+
use crate::{
19+
buffer::MutableBuffer,
20+
datatypes::DataType,
21+
error::{ArrowError, Result},
22+
util::bit_util,
23+
};
1924

2025
use super::{
2126
data::{into_buffers, new_buffers},
@@ -166,6 +171,65 @@ impl<'a> std::fmt::Debug for MutableArrayData<'a> {
166171
}
167172
}
168173

174+
/// Builds an extend that adds `offset` to the source primitive
175+
/// Additionally validates that `max` fits into the
176+
/// the underlying primitive returning None if not
177+
fn build_extend_dictionary(
178+
array: &ArrayData,
179+
offset: usize,
180+
max: usize,
181+
) -> Option<Extend> {
182+
use crate::datatypes::*;
183+
use std::convert::TryInto;
184+
185+
match array.data_type() {
186+
DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() {
187+
DataType::UInt8 => {
188+
let _: u8 = max.try_into().ok()?;
189+
let offset: u8 = offset.try_into().ok()?;
190+
Some(primitive::build_extend_with_offset(array, offset))
191+
}
192+
DataType::UInt16 => {
193+
let _: u16 = max.try_into().ok()?;
194+
let offset: u16 = offset.try_into().ok()?;
195+
Some(primitive::build_extend_with_offset(array, offset))
196+
}
197+
DataType::UInt32 => {
198+
let _: u32 = max.try_into().ok()?;
199+
let offset: u32 = offset.try_into().ok()?;
200+
Some(primitive::build_extend_with_offset(array, offset))
201+
}
202+
DataType::UInt64 => {
203+
let _: u64 = max.try_into().ok()?;
204+
let offset: u64 = offset.try_into().ok()?;
205+
Some(primitive::build_extend_with_offset(array, offset))
206+
}
207+
DataType::Int8 => {
208+
let _: i8 = max.try_into().ok()?;
209+
let offset: i8 = offset.try_into().ok()?;
210+
Some(primitive::build_extend_with_offset(array, offset))
211+
}
212+
DataType::Int16 => {
213+
let _: i16 = max.try_into().ok()?;
214+
let offset: i16 = offset.try_into().ok()?;
215+
Some(primitive::build_extend_with_offset(array, offset))
216+
}
217+
DataType::Int32 => {
218+
let _: i32 = max.try_into().ok()?;
219+
let offset: i32 = offset.try_into().ok()?;
220+
Some(primitive::build_extend_with_offset(array, offset))
221+
}
222+
DataType::Int64 => {
223+
let _: i64 = max.try_into().ok()?;
224+
let offset: i64 = offset.try_into().ok()?;
225+
Some(primitive::build_extend_with_offset(array, offset))
226+
}
227+
_ => unreachable!(),
228+
},
229+
_ => None,
230+
}
231+
}
232+
169233
fn build_extend(array: &ArrayData) -> Extend {
170234
use crate::datatypes::*;
171235
match array.data_type() {
@@ -199,17 +263,7 @@ fn build_extend(array: &ArrayData) -> Extend {
199263
}
200264
DataType::List(_) => list::build_extend::<i32>(array),
201265
DataType::LargeList(_) => list::build_extend::<i64>(array),
202-
DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() {
203-
DataType::UInt8 => primitive::build_extend::<u8>(array),
204-
DataType::UInt16 => primitive::build_extend::<u16>(array),
205-
DataType::UInt32 => primitive::build_extend::<u32>(array),
206-
DataType::UInt64 => primitive::build_extend::<u64>(array),
207-
DataType::Int8 => primitive::build_extend::<i8>(array),
208-
DataType::Int16 => primitive::build_extend::<i16>(array),
209-
DataType::Int32 => primitive::build_extend::<i32>(array),
210-
DataType::Int64 => primitive::build_extend::<i64>(array),
211-
_ => unreachable!(),
212-
},
266+
DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"),
213267
DataType::Struct(_) => structure::build_extend(array),
214268
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
215269
DataType::Float16 => unreachable!(),
@@ -339,7 +393,29 @@ impl<'a> MutableArrayData<'a> {
339393
};
340394

341395
let dictionary = match &data_type {
342-
DataType::Dictionary(_, _) => Some(arrays[0].child_data()[0].clone()),
396+
DataType::Dictionary(_, _) => match arrays.len() {
397+
0 => unreachable!(),
398+
1 => Some(arrays[0].child_data()[0].clone()),
399+
_ => {
400+
// Concat dictionaries together
401+
let dictionaries: Vec<_> =
402+
arrays.iter().map(|array| &array.child_data()[0]).collect();
403+
let lengths: Vec<_> = dictionaries
404+
.iter()
405+
.map(|dictionary| dictionary.len())
406+
.collect();
407+
let capacity = lengths.iter().sum();
408+
409+
let mut mutable =
410+
MutableArrayData::new(dictionaries, false, capacity);
411+
412+
for (i, len) in lengths.iter().enumerate() {
413+
mutable.extend(i, 0, *len)
414+
}
415+
416+
Some(mutable.freeze())
417+
}
418+
},
343419
_ => None,
344420
};
345421

@@ -353,7 +429,23 @@ impl<'a> MutableArrayData<'a> {
353429
let null_bytes = bit_util::ceil(capacity, 8);
354430
let null_buffer = MutableBuffer::from_len_zeroed(null_bytes);
355431

356-
let extend_values = arrays.iter().map(|array| build_extend(array)).collect();
432+
let extend_values = match &data_type {
433+
DataType::Dictionary(_, _) => {
434+
let mut next_offset = 0;
435+
let extend_values: Result<Vec<_>> = arrays
436+
.iter()
437+
.map(|array| {
438+
let offset = next_offset;
439+
next_offset += array.child_data()[0].len();
440+
build_extend_dictionary(array, offset, next_offset)
441+
.ok_or(ArrowError::DictionaryKeyOverflowError)
442+
})
443+
.collect();
444+
445+
extend_values.expect("MutableArrayData::new is infallible")
446+
}
447+
_ => arrays.iter().map(|array| build_extend(array)).collect(),
448+
};
357449

358450
let data = _MutableArrayData {
359451
data_type: data_type.clone(),

arrow/src/array/transform/primitive.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use std::mem::size_of;
19+
use std::ops::Add;
1920

2021
use crate::{array::ArrayData, datatypes::ArrowNativeType};
2122

@@ -32,6 +33,20 @@ pub(super) fn build_extend<T: ArrowNativeType>(array: &ArrayData) -> Extend {
3233
)
3334
}
3435

36+
pub(super) fn build_extend_with_offset<T>(array: &ArrayData, offset: T) -> Extend
37+
where
38+
T: ArrowNativeType + Add<Output = T>,
39+
{
40+
let values = array.buffer::<T>(0);
41+
Box::new(
42+
move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| {
43+
mutable
44+
.buffer1
45+
.extend(values[start..start + len].iter().map(|x| *x + offset));
46+
},
47+
)
48+
}
49+
3550
pub(super) fn extend_nulls<T: ArrowNativeType>(
3651
mutable: &mut _MutableArrayData,
3752
len: usize,

arrow/src/compute/kernels/concat.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,4 +384,72 @@ mod tests {
384384

385385
Ok(())
386386
}
387+
388+
fn collect_string_dictionary(
389+
dictionary: &DictionaryArray<Int32Type>,
390+
) -> Vec<Option<String>> {
391+
let values = dictionary.values();
392+
let values = values.as_any().downcast_ref::<StringArray>().unwrap();
393+
394+
dictionary
395+
.keys()
396+
.iter()
397+
.map(|key| key.map(|key| values.value(key as _).to_string()))
398+
.collect()
399+
}
400+
401+
fn concat_dictionary(
402+
input_1: DictionaryArray<Int32Type>,
403+
input_2: DictionaryArray<Int32Type>,
404+
) -> Vec<Option<String>> {
405+
let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
406+
let concat = concat
407+
.as_any()
408+
.downcast_ref::<DictionaryArray<Int32Type>>()
409+
.unwrap();
410+
411+
collect_string_dictionary(concat)
412+
}
413+
414+
#[test]
415+
fn test_string_dictionary_array() {
416+
let input_1: DictionaryArray<Int32Type> =
417+
vec!["hello", "A", "B", "hello", "hello", "C"]
418+
.into_iter()
419+
.collect();
420+
let input_2: DictionaryArray<Int32Type> =
421+
vec!["hello", "E", "E", "hello", "F", "E"]
422+
.into_iter()
423+
.collect();
424+
425+
let expected: Vec<_> = vec![
426+
"hello", "A", "B", "hello", "hello", "C", "hello", "E", "E", "hello", "F",
427+
"E",
428+
]
429+
.into_iter()
430+
.map(|x| Some(x.to_string()))
431+
.collect();
432+
433+
let concat = concat_dictionary(input_1, input_2);
434+
assert_eq!(concat, expected);
435+
}
436+
437+
#[test]
438+
fn test_string_dictionary_array_nulls() {
439+
let input_1: DictionaryArray<Int32Type> =
440+
vec![Some("foo"), Some("bar"), None, Some("fiz")]
441+
.into_iter()
442+
.collect();
443+
let input_2: DictionaryArray<Int32Type> = vec![None].into_iter().collect();
444+
let expected = vec![
445+
Some("foo".to_string()),
446+
Some("bar".to_string()),
447+
None,
448+
Some("fiz".to_string()),
449+
None,
450+
];
451+
452+
let concat = concat_dictionary(input_1, input_2);
453+
assert_eq!(concat, expected);
454+
}
387455
}

0 commit comments

Comments
 (0)