Skip to content

Commit

Permalink
add a new flatten api with a recursion argument
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangHuiGui committed Apr 9, 2024
1 parent 26b7a6f commit 3d5a2ec
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 38 deletions.
15 changes: 7 additions & 8 deletions cpp/src/arrow/array/array_list_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ class TestListArray : public ::testing::Test {
<< flattened->ToString();
}

void TestFlattenNested() {
void TestFlattenRecursion() {
auto inner_type = std::make_shared<T>(int32());
auto type = std::make_shared<T>(inner_type);

Expand All @@ -773,16 +773,15 @@ class TestListArray : public ::testing::Test {
[null],
[[2, 9], [4], [], [6, 5]]
])"));
ASSERT_OK_AND_ASSIGN(auto flattened,
nested_list_array->Flatten(/*with_recursion=*/true));
ASSERT_OK_AND_ASSIGN(auto flattened, nested_list_array->FlattenRecursion());
ASSERT_OK(flattened->ValidateFull());
ASSERT_EQ(9, flattened->length());
ASSERT_TRUE(flattened->Equals(ArrayFromJSON(int32(), "[0, 1, 2, 3, 2, 9, 4, 6, 5]")));

// Empty nested list should flatten until reach it's non-list type
nested_list_array =
std::dynamic_pointer_cast<ArrayType>(ArrayFromJSON(type, R"([null])"));
ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->Flatten(/*with_recursion=*/true));
ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursion());
ASSERT_TRUE(flattened->type()->Equals(int32()));

// List type with three nested level: list(list(list(int32)))
Expand All @@ -801,7 +800,7 @@ class TestListArray : public ::testing::Test {
null
]
])"));
ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->Flatten(/*with_recursion=*/true));
ASSERT_OK_AND_ASSIGN(flattened, nested_list_array->FlattenRecursion());
ASSERT_OK(flattened->ValidateFull());
ASSERT_EQ(7, flattened->length());
ASSERT_EQ(2, flattened->null_count());
Expand Down Expand Up @@ -975,7 +974,7 @@ TYPED_TEST(TestListArray, FlattenZeroLength) { this->TestFlattenZeroLength(); }
TYPED_TEST(TestListArray, TestFlattenNonEmptyBackingNulls) {
this->TestFlattenNonEmptyBackingNulls();
}
TYPED_TEST(TestListArray, FlattenNested) { this->TestFlattenNested(); }
TYPED_TEST(TestListArray, FlattenRecursion) { this->TestFlattenRecursion(); }

TYPED_TEST(TestListArray, ValidateDimensions) { this->TestValidateDimensions(); }

Expand Down Expand Up @@ -1761,7 +1760,7 @@ TEST_F(TestFixedSizeListArray, Flatten) {
}
}

TEST_F(TestFixedSizeListArray, FlattenNested) {
TEST_F(TestFixedSizeListArray, FlattenRecursion) {
// Nested fixed-size list-array: fixed_size_list(fixed_size_list(int32, 2), 2)
auto inner_type = fixed_size_list(value_type_, 2);
type_ = fixed_size_list(inner_type, 2);
Expand All @@ -1772,7 +1771,7 @@ TEST_F(TestFixedSizeListArray, FlattenNested) {
[null, null]
])"));
ASSERT_OK(values->ValidateFull());
ASSERT_OK_AND_ASSIGN(auto flattened, values->Flatten(/*with_recursion=*/true));
ASSERT_OK_AND_ASSIGN(auto flattened, values->FlattenRecursion());
ASSERT_OK(flattened->ValidateFull());
ASSERT_EQ(8, flattened->length());
ASSERT_EQ(2, flattened->null_count());
Expand Down
64 changes: 49 additions & 15 deletions cpp/src/arrow/array/array_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,9 +580,13 @@ Result<std::shared_ptr<ListArray>> ListArray::FromArrays(
null_bitmap, null_count);
}

