Skip to content

Commit

Permalink
GH-41183: [C++][Python] Expose recursive flatten for lists on list_fl…
Browse files Browse the repository at this point in the history
…atten kernel function and pyarrow bindings (#41295)

### Rationale for this change
Expose recursive flatten for logical lists on list_flatten kernel function and pyarrow bindings.

### What changes are included in this PR?
1. Expose recursive flatten for logical lists on `list_flatten` kernel function
2. Support [Large]ListView for some kernel functions: `list_flatten`,`list_value_length`, `list_element`
3. Support recursive flatten for pyarrow bindinds and simplify [Large]ListView's pyarrow bindings
4. Refactor vector_nested_test.cc for better support [Large]ListView types.

### Are these changes tested?
Yes

### Are there any user-facing changes?
Yes.
1. Some kernel functions like: list_flatten, list_value_length, list_element would support [Large]ListView types
2. `list_flatten` and related pyarrow bindings could support flatten recursively with an ListFlattenOptions.

* GitHub Issue: #41183

Lead-authored-by: ZhangHuiGui <hugo.zhang@openpie.com>
Co-authored-by: ZhangHuiGui <2689496754@qq.com>
Signed-off-by: Felipe Oliveira Carvalho <felipekde@gmail.com>
  • Loading branch information
ZhangHuiGui and ZhangHuiGui authored Apr 30, 2024
1 parent 0ef7351 commit 5e986be
Show file tree
Hide file tree
Showing 15 changed files with 364 additions and 182 deletions.
7 changes: 7 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ static auto kRankOptionsType = GetFunctionOptionsType<RankOptions>(
DataMember("tiebreaker", &RankOptions::tiebreaker));
static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
DataMember("periods", &PairwiseOptions::periods));
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
DataMember("recursive", &ListFlattenOptions::recursive));
} // namespace
} // namespace internal

Expand Down Expand Up @@ -224,6 +226,10 @@ PairwiseOptions::PairwiseOptions(int64_t periods)
: FunctionOptions(internal::kPairwiseOptionsType), periods(periods) {}
constexpr char PairwiseOptions::kTypeName[];

ListFlattenOptions::ListFlattenOptions(bool recursive)
: FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {}
constexpr char ListFlattenOptions::kTypeName[];

namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
Expand All @@ -237,6 +243,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
}
} // namespace internal

Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,18 @@ class ARROW_EXPORT PairwiseOptions : public FunctionOptions {
int64_t periods = 1;
};

/// \brief Options for list_flatten function
class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
public:
explicit ListFlattenOptions(bool recursive = false);
static constexpr char const kTypeName[] = "ListFlattenOptions";
static ListFlattenOptions Defaults() { return ListFlattenOptions(); }

/// \brief If true, the list is flattened recursively until a non-list
/// array is formed.
bool recursive = false;
};

/// @}

/// \brief Filter with a boolean selection filter
Expand Down
21 changes: 18 additions & 3 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <mutex>
#include <vector>

#include "arrow/compute/api_vector.h"
#include "arrow/type_fwd.h"

