Skip to content

Commit 26c25d1

Browse files
js8544bkietz
andauthored
GH-35786: [C++] Add pairwise_diff function (#35787)
### Rationale for this change Add a `pairwise_diff` function similar to pandas' [Series.Diff](https://pandas.pydata.org/docs/reference/api/pandas.Series.diff.html), the function computes the first order difference of an array. ### What changes are included in this PR? I followed [these instructions](#12460 (comment)). The function is implemented for numerical, temporal and decimal types. Chuck arrays are not yet supported. ### Are these changes tested? Yes. They are tested in vector_pairwise_test.cc and in python/pyarrow/tests/compute.py. ### Are there any user-facing changes? Yes, and docs are also updated in this PR. * Closes: #35786 Lead-authored-by: Jin Shang <shangjin1997@gmail.com> Co-authored-by: Benjamin Kietzman <bengilgit@gmail.com> Signed-off-by: Benjamin Kietzman <bengilgit@gmail.com>
1 parent 1ab00ae commit 26c25d1

File tree

16 files changed

+549
-3
lines changed

16 files changed

+549
-3
lines changed

cpp/src/arrow/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ if(ARROW_COMPUTE)
456456
compute/kernels/scalar_validity.cc
457457
compute/kernels/vector_array_sort.cc
458458
compute/kernels/vector_cumulative_ops.cc
459+
compute/kernels/vector_pairwise.cc
459460
compute/kernels/vector_nested.cc
460461
compute/kernels/vector_rank.cc
461462
compute/kernels/vector_replace.cc

cpp/src/arrow/compute/api_vector.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "arrow/result.h"
3636
#include "arrow/util/checked_cast.h"
3737
#include "arrow/util/logging.h"
38+
#include "arrow/util/reflection_internal.h"
3839

3940
namespace arrow {
4041

@@ -150,6 +151,8 @@ static auto kRankOptionsType = GetFunctionOptionsType<RankOptions>(
150151
DataMember("sort_keys", &RankOptions::sort_keys),
151152
DataMember("null_placement", &RankOptions::null_placement),
152153
DataMember("tiebreaker", &RankOptions::tiebreaker));
154+
static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
155+
DataMember("periods", &PairwiseOptions::periods));
153156
} // namespace
154157
} // namespace internal
155158

@@ -217,6 +220,10 @@ RankOptions::RankOptions(std::vector<SortKey> sort_keys, NullPlacement null_plac
217220
tiebreaker(tiebreaker) {}
218221
constexpr char RankOptions::kTypeName[];
219222

223+
PairwiseOptions::PairwiseOptions(int64_t periods)
224+
: FunctionOptions(internal::kPairwiseOptionsType), periods(periods) {}
225+
constexpr char PairwiseOptions::kTypeName[];
226+
220227
namespace internal {
221228
void RegisterVectorOptions(FunctionRegistry* registry) {
222229
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
@@ -229,6 +236,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
229236
DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType));
230237
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType));
231238
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
239+
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
232240
}
233241
} // namespace internal
234242

@@ -338,6 +346,15 @@ Result<std::shared_ptr<StructArray>> ValueCounts(const Datum& value, ExecContext
338346
return checked_pointer_cast<StructArray>(result.make_array());
339347
}
340348

349+
Result<std::shared_ptr<Array>> PairwiseDiff(const Array& array,
350+
const PairwiseOptions& options,
351+
bool check_overflow, ExecContext* ctx) {
352+
auto func_name = check_overflow ? "pairwise_diff_checked" : "pairwise_diff";
353+
ARROW_ASSIGN_OR_RAISE(Datum result,
354+
CallFunction(func_name, {Datum(array)}, &options, ctx));
355+
return result.make_array();
356+
}
357+
341358
// ----------------------------------------------------------------------
342359
// Filter- and take-related selection functions
343360

cpp/src/arrow/compute/api_vector.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,17 @@ class ARROW_EXPORT CumulativeOptions : public FunctionOptions {
234234
};
235235
using CumulativeSumOptions = CumulativeOptions; // For backward compatibility
236236

