From 8d08bb323697d34c5f909d02a8f9b94c4c3b400f Mon Sep 17 00:00:00 2001 From: Jeremy Leibs Date: Sun, 15 Jan 2023 12:32:38 +0100 Subject: [PATCH] Properly support sliced unions (#91) * Properly support sliced unions * Enable the slice unit-test --- arrow2_convert/tests/test_enum.rs | 2 +- arrow2_convert_derive/src/derive_enum.rs | 79 +++++------------------- 2 files changed, 18 insertions(+), 63 deletions(-) diff --git a/arrow2_convert/tests/test_enum.rs b/arrow2_convert/tests/test_enum.rs index 576575e..b39cbed 100644 --- a/arrow2_convert/tests/test_enum.rs +++ b/arrow2_convert/tests/test_enum.rs @@ -87,7 +87,7 @@ fn test_nested_unit_variant() { } // TODO: reenable this test once slices for enums is fixed. -//#[test] +#[test] #[allow(unused)] fn test_slice() { #[derive(Debug, PartialEq, ArrowField, ArrowSerialize, ArrowDeserialize)] diff --git a/arrow2_convert_derive/src/derive_enum.rs b/arrow2_convert_derive/src/derive_enum.rs index cf96892..d84db8c 100644 --- a/arrow2_convert_derive/src/derive_enum.rs +++ b/arrow2_convert_derive/src/derive_enum.rs @@ -410,9 +410,9 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream { original_name, original_name_str, visibility, - is_dense, + is_dense: _, variants, - variant_names, + variant_names: _, variant_indices, variant_types, .. @@ -426,7 +426,7 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream { // for dense unions. // - For sparse unions, return the value of the variant that corresponds to the matched arm, and // consume the iterators of the rest of the variants. - let iter_next_match_block = if is_dense { + let iter_next_match_block = { let candidates = variants.iter() .zip(&variant_indices) .zip(&variant_types) @@ -435,9 +435,6 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream { if v.is_unit { quote! { #lit_idx => { - let v = self.#name.next() - .unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str)); - assert!(v.unwrap()); Some(Some(#original_name::#name)) } } @@ -445,55 +442,13 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream { else { quote! { #lit_idx => { - let v = self.#name.next() - .unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str)); - Some(<#variant_type as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize(v).map(|v| #original_name::#name(v))) - } - } - } - }) - .collect::>(); - quote! { #(#candidates)* } - } else { - let candidates = variants.iter() - .enumerate() - .zip(variant_indices.iter()) - .zip(&variant_types) - .map(|(((i, v), lit_idx), variant_type)| { - let consume = variants.iter() - .enumerate() - .map(|(n, v)| { - let name = &v.syn.ident; - if i != n { - quote! { - let _ = self.#name.next(); - } - } - else { - quote! {} - } - }) - .collect::>(); - let consume = quote! { #(#consume)* }; - let name = &v.syn.ident; - if v.is_unit { - quote! { - #lit_idx => { - #consume - let v = self.#name.next() - .unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str)); - assert!(v.unwrap()); - Some(Some(#original_name::#name)) - } - } - } - else { - quote! { - #lit_idx => { - #consume - let v = self.#name.next() - .unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str)); + let (_, offset) = self.arr.index(next_index); + let slice = self.arr.fields()[#lit_idx].slice(offset, 1); + let mut slice_iter = <<#variant_type as arrow2_convert::deserialize::ArrowDeserialize> ::ArrayType as arrow2_convert::deserialize::ArrowArray> ::iter_from_array_ref(slice.deref()); + let v = slice_iter + .next() + .unwrap_or_else(|| panic!("Invalid offset for {}", "TensorData")); Some(<#variant_type as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize(v).map(|v| #original_name::#name(v))) } } @@ -516,15 +471,12 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream { #[inline] fn iter_from_array_ref<'a>(b: &'a dyn arrow2::array::Array) -> <&'a Self as IntoIterator>::IntoIter { - use core::ops::Deref; let arr = b.as_any().downcast_ref::().unwrap(); - let fields = arr.fields(); #iterator_name { - #( - #variant_names: <<#variant_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as arrow2_convert::deserialize::ArrowArray>::iter_from_array_ref(fields[#variant_indices].deref()), - )* + arr, types_iter: arr.types().iter(), + index_iter: 0..arr.len(), } } } @@ -545,10 +497,9 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream { let array_iterator_decl = quote! { #[allow(non_snake_case)] #visibility struct #iterator_name<'a> { - #( - #variant_names: <&'a <#variant_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as IntoIterator>::IntoIter, - )* + arr: &'a arrow2::array::UnionArray, types_iter: std::slice::Iter<'a, i8>, + index_iter: std::ops::Range, } }; @@ -558,6 +509,10 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream { #[inline] fn next(&mut self) -> Option { + use core::ops::Deref; + let Some(next_index) = self.index_iter.next() else { + return None; + }; match self.types_iter.next() { Some(type_idx) => { match type_idx {