From 2450c60f378a3ba8430365d465a0ccb6f472ecc5 Mon Sep 17 00:00:00 2001 From: ZhangHuiGui Date: Wed, 10 Apr 2024 22:35:04 +0800 Subject: [PATCH] Refactor codes --- cpp/src/arrow/array/array_list_test.cc | 14 +- cpp/src/arrow/array/array_nested.cc | 173 ++++++++++--------------- cpp/src/arrow/array/array_nested.h | 33 ++--- 3 files changed, 91 insertions(+), 129 deletions(-) diff --git a/cpp/src/arrow/array/array_list_test.cc b/cpp/src/arrow/array/array_list_test.cc index 5ef018c699718..b4953416af4ec 100644 --- a/cpp/src/arrow/array/array_list_test.cc +++ b/cpp/src/arrow/array/array_list_test.cc @@ -763,7 +763,7 @@ class TestListArray : public ::testing::Test { << flattened->ToString(); } - void TestFlattenRecursion() { + void TestFlattenRecursively() { auto inner_type = std::make_shared(int32()); auto type = std::make_shared(inner_type); @@ -773,7 +773,7 @@ class TestListArray : public ::testing::Test { [null], [[2, 9], [4], [], [6, 5]] ])")); - ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->FlattenRecursion()); + ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->FlattenRecursively()); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(9, flattened->length()); ASSERT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[0, 1, 2, 3, 2, 9, 4, 6, 5]"))); @@ -781,7 +781,7 @@ class TestListArray : public ::testing::Test { // Empty nested list should flatten until reach it's non-list type nested_list_array = std::dynamic_pointer_cast(ArrayFromJSON(type, R"([null])")); - ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursion()); + ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursively()); ASSERT_TRUE(flattened->type()->Equals(int32())); // List type with three nested level: list(list(list(int32))) @@ -800,7 +800,7 @@ class TestListArray : public ::testing::Test { null ] ])")); - ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursion()); + ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursively()); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(7, flattened->length()); ASSERT_EQ(2, flattened->null_count()); @@ -974,7 +974,7 @@ TYPED_TEST(TestListArray, FlattenZeroLength) { this->TestFlattenZeroLength(); } TYPED_TEST(TestListArray, TestFlattenNonEmptyBackingNulls) { this->TestFlattenNonEmptyBackingNulls(); } -TYPED_TEST(TestListArray, FlattenRecursion) { this->TestFlattenRecursion(); } +TYPED_TEST(TestListArray, FlattenRecursively) { this->TestFlattenRecursively(); } TYPED_TEST(TestListArray, ValidateDimensions) { this->TestValidateDimensions(); } @@ -1760,7 +1760,7 @@ TEST_F(TestFixedSizeListArray, Flatten) { } } -TEST_F(TestFixedSizeListArray, FlattenRecursion) { +TEST_F(TestFixedSizeListArray, FlattenRecursively) { // Nested fixed-size list-array: fixed_size_list(fixed_size_list(int32, 2), 2) auto inner_type = fixed_size_list(value_type_, 2); type_ = fixed_size_list(inner_type, 2); @@ -1771,7 +1771,7 @@ TEST_F(TestFixedSizeListArray, FlattenRecursion) { [null, null] ])")); ASSERT_OK(values->ValidateFull()); - ASSERT_OK_AND_ASSIGN(auto flattened, values->FlattenRecursion()); + ASSERT_OK_AND_ASSIGN(auto flattened, values->FlattenRecursively()); ASSERT_OK(flattened->ValidateFull()); ASSERT_EQ(8, flattened->length()); ASSERT_EQ(2, flattened->null_count()); diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc index f64e83ad49915..71b196b715543 100644 --- a/cpp/src/arrow/array/array_nested.cc +++ b/cpp/src/arrow/array/array_nested.cc @@ -216,26 +216,12 @@ static std::shared_ptr SliceArrayWithOffsets(const Array& array, int64_t return array.Slice(begin, end - begin); } -namespace { -struct FlattenWithRecursion { - // Flatten all list-like types array recursively - static Result> Flatten(const Array& array, bool with_recursion, - MemoryPool* memory_pool); -}; -} // namespace - template Result> FlattenListArray(const ListArrayT& list_array, - bool with_recursion, MemoryPool* memory_pool) { const int64_t list_array_length = list_array.length(); std::shared_ptr value_array = list_array.values(); - // If it's a nested-list related array, flatten recursively. - if (is_list_like(value_array->type_id()) && with_recursion) { - return FlattenWithRecursion::Flatten(*value_array, with_recursion, memory_pool); - } - // Shortcut: if a ListArray does not contain nulls, then simply slice its // value array with the first and the last offsets. if (list_array.null_count() == 0) { @@ -279,18 +265,12 @@ Result> FlattenListArray(const ListArrayT& list_array, template Result> FlattenListViewArray(const ListViewArrayT& list_view_array, - bool with_recursion, MemoryPool* memory_pool) { using offset_type = typename ListViewArrayT::offset_type; const int64_t list_view_array_offset = list_view_array.offset(); const int64_t list_view_array_length = list_view_array.length(); std::shared_ptr value_array = list_view_array.values(); - // If it's a nested list-view, flatten recursively. - if (is_list_view(value_array->type()->id()) && with_recursion) { - return FlattenWithRecursion::Flatten(*value_array, with_recursion, memory_pool); - } - if (list_view_array_length == 0) { return SliceArrayWithOffsets(*value_array, 0, 0); } @@ -371,44 +351,6 @@ Result> FlattenListViewArray(const ListViewArrayT& list_v return Concatenate(slices, memory_pool); } -Result> FlattenWithRecursion::Flatten(const Array& array, - bool with_recursion, - MemoryPool* memory_pool) { - const bool has_nulls = array.null_count() > 0; - switch (array.type_id()) { - case Type::LIST: - return FlattenListArray(checked_cast(array), with_recursion, - memory_pool); - case Type::LARGE_LIST: - return FlattenListArray(checked_cast(array), with_recursion, - memory_pool); - case Type::FIXED_SIZE_LIST: - return FlattenListArray(checked_cast(array), - with_recursion, memory_pool); - case Type::LIST_VIEW: { - if (has_nulls) { - return FlattenListViewArray( - checked_cast(array), with_recursion, memory_pool); - } else { - return FlattenListViewArray( - checked_cast(array), with_recursion, memory_pool); - } - } - case Type::LARGE_LIST_VIEW: { - if (has_nulls) { - return FlattenListViewArray( - checked_cast(array), with_recursion, memory_pool); - } else { - return FlattenListViewArray( - checked_cast(array), with_recursion, memory_pool); - } - } - default: - return Status::Invalid("Unknown or unsupported arrow nested type: ", - array.type()->ToString()); - } -} - std::shared_ptr BoxOffsets(const std::shared_ptr& boxed_type, const ArrayData& data) { const int64_t num_offsets = @@ -527,6 +469,69 @@ inline void SetListData(VarLengthListLikeArray* self, self->values_ = MakeArray(self->data_->child_data[0]); } +Result> FlattenLogicalListRecursively(const Array& array, + MemoryPool* memory_pool) { + Type::type kind = array.type_id(); + std::shared_ptr in_array = array.Slice(0, array.length()); + while (is_list_like(kind) || is_list_view(kind)) { + const bool has_nulls = array.null_count() > 0; + std::shared_ptr out; + switch (kind) { + case Type::LIST: { + ARROW_ASSIGN_OR_RAISE( + out, + FlattenListArray(checked_cast(*in_array), memory_pool)); + break; + } + case Type::LARGE_LIST: { + ARROW_ASSIGN_OR_RAISE( + out, FlattenListArray(checked_cast(*in_array), + memory_pool)); + break; + } + case Type::FIXED_SIZE_LIST: { + ARROW_ASSIGN_OR_RAISE( + out, FlattenListArray(checked_cast(*in_array), + memory_pool)); + break; + } + case Type::LIST_VIEW: { + if (has_nulls) { + ARROW_ASSIGN_OR_RAISE( + out, (FlattenListViewArray( + checked_cast(*in_array), memory_pool))); + break; + } else { + ARROW_ASSIGN_OR_RAISE( + out, (FlattenListViewArray( + checked_cast(*in_array), memory_pool))); + break; + } + } + case Type::LARGE_LIST_VIEW: { + if (has_nulls) { + ARROW_ASSIGN_OR_RAISE( + out, (FlattenListViewArray( + checked_cast(*in_array), memory_pool))); + break; + } else { + ARROW_ASSIGN_OR_RAISE( + out, (FlattenListViewArray( + checked_cast(*in_array), memory_pool))); + break; + } + } + default: + return Status::Invalid("Unknown or unsupported arrow nested type: ", + in_array->type()->ToString()); + } + + in_array = out; + kind = in_array->type_id(); + } + return std::move(in_array); +} + } // namespace internal // ---------------------------------------------------------------------- @@ -581,12 +586,7 @@ Result> ListArray::FromArrays( } Result> ListArray::Flatten(MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/false, memory_pool); -} - -Result> ListArray::FlattenRecursion( - MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/true, memory_pool); + return FlattenListArray(*this, memory_pool); } std::shared_ptr ListArray::offsets() const { return BoxOffsets(int32(), *data_); } @@ -645,12 +645,7 @@ Result> LargeListArray::FromArrays( } Result> LargeListArray::Flatten(MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/false, memory_pool); -} - -Result> LargeListArray::FlattenRecursion( - MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/true, memory_pool); + return FlattenListArray(*this, memory_pool); } std::shared_ptr LargeListArray::offsets() const { @@ -721,21 +716,9 @@ Result> LargeListViewArray::FromList( Result> ListViewArray::Flatten(MemoryPool* memory_pool) const { if (null_count() > 0) { - return FlattenListViewArray(*this, /*with_recursion=*/false, - memory_pool); - } - return FlattenListViewArray(*this, /*with_recursion=*/false, - memory_pool); -} - -Result> ListViewArray::FlattenRecursion( - MemoryPool* memory_pool) const { - if (null_count() > 0) { - return FlattenListViewArray(*this, /*with_recursion=*/true, - memory_pool); + return FlattenListViewArray(*this, memory_pool); } - return FlattenListViewArray(*this, /*with_recursion=*/true, - memory_pool); + return FlattenListViewArray(*this, memory_pool); } std::shared_ptr ListViewArray::offsets() const { @@ -794,21 +777,9 @@ Result> LargeListViewArray::FromArrays( Result> LargeListViewArray::Flatten( MemoryPool* memory_pool) const { if (null_count() > 0) { - return FlattenListViewArray(*this, /*with_recursion=*/false, - memory_pool); - } - return FlattenListViewArray(*this, /*with_recursion=*/false, - memory_pool); -} - -Result> LargeListViewArray::FlattenRecursion( - MemoryPool* memory_pool) const { - if (null_count() > 0) { - return FlattenListViewArray(*this, /*with_recursion=*/true, - memory_pool); + return FlattenListViewArray(*this, memory_pool); } - return FlattenListViewArray(*this, /*with_recursion=*/true, - memory_pool); + return FlattenListViewArray(*this, memory_pool); } std::shared_ptr LargeListViewArray::offsets() const { @@ -1026,12 +997,12 @@ Result> FixedSizeListArray::FromArrays( Result> FixedSizeListArray::Flatten( MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/false, memory_pool); + return FlattenListArray(*this, memory_pool); } -Result> FixedSizeListArray::FlattenRecursion( +Result> FixedSizeListArray::FlattenRecursively( MemoryPool* memory_pool) const { - return FlattenListArray(*this, /*with_recursion=*/true, memory_pool); + return internal::FlattenLogicalListRecursively(*this, memory_pool); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array/array_nested.h b/cpp/src/arrow/array/array_nested.h index 94ac8ea68bca7..e39ce77d0dfc3 100644 --- a/cpp/src/arrow/array/array_nested.h +++ b/cpp/src/arrow/array/array_nested.h @@ -58,6 +58,10 @@ void SetListData(VarLengthListLikeArray* self, const std::shared_ptr& data, Type::type expected_type_id = TYPE::type_id); +// Private flatten helper for logical lists: [Large]List[View]Array, FixedSizeListArray +// and MapArray +ARROW_EXPORT Result> FlattenLogicalListRecursively( + const Array& array, MemoryPool* memory_pool); } // namespace internal /// Base class for variable-sized list and list-view arrays, regardless of offset size. @@ -103,6 +107,13 @@ class VarLengthListLikeArray : public Array { return values_->Slice(value_offset(i), value_length(i)); } + /// \brief Flatten all level recursively until reach a non-list type, and return a + /// non-list type Array. + Result> FlattenRecursively( + MemoryPool* memory_pool = default_memory_pool()) const { + return internal::FlattenLogicalListRecursively(*this, memory_pool); + } + protected: friend void internal::SetListData(VarLengthListLikeArray* self, const std::shared_ptr& data, @@ -189,11 +200,6 @@ class ARROW_EXPORT ListArray : public BaseListArray { Result> Flatten( MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Flatten all level recursively until reach a non-list type, and return a - /// non-list type Array. - Result> FlattenRecursion( - MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Return list offsets as an Int32Array /// /// The returned array will not have a validity bitmap, so you cannot expect @@ -262,11 +268,6 @@ class ARROW_EXPORT LargeListArray : public BaseListArray { Result> Flatten( MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Flatten all level recursively until reach a non-list type, and return a - /// non-list type Array. - Result> FlattenRecursion( - MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Return list offsets as an Int64Array std::shared_ptr offsets() const; @@ -374,11 +375,6 @@ class ARROW_EXPORT ListViewArray : public BaseListViewArray { Result> Flatten( MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Flatten all level recursively until reach a non-list type, and return a - /// non-list type Array. - Result> FlattenRecursion( - MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Return list-view offsets as an Int32Array /// /// The returned array will not have a validity bitmap, so you cannot expect @@ -463,11 +459,6 @@ class ARROW_EXPORT LargeListViewArray : public BaseListViewArray> Flatten( MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Flatten all level recursively until reach a non-list type, and return a - /// non-list type Array. - Result> FlattenRecursion( - MemoryPool* memory_pool = default_memory_pool()) const; - /// \brief Return list-view offsets as an Int64Array /// /// The returned array will not have a validity bitmap, so you cannot expect @@ -617,7 +608,7 @@ class ARROW_EXPORT FixedSizeListArray : public Array { /// \brief Flatten all level recursively until reach a non-list type, and return a /// non-list type Array. - Result> FlattenRecursion( + Result> FlattenRecursively( MemoryPool* memory_pool = default_memory_pool()) const; /// \brief Construct FixedSizeListArray from child value array and value_length