namespace arrow {
Expand Down Expand Up @@ -56,9 +57,23 @@ Result<TypeHolder> LastType(KernelContext*, const std::vector<TypeHolder>& types
return types.back();
}

Result<TypeHolder> ListValuesType(KernelContext*, const std::vector<TypeHolder>& args) {
const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
return list_type.value_type().get();
Result<TypeHolder> ListValuesType(KernelContext* ctx,
const std::vector<TypeHolder>& args) {
auto list_type = checked_cast<const BaseListType*>(args[0].type);
auto value_type = list_type->value_type().get();

auto recursive =
ctx->state() ? OptionsWrapper<ListFlattenOptions>::Get(ctx).recursive : false;
if (!recursive) {
return value_type;
}

for (auto value_kind = value_type->id();
is_list(value_kind) || is_list_view(value_kind); value_kind = value_type->id()) {
list_type = checked_cast<const BaseListType*>(list_type->value_type().get());
value_type = list_type->value_type().get();
}
return value_type;
}

void EnsureDictionaryDecoded(std::vector<TypeHolder>* types) {
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ static void VisitTwoArrayValuesInline(const ArraySpan& arr0, const ArraySpan& ar

Result<TypeHolder> FirstType(KernelContext*, const std::vector<TypeHolder>& types);
Result<TypeHolder> LastType(KernelContext*, const std::vector<TypeHolder>& types);
Result<TypeHolder> ListValuesType(KernelContext*, const std::vector<TypeHolder>& types);
Result<TypeHolder> ListValuesType(KernelContext* ctx,
const std::vector<TypeHolder>& types);

// ----------------------------------------------------------------------
// Helpers for iterating over common DataType instances for adding kernels to
Expand Down
49 changes: 39 additions & 10 deletions cpp/src/arrow/compute/kernels/scalar_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/result.h"
#include "arrow/type_fwd.h"
#include "arrow/util/bit_block_counter.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_generate.h"
Expand All @@ -41,10 +42,17 @@ Status ListValueLength(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou
const ArraySpan& arr = batch[0].array;
ArraySpan* out_arr = out->array_span_mutable();
auto out_values = out_arr->GetValues<offset_type>(1);
const offset_type* offsets = arr.GetValues<offset_type>(1);
// Offsets are always well-defined and monotonic, even for null values
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = offsets[i + 1] - offsets[i];
if (is_list_view(*arr.type)) {
const auto* sizes = arr.GetValues<offset_type>(2);
if (arr.length > 0) {
memcpy(out_values, sizes, arr.length * sizeof(offset_type));
}
} else {
const offset_type* offsets = arr.GetValues<offset_type>(1);
// Offsets are always well-defined and monotonic, even for null values
for (int64_t i = 0; i < arr.length; ++i) {
*out_values++ = offsets[i + 1] - offsets[i];
}
}
return Status::OK();
}
Expand All @@ -59,6 +67,30 @@ Status FixedSizeListValueLength(KernelContext* ctx, const ExecSpan& batch,
return Status::OK();
}

template <typename InListType>
void AddListValueLengthKernel(ScalarFunction* func,
const std::shared_ptr<DataType>& out_type) {
auto in_type = {InputType(InListType::type_id)};
ScalarKernel kernel(in_type, out_type, ListValueLength<InListType>);
DCHECK_OK(func->AddKernel(std::move(kernel)));
}

template <>
void AddListValueLengthKernel<FixedSizeListType>(
ScalarFunction* func, const std::shared_ptr<DataType>& out_type) {
auto in_type = {InputType(Type::FIXED_SIZE_LIST)};
ScalarKernel kernel(in_type, out_type, FixedSizeListValueLength);
DCHECK_OK(func->AddKernel(std::move(kernel)));
}

void AddListValueLengthKernels(ScalarFunction* func) {
AddListValueLengthKernel<ListType>(func, int32());
AddListValueLengthKernel<LargeListType>(func, int64());
AddListValueLengthKernel<ListViewType>(func, int32());
AddListValueLengthKernel<LargeListViewType>(func, int64());
AddListValueLengthKernel<FixedSizeListType>(func, int32());
}

const FunctionDoc list_value_length_doc{
"Compute list lengths",
("`lists` must have a list-like type.\n"
Expand Down Expand Up @@ -399,6 +431,8 @@ void AddListElementKernels(ScalarFunction* func) {
void AddListElementKernels(ScalarFunction* func) {
AddListElementKernels<ListType, ListElement>(func);
AddListElementKernels<LargeListType, ListElement>(func);
AddListElementKernels<ListViewType, ListElement>(func);
AddListElementKernels<LargeListViewType, ListElement>(func);
AddListElementKernels<FixedSizeListType, FixedSizeListElement>(func);
}

Expand Down Expand Up @@ -824,12 +858,7 @@ const FunctionDoc map_lookup_doc{
void RegisterScalarNested(FunctionRegistry* registry) {
auto list_value_length = std::make_shared<ScalarFunction>(
"list_value_length", Arity::Unary(), list_value_length_doc);
DCHECK_OK(list_value_length->AddKernel({InputType(Type::LIST)}, int32(),
ListValueLength<ListType>));
DCHECK_OK(list_value_length->AddKernel({InputType(Type::FIXED_SIZE_LIST)}, int32(),
FixedSizeListValueLength));
DCHECK_OK(list_value_length->AddKernel({InputType(Type::LARGE_LIST)}, int64(),
ListValueLength<LargeListType>));
AddListValueLengthKernels(list_value_length.get());
DCHECK_OK(registry->AddFunction(std::move(list_value_length)));

auto list_element =
Expand Down
17 changes: 14 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_nested_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,21 @@ namespace arrow {
namespace compute {

static std::shared_ptr<DataType> GetOffsetType(const DataType& type) {
return type.id() == Type::LIST ? int32() : int64();
switch (type.id()) {
case Type::LIST:
case Type::LIST_VIEW:
return int32();
case Type::LARGE_LIST:
case Type::LARGE_LIST_VIEW:
return int64();
default:
Unreachable("Unexpected type");
}
}

TEST(TestScalarNested, ListValueLength) {
for (auto ty : {list(int32()), large_list(int32())}) {
for (auto ty : {list(int32()), large_list(int32()), list_view(int32()),
large_list_view(int32())}) {
CheckScalarUnary("list_value_length", ty, "[[0, null, 1], null, [2, 3], []]",
GetOffsetType(*ty), "[3, null, 2, 0]");
}
Expand All @@ -47,7 +57,8 @@ TEST(TestScalarNested, ListValueLength) {
TEST(TestScalarNested, ListElementNonFixedListWithNulls) {
auto sample = "[[7, 5, 81], [6, null, 4, 7, 8], [3, 12, 2, 0], [1, 9], null]";
for (auto ty : NumericTypes()) {
for (auto list_type : {list(ty), large_list(ty)}) {
for (auto list_type :
{list(ty), large_list(ty), list_view(ty), large_list_view(ty)}) {
auto input = ArrayFromJSON(list_type, sample);
auto null_input = ArrayFromJSON(list_type, "[null]");
for (auto index_type : IntTypes()) {
Expand Down
54 changes: 41 additions & 13 deletions cpp/src/arrow/compute/kernels/vector_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
// Vector kernels involving nested types

#include "arrow/array/array_base.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/result.h"
#include "arrow/visit_type_inline.h"
Expand All @@ -29,8 +30,13 @@ namespace {

template <typename Type>
Status ListFlatten(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
auto recursive = OptionsWrapper<ListFlattenOptions>::Get(ctx).recursive;
typename TypeTraits<Type>::ArrayType list_array(batch[0].array.ToArrayData());
ARROW_ASSIGN_OR_RAISE(auto result, list_array.Flatten(ctx->memory_pool()));

auto pool = ctx->memory_pool();
ARROW_ASSIGN_OR_RAISE(auto result, (recursive ? list_array.FlattenRecursively(pool)
: list_array.Flatten(pool)));

out->value = std::move(result->data());
return Status::OK();
}
Expand Down Expand Up @@ -107,10 +113,15 @@ struct ListParentIndicesArray {

const FunctionDoc list_flatten_doc(
"Flatten list values",
("`lists` must have a list-like type.\n"
"Return an array with the top list level flattened.\n"
"Top-level null values in `lists` do not emit anything in the input."),
{"lists"});
("`lists` must have a list-like type (lists, list-views, and\n"
"fixed-size lists).\n"
"Return an array with the top list level flattened unless\n"
"`recursive` is set to true in ListFlattenOptions. When that\n"
"is that case, flattening happens recursively until a non-list\n"
"array is formed.\n"
"\n"
"Null list values do not emit anything to the output."),
{"lists"}, "ListFlattenOptions");

const FunctionDoc list_parent_indices_doc(
"Compute parent indices of nested list values",
Expand Down Expand Up @@ -153,17 +164,34 @@ class ListParentIndicesFunction : public MetaFunction {
}
};

const ListFlattenOptions* GetDefaultListFlattenOptions() {
static const auto kDefaultListFlattenOptions = ListFlattenOptions::Defaults();
return &kDefaultListFlattenOptions;
}

template <typename InListType>
void AddBaseListFlattenKernels(VectorFunction* func) {
auto in_type = {InputType(InListType::type_id)};
auto out_type = OutputType(ListValuesType);
VectorKernel kernel(in_type, out_type, ListFlatten<InListType>,
OptionsWrapper<ListFlattenOptions>::Init);
DCHECK_OK(func->AddKernel(std::move(kernel)));
}

void AddBaseListFlattenKernels(VectorFunction* func) {
AddBaseListFlattenKernels<ListType>(func);
AddBaseListFlattenKernels<LargeListType>(func);
AddBaseListFlattenKernels<FixedSizeListType>(func);
AddBaseListFlattenKernels<ListViewType>(func);
AddBaseListFlattenKernels<LargeListViewType>(func);
}

} // namespace

void RegisterVectorNested(FunctionRegistry* registry) {
auto flatten =
std::make_shared<VectorFunction>("list_flatten", Arity::Unary(), list_flatten_doc);
DCHECK_OK(flatten->AddKernel({Type::LIST}, OutputType(ListValuesType),
ListFlatten<ListType>));
DCHECK_OK(flatten->AddKernel({Type::FIXED_SIZE_LIST}, OutputType(ListValuesType),
ListFlatten<FixedSizeListType>));
DCHECK_OK(flatten->AddKernel({Type::LARGE_LIST}, OutputType(ListValuesType),
ListFlatten<LargeListType>));
auto flatten = std::make_shared<VectorFunction>(
"list_flatten", Arity::Unary(), list_flatten_doc, GetDefaultListFlattenOptions());
AddBaseListFlattenKernels(flatten.get());
DCHECK_OK(registry->AddFunction(std::move(flatten)));

DCHECK_OK(registry->AddFunction(std::make_shared<ListParentIndicesFunction>()));
Expand Down
Loading

0 comments on commit 5e986be

Please sign in to comment.