From 3707fa6fb165b9d3985b5ebdb39d430fbbdcd02c Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Tue, 8 Oct 2024 00:50:39 +0800 Subject: [PATCH] Implement permute --- cpp/src/arrow/compute/api_vector.cc | 8 +- cpp/src/arrow/compute/api_vector.h | 23 ++-- .../arrow/compute/kernels/vector_placement.cc | 111 ++++++++---------- .../compute/kernels/vector_placement_test.cc | 10 ++ 4 files changed, 74 insertions(+), 78 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index c7e344ef28097..524f16d877119 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -159,8 +159,8 @@ static auto kReverseIndexOptionsType = GetFunctionOptionsType(DataMember("bound", &PermuteOptions::bound)); +static auto kPermuteOptionsType = GetFunctionOptionsType( + DataMember("output_length", &PermuteOptions::output_length)); } // namespace } // namespace internal @@ -245,8 +245,8 @@ ReverseIndexOptions::ReverseIndexOptions(int64_t output_length, output_non_taken(std::move(output_non_taken)) {} constexpr char ReverseIndexOptions::kTypeName[]; -PermuteOptions::PermuteOptions(int64_t bound) - : FunctionOptions(internal::kPermuteOptionsType), bound(bound) {} +PermuteOptions::PermuteOptions(int64_t output_length) + : FunctionOptions(internal::kPermuteOptionsType), output_length(output_length) {} constexpr char PermuteOptions::kTypeName[]; namespace internal { diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 60c53fec1be10..152f1bdf80832 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -266,27 +266,30 @@ class ARROW_EXPORT ReverseIndexOptions : public FunctionOptions { static constexpr char const kTypeName[] = "ReverseIndexOptions"; static ReverseIndexOptions Defaults() { return ReverseIndexOptions(); } - /// \brief The upper bound of the permutation. If -1, the output will be sized as the - /// maximum value in the indices array + 1. Otherwise, the output will be of size bound, - /// and any indices that are greater of equal to bound will be ignored. + /// \brief The length of the output reverse index. Must be non-negative. Any indices + /// that are greater of equal to this length will be ignored. int64_t output_length = 0; - /// \brief The type of the output reverse index. If null, the output type will be the - /// smallest possible integer type that can hold the maximum value in the indices array. + /// \brief The type of the output reverse index. Must be integer types. An overflow + /// error will be reported if any reverse indices are out of the bounds of the output + /// type. std::shared_ptr output_type = int32(); + /// \brief The value to fill in the output reverse index for indices that are not taken. + /// Must be of the same type as `output_type` if not null. If not provided or an invalid + /// scalar is provided, the non-taken indices will be filled with nulls. Using non-null + /// scalars properly may enable efficient processing of reverse index function. std::shared_ptr output_non_taken = NULLPTR; }; /// \brief Options for permute function class ARROW_EXPORT PermuteOptions : public FunctionOptions { public: - explicit PermuteOptions(int64_t bound = -1); + explicit PermuteOptions(int64_t output_length = 0); static constexpr char const kTypeName[] = "PermuteOptions"; static PermuteOptions Defaults() { return PermuteOptions(); } - /// \brief The upper bound of the permutation. If -1, the output will be sized as the - /// maximum value in the indices array + 1. Otherwise, the output will be of size bound, - /// and any indices that are greater of equal to bound will be ignored. - int64_t bound = -1; + /// \brief The length of the output permutation. Must be non-negative. Any values with + /// indices that are greater of equal to this length will be ignored. + int64_t output_length = 0; }; /// @} diff --git a/cpp/src/arrow/compute/kernels/vector_placement.cc b/cpp/src/arrow/compute/kernels/vector_placement.cc index f5e878bc93c14..a1d8f4054f45f 100644 --- a/cpp/src/arrow/compute/kernels/vector_placement.cc +++ b/cpp/src/arrow/compute/kernels/vector_placement.cc @@ -10,6 +10,18 @@ namespace arrow::compute::internal { namespace { +// ---------------------------------------------------------------------- +// ReverseIndex + +const FunctionDoc reverse_index_doc( + "Compute the reverse indices from an input indices", + "For the `i`-th `index` in `indices`, the `index`-th output is `i`", {"indices"}); + +const PermuteOptions* GetDefaultReverseIndexOptions() { + static const auto kDefaultPermuteOptions = PermuteOptions::Defaults(); + return &kDefaultPermuteOptions; +} + struct ReverseIndexState : public KernelState { explicit ReverseIndexState(int64_t length, std::shared_ptr type, std::shared_ptr validity, @@ -203,15 +215,8 @@ struct ReverseIndexChunked { } }; -const FunctionDoc reverse_index_doc( - "Compute the reverse indices from an input indices", - "Place each input value to the output array at position specified by `indices`", - {"indices"}); - -const PermuteOptions* GetDefaultReverseIndexOptions() { - static const auto kDefaultPermuteOptions = PermuteOptions::Defaults(); - return &kDefaultPermuteOptions; -} +// ---------------------------------------------------------------------- +// Permute const FunctionDoc permute_doc( "Permute values of an input based on indices from another array", @@ -232,68 +237,40 @@ class PermuteMetaFunction : public MetaFunction { Result ExecuteImpl(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const override { - const auto& permute_options = checked_cast(*options); - if (args[0].length() != args[1].length()) { - return Status::Invalid("Input and indices must have the same length"); + const auto& values = args[0]; + const auto& indices = args[1]; + auto* permute_options = checked_cast(options); + if (values.length() != indices.length()) { + return Status::Invalid( + "Input and indices of permute must have the same length, got " + + std::to_string(values.length()) + " and " + std::to_string(indices.length())); } - if (args[0].length() == 0) { - return args[0]; + if (!is_integer(indices.type()->id())) { + return Status::Invalid("Indices of permute must be of integer type, got ", + indices.type()->ToString()); } - int64_t output_length = permute_options.bound; + int64_t output_length = permute_options->output_length; if (output_length < 0) { - ARROW_ASSIGN_OR_RAISE(auto max_scalar, CallFunction("max", {args[1]}, ctx)); - DCHECK(max_scalar.is_scalar()); - ARROW_ASSIGN_OR_RAISE(auto max_i64_scalar, max_scalar.scalar()->CastTo(int64())); - output_length = checked_cast(max_i64_scalar.get())->value + 1; + return Status::Invalid("Output length of permute must be non-negative, got " + + std::to_string(output_length)); } - if (output_length <= 0) { - ARROW_ASSIGN_OR_RAISE(auto output, MakeEmptyArray(args[0].type())); - return output->data(); + std::shared_ptr output_non_taken = nullptr; + if (is_signed_integer(indices.type()->id())) { + // Using -1 (as opposed to null) as output_non_taken for signed integer types to + // enable efficient reverse_index. + ARROW_ASSIGN_OR_RAISE(output_non_taken, MakeScalar(indices.type(), -1)); } - ARROW_ASSIGN_OR_RAISE(auto reverse_indices, - MakeArrayOfNull(int64(), output_length, ctx->memory_pool())); - switch (args[1].kind()) { - case Datum::ARRAY: - RETURN_NOT_OK(ReverseIndices(*args[1].array(), reverse_indices->data())); - break; - case Datum::CHUNKED_ARRAY: - for (const auto& chunk : args[1].chunked_array()->chunks()) { - RETURN_NOT_OK(ReverseIndices(*chunk->data(), reverse_indices->data())); - } - break; - default: - return Status::NotImplemented("Unsupported shape for permute operation: indices=", - args[1].ToString()); - break; - } - return CallFunction("take", {args[0], reverse_indices}, ctx); + ReverseIndexOptions reverse_index_options{output_length, indices.type(), + std::move(output_non_taken)}; + ARROW_ASSIGN_OR_RAISE( + auto reverse_indices, + CallFunction("reverse_index", {indices}, &reverse_index_options, ctx)); + TakeOptions take_options{/*boundcheck=*/false}; + return CallFunction("take", {values, reverse_indices}, &take_options, ctx); } - - private: - Status ReverseIndices(const ArraySpan& indices, - const std::shared_ptr& reverse_indices) const { - auto reverse_indices_validity = reverse_indices->GetMutableValues(0); - auto reverse_indices_data = reverse_indices->GetMutableValues(1); - auto length = reverse_indices->length; - int64_t reverse_index = 0; - return VisitArraySpanInline( - indices, - [&](int64_t index) { - if (ARROW_PREDICT_TRUE(index > 0 && index < length)) { - bit_util::SetBitTo(reverse_indices_validity, index, true); - reverse_indices_data[index] = reverse_index; - } - ++reverse_index; - return Status::OK(); - }, - [&]() { - ++reverse_index; - return Status::OK(); - }); - }; }; -} // namespace +// ---------------------------------------------------------------------- void RegisterVectorReverseIndex(FunctionRegistry* registry) { auto function = @@ -319,9 +296,15 @@ void RegisterVectorReverseIndex(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(function))); } +void RegisterVectorPermute(FunctionRegistry* registry) { + DCHECK_OK(registry->AddFunction(std::make_shared())); +} + +} // namespace + void RegisterVectorPlacement(FunctionRegistry* registry) { RegisterVectorReverseIndex(registry); - DCHECK_OK(registry->AddFunction(std::make_shared())); + RegisterVectorPermute(registry); } } // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/kernels/vector_placement_test.cc b/cpp/src/arrow/compute/kernels/vector_placement_test.cc index bfe38c731dc7c..86203ad53d651 100644 --- a/cpp/src/arrow/compute/kernels/vector_placement_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_placement_test.cc @@ -145,4 +145,14 @@ TEST(ReverseIndex, Overflow) { } } +TEST(Permute, Basic) { + auto values = ArrayFromJSON(int64(), "[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]"); + auto indices = ArrayFromJSON(int64(), "[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]"); + auto expected = ArrayFromJSON(int64(), "[19, 18, 17, 16, 15, 14, 13, 12, 11, 10]"); + PermuteOptions options{10}; + ASSERT_OK_AND_ASSIGN(Datum result, + CallFunction("permute", {values, indices}, &options)); + AssertDatumsEqual(expected, result); +} + }; // namespace arrow::compute