Skip to content

Commit

Permalink
Properly support sliced unions (#91)
Browse files Browse the repository at this point in the history
* Properly support sliced unions

* Enable the slice unit-test
  • Loading branch information
jleibs authored Jan 15, 2023
1 parent a98e2bf commit 8d08bb3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 63 deletions.
2 changes: 1 addition & 1 deletion arrow2_convert/tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fn test_nested_unit_variant() {
}

// TODO: reenable this test once slices for enums is fixed.
//#[test]
#[test]
#[allow(unused)]
fn test_slice() {
#[derive(Debug, PartialEq, ArrowField, ArrowSerialize, ArrowDeserialize)]
Expand Down
79 changes: 17 additions & 62 deletions arrow2_convert_derive/src/derive_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {
original_name,
original_name_str,
visibility,
is_dense,
is_dense: _,
variants,
variant_names,
variant_names: _,
variant_indices,
variant_types,
..
Expand All @@ -426,7 +426,7 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {
// for dense unions.
// - For sparse unions, return the value of the variant that corresponds to the matched arm, and
// consume the iterators of the rest of the variants.
let iter_next_match_block = if is_dense {
let iter_next_match_block = {
let candidates = variants.iter()
.zip(&variant_indices)
.zip(&variant_types)
Expand All @@ -435,65 +435,20 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {
if v.is_unit {
quote! {
#lit_idx => {
let v = self.#name.next()
.unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str));
assert!(v.unwrap());
Some(Some(#original_name::#name))
}
}
}
else {
quote! {
#lit_idx => {
let v = self.#name.next()
.unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str));
Some(<#variant_type as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize(v).map(|v| #original_name::#name(v)))
}
}
}
})
.collect::<Vec<TokenStream>>();
quote! { #(#candidates)* }
} else {
let candidates = variants.iter()
.enumerate()
.zip(variant_indices.iter())
.zip(&variant_types)
.map(|(((i, v), lit_idx), variant_type)| {
let consume = variants.iter()
.enumerate()
.map(|(n, v)| {
let name = &v.syn.ident;
if i != n {
quote! {
let _ = self.#name.next();
}
}
else {
quote! {}
}
})
.collect::<Vec<TokenStream>>();
let consume = quote! { #(#consume)* };

let name = &v.syn.ident;
if v.is_unit {
quote! {
#lit_idx => {
#consume
let v = self.#name.next()
.unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str));
assert!(v.unwrap());
Some(Some(#original_name::#name))
}
}
}
else {
quote! {
#lit_idx => {
#consume
let v = self.#name.next()
.unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str));
let (_, offset) = self.arr.index(next_index);
let slice = self.arr.fields()[#lit_idx].slice(offset, 1);
let mut slice_iter = <<#variant_type as arrow2_convert::deserialize::ArrowDeserialize> ::ArrayType as arrow2_convert::deserialize::ArrowArray> ::iter_from_array_ref(slice.deref());
let v = slice_iter
.next()
.unwrap_or_else(|| panic!("Invalid offset for {}", "TensorData"));
Some(<#variant_type as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize(v).map(|v| #original_name::#name(v)))
}
}
Expand All @@ -516,15 +471,12 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {
#[inline]
fn iter_from_array_ref<'a>(b: &'a dyn arrow2::array::Array) -> <&'a Self as IntoIterator>::IntoIter
{
use core::ops::Deref;
let arr = b.as_any().downcast_ref::<arrow2::array::UnionArray>().unwrap();
let fields = arr.fields();

#iterator_name {
#(
#variant_names: <<#variant_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as arrow2_convert::deserialize::ArrowArray>::iter_from_array_ref(fields[#variant_indices].deref()),
)*
arr,
types_iter: arr.types().iter(),
index_iter: 0..arr.len(),
}
}
}
Expand All @@ -545,10 +497,9 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {
let array_iterator_decl = quote! {
#[allow(non_snake_case)]
#visibility struct #iterator_name<'a> {
#(
#variant_names: <&'a <#variant_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as IntoIterator>::IntoIter,
)*
arr: &'a arrow2::array::UnionArray,
types_iter: std::slice::Iter<'a, i8>,
index_iter: std::ops::Range<usize>,
}
};

Expand All @@ -558,6 +509,10 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {

#[inline]
fn next(&mut self) -> Option<Self::Item> {
use core::ops::Deref;
let Some(next_index) = self.index_iter.next() else {
return None;
};
match self.types_iter.next() {
Some(type_idx) => {
match type_idx {
Expand Down

0 comments on commit 8d08bb3

Please sign in to comment.