237+
/// \brief Options for pairwise functions
238+
class ARROW_EXPORT PairwiseOptions : public FunctionOptions {
239+
public:
240+
explicit PairwiseOptions(int64_t periods = 1);
241+
static constexpr char const kTypeName[] = "PairwiseOptions";
242+
static PairwiseOptions Defaults() { return PairwiseOptions(); }
243+
244+
/// Periods to shift for applying the binary operation, accepts negative values.
245+
int64_t periods = 1;
246+
};
247+
237248
/// @}
238249

239250
/// \brief Filter with a boolean selection filter
@@ -650,6 +661,28 @@ Result<Datum> CumulativeMin(
650661
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
651662
ExecContext* ctx = NULLPTR);
652663

664+
/// \brief Return the first order difference of an array.
665+
///
666+
/// Computes the first order difference of an array, i.e.
667+
/// output[i] = input[i] - input[i - p] if i >= p
668+
/// output[i] = null otherwise
669+
/// where p is the period. For example, with p = 1,
670+
/// Diff([1, 4, 9, 10, 15]) = [null, 3, 5, 1, 5].
671+
/// With p = 2,
672+
/// Diff([1, 4, 9, 10, 15]) = [null, null, 8, 6, 6]
673+
/// p can also be negative, in which case the diff is computed in
674+
/// the opposite direction.
675+
/// \param[in] array array input
676+
/// \param[in] options options, specifying overflow behavior and period
677+
/// \param[in] check_overflow whether to return error on overflow
678+
/// \param[in] ctx the function execution context, optional
679+
/// \return result as array
680+
ARROW_EXPORT
681+
Result<std::shared_ptr<Array>> PairwiseDiff(const Array& array,
682+
const PairwiseOptions& options,
683+
bool check_overflow = false,
684+
ExecContext* ctx = NULLPTR);
685+
653686
// ----------------------------------------------------------------------
654687
// Deprecated functions
655688

cpp/src/arrow/compute/exec.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,9 @@ struct ARROW_EXPORT ExecResult {
356356
const std::shared_ptr<ArrayData>& array_data() const {
357357
return std::get<std::shared_ptr<ArrayData>>(this->value);
358358
}
359+
ArrayData* array_data_mutable() {
360+
return std::get<std::shared_ptr<ArrayData>>(this->value).get();
361+
}
359362

360363
bool is_array_data() const { return this->value.index() == 1; }
361364
};

