Skip to content

Commit

Permalink
Refactor codes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangHuiGui committed Apr 10, 2024
1 parent 3d5a2ec commit 2450c60
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 129 deletions.
14 changes: 7 additions & 7 deletions cpp/src/arrow/array/array_list_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ class TestListArray : public ::testing::Test {
<< flattened->ToString();
}

void TestFlattenRecursion() {
void TestFlattenRecursively() {
auto inner_type = std::make_shared<T>(int32());
auto type = std::make_shared<T>(inner_type);

Expand All @@ -773,15 +773,15 @@ class TestListArray : public ::testing::Test {
[null],
[[2, 9], [4], [], [6, 5]]
])"));
ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->FlattenRecursion());
ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->FlattenRecursively());
ASSERT_OK(flattened->ValidateFull());
ASSERT_EQ(9, flattened->length());
ASSERT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[0, 1, 2, 3, 2, 9, 4, 6, 5]")));

// Empty nested list should flatten until reach it's non-list type
nested_list_array =
std::dynamic_pointer_cast<ArrayType>(ArrayFromJSON(type, R"([null])"));
ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursion());
ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursively());
ASSERT_TRUE(flattened->type()->Equals(int32()));

// List type with three nested level: list(list(list(int32)))
Expand All @@ -800,7 +800,7 @@ class TestListArray : public ::testing::Test {
null
]
])"));
ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursion());
ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursively());
ASSERT_OK(flattened->ValidateFull());
ASSERT_EQ(7, flattened->length());
ASSERT_EQ(2, flattened->null_count());
Expand Down Expand Up @@ -974,7 +974,7 @@ TYPED_TEST(TestListArray, FlattenZeroLength) { this->TestFlattenZeroLength(); }
TYPED_TEST(TestListArray, TestFlattenNonEmptyBackingNulls) {
this->TestFlattenNonEmptyBackingNulls();
}
TYPED_TEST(TestListArray, FlattenRecursion) { this->TestFlattenRecursion(); }
TYPED_TEST(TestListArray, FlattenRecursively) { this->TestFlattenRecursively(); }

TYPED_TEST(TestListArray, ValidateDimensions) { this->TestValidateDimensions(); }

Expand Down Expand Up @@ -1760,7 +1760,7 @@ TEST_F(TestFixedSizeListArray, Flatten) {
}
}

