Skip to content

Commit 04575bc

Browse files
committed
ARROW-13681: [C++] Fix list_parent_indices behaviour on chunked array
Closes #10985 from pitrou/ARROW-13681-list-parent-indices-chunked Authored-by: Antoine Pitrou <antoine@python.org> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent f4717df commit 04575bc

File tree

4 files changed

+150
-46
lines changed

4 files changed

+150
-46
lines changed

cpp/src/arrow/compute/kernels/test_util.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,11 @@ void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty,
183183
ArrayFromJSON(out_ty, json_expected), options);
184184
}
185185

186-
void CheckVectorUnary(std::string func_name, Datum input, std::shared_ptr<Array> expected,
186+
void CheckVectorUnary(std::string func_name, Datum input, Datum expected,
187187
const FunctionOptions* options) {
188-
ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, {input}, options));
189-
std::shared_ptr<Array> actual = std::move(out).make_array();
190-
ValidateOutput(*actual);
191-
AssertArraysEqual(*expected, *actual, /*verbose=*/true);
188+
ASSERT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, {input}, options));
189+
ValidateOutput(actual);
190+
AssertDatumsEqual(expected, actual, /*verbose=*/true);
192191
}
193192

194193
void CheckScalarBinary(std::string func_name, Datum left_input, Datum right_input,

cpp/src/arrow/compute/kernels/test_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void CheckScalarUnary(std::string func_name, Datum input, Datum expected,
8484
void CheckScalarBinary(std::string func_name, Datum left_input, Datum right_input,
8585
Datum expected, const FunctionOptions* options = nullptr);
8686

87-
void CheckVectorUnary(std::string func_name, Datum input, std::shared_ptr<Array> expected,
87+
void CheckVectorUnary(std::string func_name, Datum input, Datum expected,
8888
const FunctionOptions* options = nullptr);
8989

9090
void ValidateOutput(const Datum& output);

cpp/src/arrow/compute/kernels/vector_nested.cc

Lines changed: 105 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "arrow/array/array_base.h"
2121
#include "arrow/compute/kernels/common.h"
2222
#include "arrow/result.h"
23+
#include "arrow/visitor_inline.h"
2324

2425
namespace arrow {
2526
namespace compute {
@@ -34,35 +35,79 @@ Status ListFlatten(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
3435
return Status::OK();
3536
}
3637

37-
template <typename Type, typename offset_type = typename Type::offset_type>
38-
Status ListParentIndices(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
39-
typename TypeTraits<Type>::ArrayType list(batch[0].array());
40-
ArrayData* out_arr = out->mutable_array();
41-
42-
const offset_type* offsets = list.raw_value_offsets();
43-
offset_type values_length = offsets[list.length()] - offsets[0];
44-
45-
out_arr->length = values_length;
46-
out_arr->null_count = 0;
47-
ARROW_ASSIGN_OR_RAISE(out_arr->buffers[1],
48-
ctx->Allocate(values_length * sizeof(offset_type)));
49-
auto out_indices = reinterpret_cast<offset_type*>(out_arr->buffers[1]->mutable_data());
50-
for (int64_t i = 0; i < list.length(); ++i) {
51-
// Note: In most cases, null slots are empty, but when they are non-empty
52-
// we write out the indices so make sure they are accounted for. This
53-
// behavior could be changed if needed in the future.
54-
for (offset_type j = offsets[i]; j < offsets[i + 1]; ++j) {
55-
*out_indices++ = static_cast<offset_type>(i);
38+
struct ListParentIndicesArray {
39+
KernelContext* ctx;
40+
const std::shared_ptr<ArrayData>& input;
41+
int64_t base_output_offset;
42+
std::shared_ptr<ArrayData> out;
43+
44+
template <typename Type, typename offset_type = typename Type::offset_type>
45+
Status VisitList(const Type&) {
46+
typename TypeTraits<Type>::ArrayType list(input);
47+
48+
const offset_type* offsets = list.raw_value_offsets();
49+
offset_type values_length = offsets[list.length()] - offsets[0];
50+
51+
ARROW_ASSIGN_OR_RAISE(auto indices,
52+
ctx->Allocate(values_length * sizeof(offset_type)));
53+
auto out_indices = reinterpret_cast<offset_type*>(indices->mutable_data());
54+
for (int64_t i = 0; i < list.length(); ++i) {
55+
// Note: In most cases, null slots are empty, but when they are non-empty
56+
// we write out the indices so make sure they are accounted for. This
57+
// behavior could be changed if needed in the future.
58+
for (offset_type j = offsets[i]; j < offsets[i + 1]; ++j) {
59+
*out_indices++ = static_cast<offset_type>(i + base_output_offset);
60+
}
5661
}
62+
63+
BufferVector buffers{nullptr, std::move(indices)};
64+
int64_t null_count = 0;
65+
if (sizeof(offset_type) == 4) {
66+
out = std::make_shared<ArrayData>(int32(), values_length, std::move(buffers),
67+
null_count);
68+
} else {
69+
out = std::make_shared<ArrayData>(int64(), values_length, std::move(buffers),
70+
null_count);
71+
}
72+
return Status::OK();
73+
}
74+
75+
Status Visit(const ListType& type) { return VisitList(type); }
76+
77+
Status Visit(const LargeListType& type) { return VisitList(type); }
78+
79+
Status Visit(const DataType& type) {
80+
return Status::TypeError("Function 'list_parent_indices' expects list input, got ",
81+
type.ToString());
5782
}
58-
return Status::OK();
59-
}
6083

61-
Result<ValueDescr> ValuesType(KernelContext*, const std::vector<ValueDescr>& args) {
84+
static Result<std::shared_ptr<ArrayData>> Exec(KernelContext* ctx,
85+
const std::shared_ptr<ArrayData>& input,
86+
int64_t base_output_offset = 0) {
87+
ListParentIndicesArray self{ctx, input, base_output_offset, /*out=*/nullptr};
88+
RETURN_NOT_OK(VisitTypeInline(*input->type, &self));
89+
DCHECK_NE(self.out, nullptr);
90+
return self.out;
91+
}
92+
};
93+
94+
Result<ValueDescr> ListValuesType(KernelContext*, const std::vector<ValueDescr>& args) {
6295
const auto& list_type = checked_cast<const BaseListType&>(*args[0].type);
6396
return ValueDescr::Array(list_type.value_type());
6497
}
6598

99+
Result<std::shared_ptr<DataType>> ListParentIndicesType(const DataType& input_type) {
100+
switch (input_type.id()) {
101+
case Type::LIST:
102+
return int32();
103+
case Type::LARGE_LIST:
104+
return int64();
105+
default:
106+
return Status::TypeError("Function 'list_parent_indices' expects list input, got ",
107+
input_type.ToString());
108+
}
109+
}
110+
66111
const FunctionDoc list_flatten_doc(
67112
"Flatten list values",
68113
("`lists` must have a list-like type.\n"
@@ -77,24 +122,53 @@ const FunctionDoc list_parent_indices_doc(
77122
"is emitted."),
78123
{"lists"});
79124

125+
class ListParentIndicesFunction : public MetaFunction {
126+
public:
127+
ListParentIndicesFunction()
128+
: MetaFunction("list_parent_indices", Arity::Unary(), &list_parent_indices_doc) {}
129+
130+
Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
131+
const FunctionOptions* options,
132+
ExecContext* ctx) const override {
133+
KernelContext kernel_ctx(ctx);
134+
switch (args[0].kind()) {
135+
case Datum::ARRAY:
136+
return ListParentIndicesArray::Exec(&kernel_ctx, args[0].array());
137+
case Datum::CHUNKED_ARRAY: {
138+
const auto& input = args[0].chunked_array();
139+
ARROW_ASSIGN_OR_RAISE(auto out_ty, ListParentIndicesType(*input->type()));
140+
141+
int64_t base_output_offset = 0;
142+
ArrayVector out_chunks;
143+
for (const auto& chunk : input->chunks()) {
144+
ARROW_ASSIGN_OR_RAISE(auto out_chunk,
145+
ListParentIndicesArray::Exec(&kernel_ctx, chunk->data(),
146+
base_output_offset));
147+
out_chunks.push_back(MakeArray(std::move(out_chunk)));
148+
base_output_offset += chunk->length();
149+
}
150+
return std::make_shared<ChunkedArray>(std::move(out_chunks), std::move(out_ty));
151+
}
152+
default:
153+
return Status::NotImplemented(
154+
"Unsupported input type for function 'list_parent_indices': ",
155+
args[0].ToString());
156+
}
157+
}
158+
};
159+
80160
} // namespace
81161

82162
void RegisterVectorNested(FunctionRegistry* registry) {
83163
auto flatten =
84164
std::make_shared<VectorFunction>("list_flatten", Arity::Unary(), &list_flatten_doc);
85-
DCHECK_OK(flatten->AddKernel({InputType::Array(Type::LIST)}, OutputType(ValuesType),
165+
DCHECK_OK(flatten->AddKernel({InputType::Array(Type::LIST)}, OutputType(ListValuesType),
86166
ListFlatten<ListType>));
87167
DCHECK_OK(flatten->AddKernel({InputType::Array(Type::LARGE_LIST)},
88-
OutputType(ValuesType), ListFlatten<LargeListType>));
168+
OutputType(ListValuesType), ListFlatten<LargeListType>));
89169
DCHECK_OK(registry->AddFunction(std::move(flatten)));
90170

91-
auto list_parent_indices = std::make_shared<VectorFunction>(
92-
"list_parent_indices", Arity::Unary(), &list_parent_indices_doc);
93-
DCHECK_OK(list_parent_indices->AddKernel({InputType::Array(Type::LIST)}, int32(),
94-
ListParentIndices<ListType>));
95-
DCHECK_OK(list_parent_indices->AddKernel({InputType::Array(Type::LARGE_LIST)}, int64(),
96-
ListParentIndices<LargeListType>));
97-
DCHECK_OK(registry->AddFunction(std::move(list_parent_indices)));
171+
DCHECK_OK(registry->AddFunction(std::make_shared<ListParentIndicesFunction>()));
98172
}
99173

100174
} // namespace internal

cpp/src/arrow/compute/kernels/vector_nested_test.cc

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <gtest/gtest.h>
1919

20+
#include "arrow/chunked_array.h"
2021
#include "arrow/compute/api.h"
2122
#include "arrow/compute/kernels/test_util.h"
2223
#include "arrow/result.h"
@@ -26,29 +27,59 @@ namespace arrow {
2627
namespace compute {
2728

2829
TEST(TestVectorNested, ListFlatten) {
29-
for (auto ty : {list(int32()), large_list(int32())}) {
30+
for (auto ty : {list(int16()), large_list(int16())}) {
3031
auto input = ArrayFromJSON(ty, "[[0, null, 1], null, [2, 3], []]");
31-
auto expected = ArrayFromJSON(int32(), "[0, null, 1, 2, 3]");
32+
auto expected = ArrayFromJSON(int16(), "[0, null, 1, 2, 3]");
33+
CheckVectorUnary("list_flatten", input, expected);
34+
35+
// Construct a list with a non-empty null slot
36+
TweakValidityBit(input, 0, false);
37+
expected = ArrayFromJSON(int16(), "[2, 3]");
38+
CheckVectorUnary("list_flatten", input, expected);
39+
}
40+
}
41+
42+
TEST(TestVectorNested, ListFlattenChunkedArray) {
43+
for (auto ty : {list(int16()), large_list(int16())}) {
44+
auto input = ChunkedArrayFromJSON(ty, {"[[0, null, 1], null]", "[[2, 3], []]"});
45+
auto expected = ChunkedArrayFromJSON(int16(), {"[0, null, 1]", "[2, 3]"});
46+
CheckVectorUnary("list_flatten", input, expected);
47+
48+
input = ChunkedArrayFromJSON(ty, {});
49+
expected = ChunkedArrayFromJSON(int16(), {});
3250
CheckVectorUnary("list_flatten", input, expected);
3351
}
3452
}
3553

3654
TEST(TestVectorNested, ListParentIndices) {
37-
for (auto ty : {list(int32()), large_list(int32())}) {
55+
for (auto ty : {list(int16()), large_list(int16())}) {
3856
auto input = ArrayFromJSON(ty, "[[0, null, 1], null, [2, 3], [], [4, 5]]");
3957

4058
auto out_ty = ty->id() == Type::LIST ? int32() : int64();
4159
auto expected = ArrayFromJSON(out_ty, "[0, 0, 0, 2, 2, 4, 4]");
4260
CheckVectorUnary("list_parent_indices", input, expected);
4361
}
4462

45-
// Construct a list with non-empty null slots
46-
auto input = ArrayFromJSON(list(int32()), "[[0, null, 1], [0, 0], [2, 3], [], [4, 5]]");
47-
std::shared_ptr<ArrayData> data = input->data()->Copy();
48-
data->buffers[0] =
49-
(ArrayFromJSON(boolean(), "[true, false, true, true, true]")->data()->buffers[1]);
63+
// Construct a list with a non-empty null slot
64+
auto input = ArrayFromJSON(list(int16()), "[[0, null, 1], [0, 0], [2, 3], [], [4, 5]]");
65+
TweakValidityBit(input, 1, false);
5066
auto expected = ArrayFromJSON(int32(), "[0, 0, 0, 1, 1, 2, 2, 4, 4]");
51-
CheckVectorUnary("list_parent_indices", data, expected);
67+
CheckVectorUnary("list_parent_indices", input, expected);
68+
}
69+
70+
TEST(TestVectorNested, ListParentIndicesChunkedArray) {
71+
for (auto ty : {list(int16()), large_list(int16())}) {
72+
auto input =
73+
ChunkedArrayFromJSON(ty, {"[[0, null, 1], null]", "[[2, 3], [], [4, 5]]"});
74+
75+
auto out_ty = ty->id() == Type::LIST ? int32() : int64();
76+
auto expected = ChunkedArrayFromJSON(out_ty, {"[0, 0, 0]", "[2, 2, 4, 4]"});
77+
CheckVectorUnary("list_parent_indices", input, expected);
78+
79+
input = ChunkedArrayFromJSON(ty, {});
80+
expected = ChunkedArrayFromJSON(out_ty, {});
81+
CheckVectorUnary("list_parent_indices", input, expected);
82+
}
5283
}
5384

5485
} // namespace compute

0 commit comments

Comments
 (0)