cpp/src/arrow/compute/kernel.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,14 +283,16 @@ class ARROW_EXPORT OutputType {
283283
///
284284
/// This function SHOULD _not_ be used to check for arity, that is to be
285285
/// performed one or more layers above.
286-
using Resolver = Result<TypeHolder> (*)(KernelContext*, const std::vector<TypeHolder>&);
286+
using Resolver =
287+
std::function<Result<TypeHolder>(KernelContext*, const std::vector<TypeHolder>&)>;
287288

288289
/// \brief Output an exact type
289290
OutputType(std::shared_ptr<DataType> type) // NOLINT implicit construction
290291
: kind_(FIXED), type_(std::move(type)) {}
291292

292293
/// \brief Output a computed type depending on actual input types
293-
OutputType(Resolver resolver) // NOLINT implicit construction
294+
template <typename Fn>
295+
OutputType(Fn resolver) // NOLINT implicit construction
294296
: kind_(COMPUTED), resolver_(std::move(resolver)) {}
295297

296298
OutputType(const OutputType& other) {

cpp/src/arrow/compute/kernels/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ add_arrow_benchmark(scalar_temporal_benchmark PREFIX "arrow-compute")
6969
add_arrow_compute_test(vector_test
7070
SOURCES
7171
vector_cumulative_ops_test.cc
72+
vector_pairwise_test.cc
7273
vector_hash_test.cc
7374
vector_nested_test.cc
7475
vector_replace_test.cc
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
// Vector kernels for pairwise computation
19+
20+
#include <iostream>
21+
#include <memory>
22+
#include "arrow/builder.h"
23+
#include "arrow/compute/api_vector.h"
24+
#include "arrow/compute/exec.h"
25+
#include "arrow/compute/function.h"
26+
#include "arrow/compute/kernel.h"
27+
#include "arrow/compute/kernels/base_arithmetic_internal.h"
28+
#include "arrow/compute/kernels/codegen_internal.h"
29+
#include "arrow/compute/registry.h"
30+
#include "arrow/compute/util.h"
31+
#include "arrow/status.h"
32+
#include "arrow/type.h"
33+
#include "arrow/type_fwd.h"
34+
#include "arrow/type_traits.h"
35+
#include "arrow/util/bit_util.h"
36+
#include "arrow/util/checked_cast.h"
37+
#include "arrow/util/logging.h"
38+
#include "arrow/visit_type_inline.h"
39+
40+
namespace arrow::compute::internal {
41+
42+
// We reuse the kernel exec function of a scalar binary function to compute pairwise
43+
// results. For example, for pairwise_diff, we reuse subtract's kernel exec.
44+
struct PairwiseState : KernelState {
45+
PairwiseState(const PairwiseOptions& options, ArrayKernelExec scalar_exec)
46+
: periods(options.periods), scalar_exec(scalar_exec) {}
47+
48+
int64_t periods;
49+
ArrayKernelExec scalar_exec;
50+
};
51+
52+
/// A generic pairwise implementation that can be reused by different ops.
53+
Status PairwiseExecImpl(KernelContext* ctx, const ArraySpan& input,
54+
const ArrayKernelExec& scalar_exec, int64_t periods,
55+
ArrayData* result) {
56+
// We only compute values in the region where the input-with-offset overlaps
57+
// the original input. The margin where these do not overlap gets filled with null.
58+
auto margin_length = std::min(abs(periods), input.length);
59+
auto computed_length = input.length - margin_length;
60+
auto margin_start = periods > 0 ? 0 : computed_length;
61+
auto computed_start = periods > 0 ? margin_length : 0;
62+
auto left_start = computed_start;
63+
auto right_start = margin_length - computed_start;
64+
// prepare bitmap
65+
bit_util::ClearBitmap(result->buffers[0]->mutable_data(), margin_start, margin_length);
66+
for (int64_t i = computed_start; i < computed_start + computed_length; i++) {
67+
if (input.IsValid(i) && input.IsValid(i - periods)) {
68+
bit_util::SetBit(result->buffers[0]->mutable_data(), i);
69+
} else {
70+
bit_util::ClearBit(result->buffers[0]->mutable_data(), i);
71+
}
72+
}
73+
// prepare input span
74+
ArraySpan left(input);
75+
left.SetSlice(left_start, computed_length);
76+
ArraySpan right(input);
77+
right.SetSlice(right_start, computed_length);
78+
// prepare output span
79+
ArraySpan output_span;
80+
output_span.SetMembers(*result);
81+
output_span.offset = computed_start;
82+
output_span.length = computed_length;
83+
ExecResult output{output_span};
84+
// execute scalar function
85+
RETURN_NOT_OK(scalar_exec(ctx, ExecSpan({left, right}, computed_length), &output));
86+
87+
return Status::OK();
88+
}
89+
90+
Status PairwiseExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
91+
const auto& state = checked_cast<const PairwiseState&>(*ctx->state());
92+
auto input = batch[0].array;
93+
RETURN_NOT_OK(PairwiseExecImpl(ctx, batch[0].array, state.scalar_exec, state.periods,
94+
out->array_data_mutable()));
95+
return Status::OK();
96+
}
97+
98+
const FunctionDoc pairwise_diff_doc(
99+
"Compute first order difference of an array",
100+
("Computes the first order difference of an array, It internally calls \n"
101+
"the scalar function \"subtract\" to compute \n differences, so its \n"
102+
"behavior and supported types are the same as \n"
103+
"\"subtract\". The period can be specified in :struct:`PairwiseOptions`.\n"
104+
"\n"
105+
"Results will wrap around on integer overflow. Use function \n"
106+
"\"pairwise_diff_checked\" if you want overflow to return an error."),
107+
{"input"}, "PairwiseOptions");
108+
109+
const FunctionDoc pairwise_diff_checked_doc(
110+
"Compute first order difference of an array",
111+
("Computes the first order difference of an array, It internally calls \n"
112+
"the scalar function \"subtract_checked\" (or the checked variant) to compute \n"
113+
"differences, so its behavior and supported types are the same as \n"
114+
"\"subtract_checked\". The period can be specified in :struct:`PairwiseOptions`.\n"
115+
"\n"
116+
"This function returns an error on overflow. For a variant that doesn't \n"
117+
"fail on overflow, use function \"pairwise_diff\"."),
118+
{"input"}, "PairwiseOptions");
119+
120+
const PairwiseOptions* GetDefaultPairwiseOptions() {
121+
static const auto kDefaultPairwiseOptions = PairwiseOptions::Defaults();
122+
return &kDefaultPairwiseOptions;
123+
}
124+
125+
struct PairwiseKernelData {
126+
InputType input;
127+
OutputType output;
128+
ArrayKernelExec exec;
129+
};
130+
131+
void RegisterPairwiseDiffKernels(std::string_view func_name,
132+
std::string_view base_func_name, const FunctionDoc& doc,
133+
FunctionRegistry* registry) {
134+
VectorKernel kernel;
135+
kernel.can_execute_chunkwise = false;
136+
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
137+
kernel.mem_allocation = MemAllocation::PREALLOCATE;
138+
kernel.init = OptionsWrapper<PairwiseOptions>::Init;
139+
auto func = std::make_shared<VectorFunction>(std::string(func_name), Arity::Unary(),
140+
doc, GetDefaultPairwiseOptions());
141+
142+
auto base_func_result = registry->GetFunction(std::string(base_func_name));
143+
DCHECK_OK(base_func_result.status());
144+
const auto& base_func = checked_cast<const ScalarFunction&>(**base_func_result);
145+
DCHECK_EQ(base_func.arity().num_args, 2);
146+
147+
for (const auto& base_func_kernel : base_func.kernels()) {
148+
const auto& base_func_kernel_sig = base_func_kernel->signature;
149+
if (!base_func_kernel_sig->in_types()[0].Equals(
150+
base_func_kernel_sig->in_types()[1])) {
151+
continue;
152+
}
153+
OutputType out_type(base_func_kernel_sig->out_type());
154+
// Need to wrap base output resolver
155+
if (out_type.kind() == OutputType::COMPUTED) {
156+
out_type =
157+
OutputType([base_resolver = base_func_kernel_sig->out_type().resolver()](
158+
KernelContext* ctx, const std::vector<TypeHolder>& input_types) {
159+
return base_resolver(ctx, {input_types[0], input_types[0]});
160+
});
161+
}
162+
163+
kernel.signature =
164+
KernelSignature::Make({base_func_kernel_sig->in_types()[0]}, out_type);
165+
kernel.exec = PairwiseExec;
166+
kernel.init = [scalar_exec = base_func_kernel->exec](KernelContext* ctx,
167+
const KernelInitArgs& args) {
168+
return std::make_unique<PairwiseState>(
169+
checked_cast<const PairwiseOptions&>(*args.options), scalar_exec);
170+
};
171+
DCHECK_OK(func->AddKernel(kernel));
172+
}
173+
174+
DCHECK_OK(registry->AddFunction(std::move(func)));
175+
}
176+
177+
void RegisterVectorPairwise(FunctionRegistry* registry) {
178+
RegisterPairwiseDiffKernels("pairwise_diff", "subtract", pairwise_diff_doc, registry);
179+
RegisterPairwiseDiffKernels("pairwise_diff_checked", "subtract_checked",
180+
pairwise_diff_checked_doc, registry);
181+
}
182+
183+
} // namespace arrow::compute::internal

0 commit comments

Comments
 (0)