TEST_F(TestFixedSizeListArray, FlattenRecursion) {
TEST_F(TestFixedSizeListArray, FlattenRecursively) {
// Nested fixed-size list-array: fixed_size_list(fixed_size_list(int32, 2), 2)
auto inner_type = fixed_size_list(value_type_, 2);
type_ = fixed_size_list(inner_type, 2);
Expand All @@ -1771,7 +1771,7 @@ TEST_F(TestFixedSizeListArray, FlattenRecursion) {
[null, null]
])"));
ASSERT_OK(values->ValidateFull());
ASSERT_OK_AND_ASSIGN(auto flattened, values->FlattenRecursion());
ASSERT_OK_AND_ASSIGN(auto flattened, values->FlattenRecursively());
ASSERT_OK(flattened->ValidateFull());
ASSERT_EQ(8, flattened->length());
ASSERT_EQ(2, flattened->null_count());
Expand Down
173 changes: 72 additions & 101 deletions cpp/src/arrow/array/array_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,26 +216,12 @@ static std::shared_ptr<Array> SliceArrayWithOffsets(const Array& array, int64_t
return array.Slice(begin, end - begin);
}

namespace {
struct FlattenWithRecursion {
// Flatten all list-like types array recursively
static Result<std::shared_ptr<Array>> Flatten(const Array& array, bool with_recursion,
MemoryPool* memory_pool);
};
} // namespace

template <typename ListArrayT>
Result<std::shared_ptr<Array>> FlattenListArray(const ListArrayT& list_array,
bool with_recursion,
MemoryPool* memory_pool) {
const int64_t list_array_length = list_array.length();
std::shared_ptr<arrow::Array> value_array = list_array.values();

// If it's a nested-list related array, flatten recursively.
if (is_list_like(value_array->type_id()) && with_recursion) {
return FlattenWithRecursion::Flatten(*value_array, with_recursion, memory_pool);
}

// Shortcut: if a ListArray does not contain nulls, then simply slice its
// value array with the first and the last offsets.
if (list_array.null_count() == 0) {
Expand Down Expand Up @@ -279,18 +265,12 @@ Result<std::shared_ptr<Array>> FlattenListArray(const ListArrayT& list_array,

template <typename ListViewArrayT, bool HasNulls>
Result<std::shared_ptr<Array>> FlattenListViewArray(const ListViewArrayT& list_view_array,
bool with_recursion,
MemoryPool* memory_pool) {
using offset_type = typename ListViewArrayT::offset_type;
const int64_t list_view_array_offset = list_view_array.offset();
const int64_t list_view_array_length = list_view_array.length();
std::shared_ptr<arrow::Array> value_array = list_view_array.values();

// If it's a nested list-view, flatten recursively.
if (is_list_view(value_array->type()->id()) && with_recursion) {
return FlattenWithRecursion::Flatten(*value_array, with_recursion, memory_pool);
}

if (list_view_array_length == 0) {
return SliceArrayWithOffsets(*value_array, 0, 0);
}
Expand Down Expand Up @@ -371,44 +351,6 @@ Result<std::shared_ptr<Array>> FlattenListViewArray(const ListViewArrayT& list_v
return Concatenate(slices, memory_pool);
}

Result<std::shared_ptr<Array>> FlattenWithRecursion::Flatten(const Array& array,
bool with_recursion,
MemoryPool* memory_pool) {
const bool has_nulls = array.null_count() > 0;
switch (array.type_id()) {
case Type::LIST:
return FlattenListArray(checked_cast<const ListArray&>(array), with_recursion,
memory_pool);
case Type::LARGE_LIST:
return FlattenListArray(checked_cast<const LargeListArray&>(array), with_recursion,
memory_pool);
case Type::FIXED_SIZE_LIST:
return FlattenListArray(checked_cast<const FixedSizeListArray&>(array),
with_recursion, memory_pool);
case Type::LIST_VIEW: {
if (has_nulls) {
return FlattenListViewArray<ListViewArray, true>(
checked_cast<const ListViewArray&>(array), with_recursion, memory_pool);
} else {
return FlattenListViewArray<ListViewArray, false>(
checked_cast<const ListViewArray&>(array), with_recursion, memory_pool);
}
}
case Type::LARGE_LIST_VIEW: {
if (has_nulls) {
return FlattenListViewArray<LargeListViewArray, true>(
checked_cast<const LargeListViewArray&>(array), with_recursion, memory_pool);
} else {
return FlattenListViewArray<LargeListViewArray, false>(
checked_cast<const LargeListViewArray&>(array), with_recursion, memory_pool);
}
}
default:
return Status::Invalid("Unknown or unsupported arrow nested type: ",
array.type()->ToString());
}
}

std::shared_ptr<Array> BoxOffsets(const std::shared_ptr<DataType>& boxed_type,
const ArrayData& data) {
const int64_t num_offsets =
Expand Down Expand Up @@ -527,6 +469,69 @@ inline void SetListData(VarLengthListLikeArray<TYPE>* self,
self->values_ = MakeArray(self->data_->child_data[0]);
}

Result<std::shared_ptr<Array>> FlattenLogicalListRecursively(const Array& array,
MemoryPool* memory_pool) {
Type::type kind = array.type_id();
std::shared_ptr<Array> in_array = array.Slice(0, array.length());
while (is_list_like(kind) || is_list_view(kind)) {
const bool has_nulls = array.null_count() > 0;
std::shared_ptr<Array> out;
switch (kind) {
case Type::LIST: {
ARROW_ASSIGN_OR_RAISE(
out,
FlattenListArray(checked_cast<const ListArray&>(*in_array), memory_pool));
break;
}
case Type::LARGE_LIST: {
ARROW_ASSIGN_OR_RAISE(
out, FlattenListArray(checked_cast<const LargeListArray&>(*in_array),
memory_pool));
break;
}
case Type::FIXED_SIZE_LIST: {
ARROW_ASSIGN_OR_RAISE(
out, FlattenListArray(checked_cast<const FixedSizeListArray&>(*in_array),
memory_pool));
break;
}
case Type::LIST_VIEW: {
if (has_nulls) {
ARROW_ASSIGN_OR_RAISE(
out, (FlattenListViewArray<ListViewArray, true>(
checked_cast<const ListViewArray&>(*in_array), memory_pool)));
break;
} else {
ARROW_ASSIGN_OR_RAISE(
out, (FlattenListViewArray<ListViewArray, false>(
checked_cast<const ListViewArray&>(*in_array), memory_pool)));
break;
}
}
case Type::LARGE_LIST_VIEW: {
if (has_nulls) {
ARROW_ASSIGN_OR_RAISE(
out, (FlattenListViewArray<LargeListViewArray, true>(
checked_cast<const LargeListViewArray&>(*in_array), memory_pool)));
break;
} else {
ARROW_ASSIGN_OR_RAISE(
out, (FlattenListViewArray<LargeListViewArray, false>(
checked_cast<const LargeListViewArray&>(*in_array), memory_pool)));
break;
}
}
default:
return Status::Invalid("Unknown or unsupported arrow nested type: ",
in_array->type()->ToString());
}

in_array = out;
kind = in_array->type_id();
}
return std::move(in_array);
}

} // namespace internal

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -581,12 +586,7 @@ Result<std::shared_ptr<ListArray>> ListArray::FromArrays(
}

Result<std::shared_ptr<Array>> ListArray::Flatten(MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/false, memory_pool);
}

Result<std::shared_ptr<Array>> ListArray::FlattenRecursion(
MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/true, memory_pool);
return FlattenListArray(*this, memory_pool);
}

std::shared_ptr<Array> ListArray::offsets() const { return BoxOffsets(int32(), *data_); }
Expand Down Expand Up @@ -645,12 +645,7 @@ Result<std::shared_ptr<LargeListArray>> LargeListArray::FromArrays(
}

Result<std::shared_ptr<Array>> LargeListArray::Flatten(MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/false, memory_pool);
}

Result<std::shared_ptr<Array>> LargeListArray::FlattenRecursion(
MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/true, memory_pool);
return FlattenListArray(*this, memory_pool);
}

std::shared_ptr<Array> LargeListArray::offsets() const {
Expand Down Expand Up @@ -721,21 +716,9 @@ Result<std::shared_ptr<LargeListViewArray>> LargeListViewArray::FromList(

Result<std::shared_ptr<Array>> ListViewArray::Flatten(MemoryPool* memory_pool) const {
if (null_count() > 0) {
return FlattenListViewArray<ListViewArray, true>(*this, /*with_recursion=*/false,
memory_pool);
}
return FlattenListViewArray<ListViewArray, false>(*this, /*with_recursion=*/false,
memory_pool);
}

Result<std::shared_ptr<Array>> ListViewArray::FlattenRecursion(
MemoryPool* memory_pool) const {
if (null_count() > 0) {
return FlattenListViewArray<ListViewArray, true>(*this, /*with_recursion=*/true,
memory_pool);
return FlattenListViewArray<ListViewArray, true>(*this, memory_pool);
}
return FlattenListViewArray<ListViewArray, false>(*this, /*with_recursion=*/true,
memory_pool);
return FlattenListViewArray<ListViewArray, false>(*this, memory_pool);
}

std::shared_ptr<Array> ListViewArray::offsets() const {
Expand Down Expand Up @@ -794,21 +777,9 @@ Result<std::shared_ptr<LargeListViewArray>> LargeListViewArray::FromArrays(
Result<std::shared_ptr<Array>> LargeListViewArray::Flatten(
MemoryPool* memory_pool) const {
if (null_count() > 0) {
return FlattenListViewArray<LargeListViewArray, true>(*this, /*with_recursion=*/false,
memory_pool);
}
return FlattenListViewArray<LargeListViewArray, false>(*this, /*with_recursion=*/false,
memory_pool);
}

Result<std::shared_ptr<Array>> LargeListViewArray::FlattenRecursion(
MemoryPool* memory_pool) const {
if (null_count() > 0) {
return FlattenListViewArray<LargeListViewArray, true>(*this, /*with_recursion=*/true,
memory_pool);
return FlattenListViewArray<LargeListViewArray, true>(*this, memory_pool);
}
return FlattenListViewArray<LargeListViewArray, false>(*this, /*with_recursion=*/true,
memory_pool);
return FlattenListViewArray<LargeListViewArray, false>(*this, memory_pool);
}

std::shared_ptr<Array> LargeListViewArray::offsets() const {
Expand Down Expand Up @@ -1026,12 +997,12 @@ Result<std::shared_ptr<Array>> FixedSizeListArray::FromArrays(

Result<std::shared_ptr<Array>> FixedSizeListArray::Flatten(
MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/false, memory_pool);
return FlattenListArray(*this, memory_pool);
}

Result<std::shared_ptr<Array>> FixedSizeListArray::FlattenRecursion(
Result<std::shared_ptr<Array>> FixedSizeListArray::FlattenRecursively(
MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/true, memory_pool);
return internal::FlattenLogicalListRecursively(*this, memory_pool);
}

// ----------------------------------------------------------------------
Expand Down
33 changes: 12 additions & 21 deletions cpp/src/arrow/array/array_nested.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ void SetListData(VarLengthListLikeArray<TYPE>* self,
const std::shared_ptr<ArrayData>& data,
Type::type expected_type_id = TYPE::type_id);

// Private flatten helper for logical lists: [Large]List[View]Array, FixedSizeListArray
// and MapArray
ARROW_EXPORT Result<std::shared_ptr<Array>> FlattenLogicalListRecursively(
const Array& array, MemoryPool* memory_pool);
} // namespace internal

/// Base class for variable-sized list and list-view arrays, regardless of offset size.
Expand Down Expand Up @@ -103,6 +107,13 @@ class VarLengthListLikeArray : public Array {
return values_->Slice(value_offset(i), value_length(i));
}

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursively(
MemoryPool* memory_pool = default_memory_pool()) const {
return internal::FlattenLogicalListRecursively(*this, memory_pool);
}

protected:
friend void internal::SetListData<TYPE>(VarLengthListLikeArray<TYPE>* self,
const std::shared_ptr<ArrayData>& data,
Expand Down Expand Up @@ -189,11 +200,6 @@ class ARROW_EXPORT ListArray : public BaseListArray<ListType> {
Result<std::shared_ptr<Array>> Flatten(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursion(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Return list offsets as an Int32Array
///
/// The returned array will not have a validity bitmap, so you cannot expect
Expand Down Expand Up @@ -262,11 +268,6 @@ class ARROW_EXPORT LargeListArray : public BaseListArray<LargeListType> {
Result<std::shared_ptr<Array>> Flatten(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursion(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Return list offsets as an Int64Array
std::shared_ptr<Array> offsets() const;

Expand Down Expand Up @@ -374,11 +375,6 @@ class ARROW_EXPORT ListViewArray : public BaseListViewArray<ListViewType> {
Result<std::shared_ptr<Array>> Flatten(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursion(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Return list-view offsets as an Int32Array
///
/// The returned array will not have a validity bitmap, so you cannot expect
Expand Down Expand Up @@ -463,11 +459,6 @@ class ARROW_EXPORT LargeListViewArray : public BaseListViewArray<LargeListViewTy
Result<std::shared_ptr<Array>> Flatten(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursion(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Return list-view offsets as an Int64Array
///
/// The returned array will not have a validity bitmap, so you cannot expect
Expand Down Expand Up @@ -617,7 +608,7 @@ class ARROW_EXPORT FixedSizeListArray : public Array {

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursion(
Result<std::shared_ptr<Array>> FlattenRecursively(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Construct FixedSizeListArray from child value array and value_length
Expand Down

0 comments on commit 2450c60

Please sign in to comment.