Skip to content

Commit

Permalink
Implement permute
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 7, 2024
1 parent b445c36 commit 3707fa6
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 78 deletions.
8 changes: 4 additions & 4 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ static auto kReverseIndexOptionsType = GetFunctionOptionsType<ReverseIndexOption
DataMember("output_length", &ReverseIndexOptions::output_length),
DataMember("output_type", &ReverseIndexOptions::output_type),
DataMember("output_non_taken", &ReverseIndexOptions::output_non_taken));
static auto kPermuteOptionsType =
GetFunctionOptionsType<PermuteOptions>(DataMember("bound", &PermuteOptions::bound));
static auto kPermuteOptionsType = GetFunctionOptionsType<PermuteOptions>(
DataMember("output_length", &PermuteOptions::output_length));
} // namespace
} // namespace internal

Expand Down Expand Up @@ -245,8 +245,8 @@ ReverseIndexOptions::ReverseIndexOptions(int64_t output_length,
output_non_taken(std::move(output_non_taken)) {}
constexpr char ReverseIndexOptions::kTypeName[];

PermuteOptions::PermuteOptions(int64_t bound)
: FunctionOptions(internal::kPermuteOptionsType), bound(bound) {}
PermuteOptions::PermuteOptions(int64_t output_length)
: FunctionOptions(internal::kPermuteOptionsType), output_length(output_length) {}
constexpr char PermuteOptions::kTypeName[];

namespace internal {
Expand Down
23 changes: 13 additions & 10 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,27 +266,30 @@ class ARROW_EXPORT ReverseIndexOptions : public FunctionOptions {
static constexpr char const kTypeName[] = "ReverseIndexOptions";
static ReverseIndexOptions Defaults() { return ReverseIndexOptions(); }

/// \brief The upper bound of the permutation. If -1, the output will be sized as the
/// maximum value in the indices array + 1. Otherwise, the output will be of size bound,
/// and any indices that are greater of equal to bound will be ignored.
/// \brief The length of the output reverse index. Must be non-negative. Any indices
/// that are greater of equal to this length will be ignored.
int64_t output_length = 0;
/// \brief The type of the output reverse index. If null, the output type will be the
/// smallest possible integer type that can hold the maximum value in the indices array.
/// \brief The type of the output reverse index. Must be integer types. An overflow
/// error will be reported if any reverse indices are out of the bounds of the output
/// type.
std::shared_ptr<DataType> output_type = int32();
/// \brief The value to fill in the output reverse index for indices that are not taken.
/// Must be of the same type as `output_type` if not null. If not provided or an invalid
/// scalar is provided, the non-taken indices will be filled with nulls. Using non-null
/// scalars properly may enable efficient processing of reverse index function.
std::shared_ptr<Scalar> output_non_taken = NULLPTR;
};

/// \brief Options for permute function
class ARROW_EXPORT PermuteOptions : public FunctionOptions {
public:
explicit PermuteOptions(int64_t bound = -1);
explicit PermuteOptions(int64_t output_length = 0);
static constexpr char const kTypeName[] = "PermuteOptions";
static PermuteOptions Defaults() { return PermuteOptions(); }

/// \brief The upper bound of the permutation. If -1, the output will be sized as the
/// maximum value in the indices array + 1. Otherwise, the output will be of size bound,
/// and any indices that are greater of equal to bound will be ignored.
int64_t bound = -1;
/// \brief The length of the output permutation. Must be non-negative. Any values with
/// indices that are greater of equal to this length will be ignored.
int64_t output_length = 0;
};

/// @}
Expand Down
111 changes: 47 additions & 64 deletions cpp/src/arrow/compute/kernels/vector_placement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ namespace arrow::compute::internal {

namespace {

// ----------------------------------------------------------------------
// ReverseIndex

const FunctionDoc reverse_index_doc(
"Compute the reverse indices from an input indices",
"For the `i`-th `index` in `indices`, the `index`-th output is `i`", {"indices"});

const PermuteOptions* GetDefaultReverseIndexOptions() {
static const auto kDefaultPermuteOptions = PermuteOptions::Defaults();
return &kDefaultPermuteOptions;
}

struct ReverseIndexState : public KernelState {
explicit ReverseIndexState(int64_t length, std::shared_ptr<DataType> type,
std::shared_ptr<Buffer> validity,
Expand Down Expand Up @@ -203,15 +215,8 @@ struct ReverseIndexChunked {
}
};

const FunctionDoc reverse_index_doc(
"Compute the reverse indices from an input indices",
"Place each input value to the output array at position specified by `indices`",
{"indices"});

const PermuteOptions* GetDefaultReverseIndexOptions() {
static const auto kDefaultPermuteOptions = PermuteOptions::Defaults();
return &kDefaultPermuteOptions;
}
// ----------------------------------------------------------------------
// Permute

const FunctionDoc permute_doc(
"Permute values of an input based on indices from another array",
Expand All @@ -232,68 +237,40 @@ class PermuteMetaFunction : public MetaFunction {
Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
const FunctionOptions* options,
ExecContext* ctx) const override {
const auto& permute_options = checked_cast<const PermuteOptions&>(*options);
if (args[0].length() != args[1].length()) {
return Status::Invalid("Input and indices must have the same length");
const auto& values = args[0];
const auto& indices = args[1];
auto* permute_options = checked_cast<const PermuteOptions*>(options);
if (values.length() != indices.length()) {
return Status::Invalid(
"Input and indices of permute must have the same length, got " +
std::to_string(values.length()) + " and " + std::to_string(indices.length()));
}
if (args[0].length() == 0) {
return args[0];
if (!is_integer(indices.type()->id())) {
return Status::Invalid("Indices of permute must be of integer type, got ",
indices.type()->ToString());
}
int64_t output_length = permute_options.bound;
int64_t output_length = permute_options->output_length;
if (output_length < 0) {
ARROW_ASSIGN_OR_RAISE(auto max_scalar, CallFunction("max", {args[1]}, ctx));
DCHECK(max_scalar.is_scalar());
ARROW_ASSIGN_OR_RAISE(auto max_i64_scalar, max_scalar.scalar()->CastTo(int64()));
output_length = checked_cast<const Int64Scalar*>(max_i64_scalar.get())->value + 1;
return Status::Invalid("Output length of permute must be non-negative, got " +
std::to_string(output_length));
}
if (output_length <= 0) {
ARROW_ASSIGN_OR_RAISE(auto output, MakeEmptyArray(args[0].type()));
return output->data();
std::shared_ptr<Scalar> output_non_taken = nullptr;
if (is_signed_integer(indices.type()->id())) {
// Using -1 (as opposed to null) as output_non_taken for signed integer types to
// enable efficient reverse_index.
ARROW_ASSIGN_OR_RAISE(output_non_taken, MakeScalar(indices.type(), -1));
}
ARROW_ASSIGN_OR_RAISE(auto reverse_indices,
MakeArrayOfNull(int64(), output_length, ctx->memory_pool()));
switch (args[1].kind()) {
case Datum::ARRAY:
RETURN_NOT_OK(ReverseIndices(*args[1].array(), reverse_indices->data()));
break;
case Datum::CHUNKED_ARRAY:
for (const auto& chunk : args[1].chunked_array()->chunks()) {
RETURN_NOT_OK(ReverseIndices(*chunk->data(), reverse_indices->data()));
}
break;
default:
return Status::NotImplemented("Unsupported shape for permute operation: indices=",
args[1].ToString());
break;
}
return CallFunction("take", {args[0], reverse_indices}, ctx);
ReverseIndexOptions reverse_index_options{output_length, indices.type(),
std::move(output_non_taken)};
ARROW_ASSIGN_OR_RAISE(
auto reverse_indices,
CallFunction("reverse_index", {indices}, &reverse_index_options, ctx));
TakeOptions take_options{/*boundcheck=*/false};
return CallFunction("take", {values, reverse_indices}, &take_options, ctx);
}

private:
Status ReverseIndices(const ArraySpan& indices,
const std::shared_ptr<ArrayData>& reverse_indices) const {
auto reverse_indices_validity = reverse_indices->GetMutableValues<uint8_t>(0);
auto reverse_indices_data = reverse_indices->GetMutableValues<int64_t>(1);
auto length = reverse_indices->length;
int64_t reverse_index = 0;
return VisitArraySpanInline<Int64Type>(
indices,
[&](int64_t index) {
if (ARROW_PREDICT_TRUE(index > 0 && index < length)) {
bit_util::SetBitTo(reverse_indices_validity, index, true);
reverse_indices_data[index] = reverse_index;
}
++reverse_index;
return Status::OK();
},
[&]() {
++reverse_index;
return Status::OK();
});
};
};

} // namespace
// ----------------------------------------------------------------------

void RegisterVectorReverseIndex(FunctionRegistry* registry) {
auto function =
Expand All @@ -319,9 +296,15 @@ void RegisterVectorReverseIndex(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(function)));
}

void RegisterVectorPermute(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::make_shared<PermuteMetaFunction>()));
}

} // namespace

void RegisterVectorPlacement(FunctionRegistry* registry) {
RegisterVectorReverseIndex(registry);
DCHECK_OK(registry->AddFunction(std::make_shared<PermuteMetaFunction>()));
RegisterVectorPermute(registry);
}

} // namespace arrow::compute::internal
10 changes: 10 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_placement_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,14 @@ TEST(ReverseIndex, Overflow) {
}
}

TEST(Permute, Basic) {
auto values = ArrayFromJSON(int64(), "[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]");
auto indices = ArrayFromJSON(int64(), "[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]");
auto expected = ArrayFromJSON(int64(), "[19, 18, 17, 16, 15, 14, 13, 12, 11, 10]");
PermuteOptions options{10};
ASSERT_OK_AND_ASSIGN(Datum result,
CallFunction("permute", {values, indices}, &options));
AssertDatumsEqual(expected, result);
}

}; // namespace arrow::compute

0 comments on commit 3707fa6

Please sign in to comment.