Skip to content

Commit b9c10d7

Browse files
davidhewittfriendlymatthew
authored andcommitted
only attempt merge on supported types
1 parent 1b17a76 commit b9c10d7

File tree

3 files changed

+52
-12
lines changed

3 files changed

+52
-12
lines changed

arrow-select/src/concat.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
//! assert_eq!(arr.len(), 3);
3131
//! ```
3232
33-
use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values};
33+
use crate::dictionary::{
34+
merge_dictionary_values, should_merge_dictionary_values, ShouldMergeValues,
35+
};
3436
use arrow_array::builder::{
3537
BooleanBuilder, GenericByteBuilder, PrimitiveBuilder, PrimitiveDictionaryBuilder,
3638
};
@@ -96,9 +98,14 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
9698
.map(|x| x.as_dictionary::<K>())
9799
.inspect(|d| output_len += d.len())
98100
.collect();
99-
if !should_merge_dictionary_values::<K>(&dictionaries, output_len) {
100-
return concat_fallback(arrays, Capacities::Array(output_len));
101-
}
101+
102+
let is_overflow = match should_merge_dictionary_values::<K>(&dictionaries, output_len) {
103+
ShouldMergeValues::ConcatWillOverflow => true,
104+
ShouldMergeValues::Yes => false,
105+
ShouldMergeValues::No => {
106+
return concat_fallback(arrays, Capacities::Array(output_len));
107+
}
108+
};
102109

103110
macro_rules! primitive_dict_helper {
104111
($t:ty) => {
@@ -111,8 +118,12 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
111118
DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
112119
merge_concat_byte_dictionaries(&dictionaries, output_len)
113120
},
121+
// merge not yet implemented for this type and it's not going to overflow, so fall back
122+
// to concatenating values
123+
_ if !is_overflow => concat_fallback(arrays, Capacities::Array(output_len)),
114124
other => Err(ArrowError::NotYetImplemented(format!(
115-
"interleave does not yet support merging dictionaries with value type {other:?}"
125+
"concat of dictionaries would overflow key type {key_type:?} with value type {other:?}",
126+
key_type = K::DATA_TYPE,
116127
)))
117128
}
118129
}

arrow-select/src/dictionary.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,17 @@ fn bytes_ptr_eq<T: ByteArrayType>(a: &dyn Array, b: &dyn Array) -> bool {
101101
}
102102
}
103103

104+
/// Whether selection kernels should attempt to merge dictionary values
105+
pub enum ShouldMergeValues {
106+
/// Concatenation of the dictionary values will lead to overflowing
107+
/// the key space; it's necessary to attempt to merge
108+
ConcatWillOverflow,
109+
/// The heuristic suggests that merging will be beneficial
110+
Yes,
111+
/// The heuristic suggests that merging is not necessary
112+
No,
113+
}
114+
104115
/// A type-erased function that compares two array for pointer equality
105116
type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
106117

@@ -112,7 +123,7 @@ type PtrEq = fn(&dyn Array, &dyn Array) -> bool;
112123
pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
113124
dictionaries: &[&DictionaryArray<K>],
114125
len: usize,
115-
) -> bool {
126+
) -> ShouldMergeValues {
116127
use DataType::*;
117128
let first_values = dictionaries[0].values().as_ref();
118129
let ptr_eq: PtrEq = match first_values.data_type() {
@@ -136,7 +147,15 @@ pub fn should_merge_dictionary_values<K: ArrowDictionaryKeyType>(
136147
let overflow = K::Native::from_usize(total_values).is_none();
137148
let values_exceed_length = total_values >= len;
138149

139-
!single_dictionary && (overflow || values_exceed_length)
150+
if single_dictionary {
151+
ShouldMergeValues::No
152+
} else if overflow {
153+
ShouldMergeValues::ConcatWillOverflow
154+
} else if values_exceed_length {
155+
ShouldMergeValues::Yes
156+
} else {
157+
ShouldMergeValues::No
158+
}
140159
}
141160

142161
/// Given an array of dictionaries and an optional key mask compute a values array

arrow-select/src/interleave.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
//! Interleave elements from multiple arrays
1919
20-
use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values};
20+
use crate::dictionary::{
21+
merge_dictionary_values, should_merge_dictionary_values, ShouldMergeValues,
22+
};
2123
use arrow_array::builder::{
2224
BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder, PrimitiveDictionaryBuilder,
2325
};
@@ -198,9 +200,13 @@ fn interleave_dictionaries<K: ArrowDictionaryKeyType>(
198200
indices: &[(usize, usize)],
199201
) -> Result<ArrayRef, ArrowError> {
200202
let dictionaries: Vec<_> = arrays.iter().map(|x| x.as_dictionary::<K>()).collect();
201-
if !should_merge_dictionary_values::<K>(&dictionaries, indices.len()) {
202-
return interleave_fallback(arrays, indices);
203-
}
203+
let is_overflow = match should_merge_dictionary_values::<K>(&dictionaries, indices.len()) {
204+
ShouldMergeValues::ConcatWillOverflow => true,
205+
ShouldMergeValues::Yes => false,
206+
ShouldMergeValues::No => {
207+
return interleave_fallback(arrays, indices);
208+
}
209+
};
204210

205211
macro_rules! primitive_dict_helper {
206212
($t:ty) => {
@@ -213,8 +219,12 @@ fn interleave_dictionaries<K: ArrowDictionaryKeyType>(
213219
DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
214220
merge_interleave_byte_dictionaries(&dictionaries, indices)
215221
},
222+
// merge not yet implemented for this type and it's not going to overflow, so fall back
223+
// to concatenating values
224+
_ if !is_overflow => interleave_fallback(arrays, indices),
216225
other => Err(ArrowError::NotYetImplemented(format!(
217-
"interleave does not yet support merging dictionaries with value type {other:?}"
226+
"interleave of dictionaries would overflow key type {key_type:?} with value type {other:?}",
227+
key_type = K::DATA_TYPE,
218228
)))
219229
}
220230
}

0 commit comments

Comments
 (0)