Result<std::shared_ptr<Array>> ListArray::Flatten(bool with_recursion,
MemoryPool* memory_pool) const {
return FlattenListArray(*this, with_recursion, memory_pool);
Result<std::shared_ptr<Array>> ListArray::Flatten(MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/false, memory_pool);
}

Result<std::shared_ptr<Array>> ListArray::FlattenRecursion(
MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/true, memory_pool);
}

std::shared_ptr<Array> ListArray::offsets() const { return BoxOffsets(int32(), *data_); }
Expand Down Expand Up @@ -640,9 +644,13 @@ Result<std::shared_ptr<LargeListArray>> LargeListArray::FromArrays(
null_bitmap, null_count);
}

Result<std::shared_ptr<Array>> LargeListArray::Flatten(bool with_recursion,
MemoryPool* memory_pool) const {
return FlattenListArray(*this, with_recursion, memory_pool);
Result<std::shared_ptr<Array>> LargeListArray::Flatten(MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/false, memory_pool);
}

Result<std::shared_ptr<Array>> LargeListArray::FlattenRecursion(
MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/true, memory_pool);
}

std::shared_ptr<Array> LargeListArray::offsets() const {
Expand Down Expand Up @@ -711,12 +719,23 @@ Result<std::shared_ptr<LargeListViewArray>> LargeListViewArray::FromList(
return std::make_shared<LargeListViewArray>(std::move(data));
}

Result<std::shared_ptr<Array>> ListViewArray::Flatten(bool with_recursion,
MemoryPool* memory_pool) const {
Result<std::shared_ptr<Array>> ListViewArray::Flatten(MemoryPool* memory_pool) const {
if (null_count() > 0) {
return FlattenListViewArray<ListViewArray, true>(*this, /*with_recursion=*/false,
memory_pool);
}
return FlattenListViewArray<ListViewArray, false>(*this, /*with_recursion=*/false,
memory_pool);
}

Result<std::shared_ptr<Array>> ListViewArray::FlattenRecursion(
MemoryPool* memory_pool) const {
if (null_count() > 0) {
return FlattenListViewArray<ListViewArray, true>(*this, with_recursion, memory_pool);
return FlattenListViewArray<ListViewArray, true>(*this, /*with_recursion=*/true,
memory_pool);
}
return FlattenListViewArray<ListViewArray, false>(*this, with_recursion, memory_pool);
return FlattenListViewArray<ListViewArray, false>(*this, /*with_recursion=*/true,
memory_pool);
}

std::shared_ptr<Array> ListViewArray::offsets() const {
Expand Down Expand Up @@ -773,12 +792,22 @@ Result<std::shared_ptr<LargeListViewArray>> LargeListViewArray::FromArrays(
}

Result<std::shared_ptr<Array>> LargeListViewArray::Flatten(
bool with_recursion, MemoryPool* memory_pool) const {
MemoryPool* memory_pool) const {
if (null_count() > 0) {
return FlattenListViewArray<LargeListViewArray, true>(*this, with_recursion,
return FlattenListViewArray<LargeListViewArray, true>(*this, /*with_recursion=*/false,
memory_pool);
}
return FlattenListViewArray<LargeListViewArray, false>(*this, with_recursion,
return FlattenListViewArray<LargeListViewArray, false>(*this, /*with_recursion=*/false,
memory_pool);
}

Result<std::shared_ptr<Array>> LargeListViewArray::FlattenRecursion(
MemoryPool* memory_pool) const {
if (null_count() > 0) {
return FlattenListViewArray<LargeListViewArray, true>(*this, /*with_recursion=*/true,
memory_pool);
}
return FlattenListViewArray<LargeListViewArray, false>(*this, /*with_recursion=*/true,
memory_pool);
}

Expand Down Expand Up @@ -996,8 +1025,13 @@ Result<std::shared_ptr<Array>> FixedSizeListArray::FromArrays(
}

Result<std::shared_ptr<Array>> FixedSizeListArray::Flatten(
bool with_recursion, MemoryPool* memory_pool) const {
return FlattenListArray(*this, with_recursion, memory_pool);
MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/false, memory_pool);
}

Result<std::shared_ptr<Array>> FixedSizeListArray::FlattenRecursion(
MemoryPool* memory_pool) const {
return FlattenListArray(*this, /*with_recursion=*/true, memory_pool);
}

// ----------------------------------------------------------------------
Expand Down
45 changes: 30 additions & 15 deletions cpp/src/arrow/array/array_nested.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,16 @@ class ARROW_EXPORT ListArray : public BaseListArray<ListType> {

/// \brief Return an Array that is a concatenation of the lists in this array.
///
/// \param[in] with_recursion Flatten recursively until reach non-list type
///
/// Note that it's different from `values()` in that it takes into
/// consideration of this array's offsets as well as null elements backed
/// by non-empty lists (they are skipped, thus copying may be needed).
Result<std::shared_ptr<Array>> Flatten(
bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const;
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursion(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Return list offsets as an Int32Array
///
Expand Down Expand Up @@ -253,13 +256,16 @@ class ARROW_EXPORT LargeListArray : public BaseListArray<LargeListType> {

/// \brief Return an Array that is a concatenation of the lists in this array.
///
/// \param[in] with_recursion Flatten recursively until reach non-list type
///
/// Note that it's different from `values()` in that it takes into
/// consideration of this array's offsets as well as null elements backed
/// by non-empty lists (they are skipped, thus copying may be needed).
Result<std::shared_ptr<Array>> Flatten(
bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const;
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursion(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Return list offsets as an Int64Array
std::shared_ptr<Array> offsets() const;
Expand Down Expand Up @@ -357,8 +363,6 @@ class ARROW_EXPORT ListViewArray : public BaseListViewArray<ListViewType> {

/// \brief Return an Array that is a concatenation of the list-views in this array.
///
/// \param[in] with_recursion Flatten recursively until reach non-list type
///
/// Note that it's different from `values()` in that it takes into
/// consideration this array's offsets (which can be in any order)
/// and sizes. Nulls are skipped.
Expand All @@ -368,7 +372,12 @@ class ARROW_EXPORT ListViewArray : public BaseListViewArray<ListViewType> {
/// maximizing the size of each slice (containing as many contiguous
/// list-views as possible).
Result<std::shared_ptr<Array>> Flatten(
bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const;
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursion(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Return list-view offsets as an Int32Array
///
Expand Down Expand Up @@ -448,13 +457,16 @@ class ARROW_EXPORT LargeListViewArray : public BaseListViewArray<LargeListViewTy
/// \brief Return an Array that is a concatenation of the large list-views in this
/// array.
///
/// \param[in] with_recursion Flatten recursively until reach non-list type
///
/// Note that it's different from `values()` in that it takes into
/// consideration this array's offsets (which can be in any order)
/// and sizes. Nulls are skipped.
Result<std::shared_ptr<Array>> Flatten(
bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const;
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursion(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Return list-view offsets as an Int64Array
///
Expand Down Expand Up @@ -598,12 +610,15 @@ class ARROW_EXPORT FixedSizeListArray : public Array {

/// \brief Return an Array that is a concatenation of the lists in this array.
///
/// \param[in] with_recursion Flatten recursively until reach non-list type
///
/// Note that it's different from `values()` in that it takes into
/// consideration null elements (they are skipped, thus copying may be needed).
Result<std::shared_ptr<Array>> Flatten(
bool with_recursion = false, MemoryPool* memory_pool = default_memory_pool()) const;
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Flatten all level recursively until reach a non-list type, and return a
/// non-list type Array.
Result<std::shared_ptr<Array>> FlattenRecursion(
MemoryPool* memory_pool = default_memory_pool()) const;

/// \brief Construct FixedSizeListArray from child value array and value_length
///
Expand Down

0 comments on commit 3d5a2ec

Please sign in to comment.