Skip to content

Commit 6e6e6f0

Browse files
js8544pitrou
andauthored
GH-36931: [C++] Add cumulative_mean function (#36932)
### Rationale for this change Add `cumulative_mean` function ### What changes are included in this PR? Implement `cumulative_mean` function. The current cumulative_* kernel generator can only be based on a simple binary arithmetic op and the state can only be a single value. I refactored it to using of a generic state such that it can handle complex operations such as `mean`, `median`, `var` etc. ### Are these changes tested? Yes ### Are there any user-facing changes? No * Closes: #36931 Lead-authored-by: Jin Shang <shangjin1997@gmail.com> Co-authored-by: Antoine Pitrou <antoine@python.org> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 9f183fc commit 6e6e6f0

File tree

5 files changed

+268
-77
lines changed

5 files changed

+268
-77
lines changed

cpp/src/arrow/compute/api_vector.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,5 +417,10 @@ Result<Datum> CumulativeMin(const Datum& values, const CumulativeOptions& option
417417
return CallFunction("cumulative_min", {Datum(values)}, &options, ctx);
418418
}
419419

420+
Result<Datum> CumulativeMean(const Datum& values, const CumulativeOptions& options,
421+
ExecContext* ctx) {
422+
return CallFunction("cumulative_mean", {Datum(values)}, &options, ctx);
423+
}
424+
420425
} // namespace compute
421426
} // namespace arrow

cpp/src/arrow/compute/api_vector.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ class ARROW_EXPORT CumulativeOptions : public FunctionOptions {
226226
/// - prod: 1
227227
/// - min: maximum of the input type
228228
/// - max: minimum of the input type
229+
/// - mean: start is ignored because it has no meaning for mean
229230
std::optional<std::shared_ptr<Scalar>> start;
230231

231232
/// If true, nulls in the input are ignored and produce a corresponding null output.
@@ -661,6 +662,16 @@ Result<Datum> CumulativeMin(
661662
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
662663
ExecContext* ctx = NULLPTR);
663664

665+
/// \brief Compute the cumulative mean of an array-like object
666+
///
667+
/// \param[in] values array-like input
668+
/// \param[in] options configures cumulative mean behavior, `start` is ignored
669+
/// \param[in] ctx the function execution context, optional
670+
ARROW_EXPORT
671+
Result<Datum> CumulativeMean(
672+
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
673+
ExecContext* ctx = NULLPTR);
674+
664675
/// \brief Return the first order difference of an array.
665676
///
666677
/// Computes the first order difference of an array, i.e.

cpp/src/arrow/compute/kernels/vector_cumulative_ops.cc

Lines changed: 140 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@
2525
#include "arrow/compute/kernels/codegen_internal.h"
2626
#include "arrow/compute/kernels/common_internal.h"
2727
#include "arrow/result.h"
28+
#include "arrow/type_traits.h"
2829
#include "arrow/util/bit_util.h"
2930
#include "arrow/visit_type_inline.h"
3031

31-
namespace arrow {
32-
namespace compute {
33-
namespace internal {
32+
namespace arrow::compute::internal {
3433

3534
namespace {
3635

@@ -63,19 +62,60 @@ struct CumulativeOptionsWrapper : public OptionsWrapper<OptionsType> {
6362
}
6463
};
6564

66-
// The driver kernel for all cumulative compute functions. Op is a compute kernel
67-
// representing any binary associative operation with an identity element (add, product,
68-
// min, max, etc.), i.e. ones that form a monoid, and OptionsType the options type
69-
// corresponding to Op. ArgType and OutType are the input and output types, which will
65+
// The cumulative value is computed based on a simple arithmetic binary op
66+
// such as Add, Mul, Min, Max, etc.
67+
template <typename Op, typename ArgType>
68+
struct CumulativeBinaryOp {
69+
using OutType = ArgType;
70+
using OutValue = typename GetOutputType<OutType>::T;
71+
using ArgValue = typename GetViewType<ArgType>::T;
72+
73+
OutValue current_value;
74+
75+
CumulativeBinaryOp() { current_value = Identity<Op>::template value<OutValue>; }
76+
77+
explicit CumulativeBinaryOp(const std::shared_ptr<Scalar> start) {
78+
current_value = UnboxScalar<OutType>::Unbox(*start);
79+
}
80+
81+
OutValue Call(KernelContext* ctx, ArgValue arg, Status* st) {
82+
current_value =
83+
Op::template Call<OutValue, ArgValue, ArgValue>(ctx, arg, current_value, st);
84+
return current_value;
85+
}
86+
};
87+
88+
template <typename ArgType>
89+
struct CumulativeMean {
90+
using OutType = DoubleType;
91+
using ArgValue = typename GetViewType<ArgType>::T;
92+
int64_t count = 0;
93+
double sum = 0;
94+
95+
CumulativeMean() = default;
96+
97+
// start value is ignored for CumulativeMean
98+
explicit CumulativeMean(const std::shared_ptr<Scalar> start) {}
99+
100+
double Call(KernelContext* ctx, ArgValue arg, Status* st) {
101+
sum += static_cast<double>(arg);
102+
++count;
103+
return sum / count;
104+
}
105+
};
106+
107+
// The driver kernel for all cumulative compute functions.
108+
// ArgType and OutType are the input and output types, which will
70109
// normally be the same (e.g. the cumulative sum of an array of Int64Type will result in
71-
// an array of Int64Type).
72-
template <typename OutType, typename ArgType, typename Op, typename OptionsType>
110+
// an array of Int64Type) with the exception of CumulativeMean, which will always return
111+
// a double.
112+
template <typename ArgType, typename CumulativeState>
73113
struct Accumulator {
74-
using OutValue = typename GetOutputType<OutType>::T;
114+
using OutType = typename CumulativeState::OutType;
75115
using ArgValue = typename GetViewType<ArgType>::T;
76116

77117
KernelContext* ctx;
78-
ArgValue current_value;
118+
CumulativeState current_state;
79119
bool skip_nulls;
80120
bool encountered_null = false;
81121
NumericBuilder<OutType> builder;
@@ -88,21 +128,15 @@ struct Accumulator {
88128
if (skip_nulls || (input.GetNullCount() == 0 && !encountered_null)) {
89129
VisitArrayValuesInline<ArgType>(
90130
input,
91-
[&](ArgValue v) {
92-
current_value = Op::template Call<OutValue, ArgValue, ArgValue>(
93-
ctx, v, current_value, &st);
94-
builder.UnsafeAppend(current_value);
95-
},
131+
[&](ArgValue v) { builder.UnsafeAppend(current_state.Call(ctx, v, &st)); },
96132
[&]() { builder.UnsafeAppendNull(); });
97133
} else {
98134
int64_t nulls_start_idx = 0;
99135
VisitArrayValuesInline<ArgType>(
100136
input,
101137
[&](ArgValue v) {
102138
if (!encountered_null) {
103-
current_value = Op::template Call<OutValue, ArgValue, ArgValue>(
104-
ctx, v, current_value, &st);
105-
builder.UnsafeAppend(current_value);
139+
builder.UnsafeAppend(current_state.Call(ctx, v, &st));
106140
++nulls_start_idx;
107141
}
108142
},
@@ -115,16 +149,17 @@ struct Accumulator {
115149
}
116150
};
117151

118-
template <typename OutType, typename ArgType, typename Op, typename OptionsType>
152+
template <typename ArgType, typename CumulativeState, typename OptionsType>
119153
struct CumulativeKernel {
154+
using OutType = typename CumulativeState::OutType;
120155
using OutValue = typename GetOutputType<OutType>::T;
121156
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
122157
const auto& options = CumulativeOptionsWrapper<OptionsType>::Get(ctx);
123-
Accumulator<OutType, ArgType, Op, OptionsType> accumulator(ctx);
158+
Accumulator<ArgType, CumulativeState> accumulator(ctx);
124159
if (options.start.has_value()) {
125-
accumulator.current_value = UnboxScalar<OutType>::Unbox(*(options.start.value()));
160+
accumulator.current_state = CumulativeState(options.start.value());
126161
} else {
127-
accumulator.current_value = Identity<Op>::template value<OutValue>;
162+
accumulator.current_state = CumulativeState();
128163
}
129164
accumulator.skip_nulls = options.skip_nulls;
130165

@@ -138,16 +173,17 @@ struct CumulativeKernel {
138173
}
139174
};
140175

141-
template <typename OutType, typename ArgType, typename Op, typename OptionsType>
176+
template <typename ArgType, typename CumulativeState, typename OptionsType>
142177
struct CumulativeKernelChunked {
178+
using OutType = typename CumulativeState::OutType;
143179
using OutValue = typename GetOutputType<OutType>::T;
144180
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
145181
const auto& options = CumulativeOptionsWrapper<OptionsType>::Get(ctx);
146-
Accumulator<OutType, ArgType, Op, OptionsType> accumulator(ctx);
182+
Accumulator<ArgType, CumulativeState> accumulator(ctx);
147183
if (options.start.has_value()) {
148-
accumulator.current_value = UnboxScalar<OutType>::Unbox(*(options.start.value()));
184+
accumulator.current_state = CumulativeState(options.start.value());
149185
} else {
150-
accumulator.current_value = Identity<Op>::template value<OutValue>;
186+
accumulator.current_state = CumulativeState();
151187
}
152188
accumulator.skip_nulls = options.skip_nulls;
153189

@@ -217,53 +253,102 @@ const FunctionDoc cumulative_min_doc{
217253
"start as the new minimum)."),
218254
{"values"},
219255
"CumulativeOptions"};
220-
} // namespace
221256

222-
template <typename Op, typename OptionsType>
223-
void MakeVectorCumulativeFunction(FunctionRegistry* registry, const std::string func_name,
224-
const FunctionDoc doc) {
257+
const FunctionDoc cumulative_mean_doc{
258+
"Compute the cumulative mean over a numeric input",
259+
("`values` must be numeric. Return an array/chunked array which is the\n"
260+
"cumulative mean computed over `values`. CumulativeOptions::start_value is \n"
261+
"ignored."),
262+
{"values"},
263+
"CumulativeOptions"};
264+
265+
// Kernel factory for complex stateful computations.
266+
template <template <typename ArgType> typename State, typename OptionsType>
267+
struct CumulativeStatefulKernelFactory {
268+
VectorKernel kernel;
269+
270+
CumulativeStatefulKernelFactory() {
271+
kernel.can_execute_chunkwise = false;
272+
kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
273+
kernel.mem_allocation = MemAllocation::type::NO_PREALLOCATE;
274+
kernel.init = CumulativeOptionsWrapper<OptionsType>::Init;
275+
}
276+
277+
template <typename Type>
278+
enable_if_number<Type, Status> Visit(const Type& type) {
279+
kernel.signature = KernelSignature::Make(
280+
{type.GetSharedPtr()},
281+
OutputType(TypeTraits<typename State<Type>::OutType>::type_singleton()));
282+
kernel.exec = CumulativeKernel<Type, State<Type>, OptionsType>::Exec;
283+
kernel.exec_chunked = CumulativeKernelChunked<Type, State<Type>, OptionsType>::Exec;
284+
return arrow::Status::OK();
285+
}
286+
287+
Status Visit(const DataType& type) {
288+
return Status::NotImplemented("Cumulative kernel not implemented for type ",
289+
type.ToString());
290+
}
291+
292+
Result<VectorKernel> Make(const DataType& type) {
293+
RETURN_NOT_OK(VisitTypeInline(type, this));
294+
return kernel;
295+
}
296+
};
297+
298+
template <template <typename ArgType> typename State, typename OptionsType>
299+
void MakeVectorCumulativeStatefulFunction(FunctionRegistry* registry,
300+
const std::string func_name,
301+
const FunctionDoc doc) {
225302
static const OptionsType kDefaultOptions = OptionsType::Defaults();
226303
auto func =
227304
std::make_shared<VectorFunction>(func_name, Arity::Unary(), doc, &kDefaultOptions);
228305

229306
std::vector<std::shared_ptr<DataType>> types;
230307
types.insert(types.end(), NumericTypes().begin(), NumericTypes().end());
231308

309+
CumulativeStatefulKernelFactory<State, OptionsType> kernel_factory;
232310
for (const auto& ty : types) {
233-
VectorKernel kernel;
234-
kernel.can_execute_chunkwise = false;
235-
kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
236-
kernel.mem_allocation = MemAllocation::type::NO_PREALLOCATE;
237-
kernel.signature = KernelSignature::Make({ty}, OutputType(ty));
238-
kernel.exec =
239-
ArithmeticExecFromOp<CumulativeKernel, Op, ArrayKernelExec, OptionsType>(ty);
240-
kernel.exec_chunked =
241-
ArithmeticExecFromOp<CumulativeKernelChunked, Op, VectorKernel::ChunkedExec,
242-
OptionsType>(ty);
243-
kernel.init = CumulativeOptionsWrapper<OptionsType>::Init;
311+
auto kernel = kernel_factory.Make(*ty).ValueOrDie();
244312
DCHECK_OK(func->AddKernel(std::move(kernel)));
245313
}
246314

247315
DCHECK_OK(registry->AddFunction(std::move(func)));
248316
}
249317

318+
// A kernel factory that forwards to CumulativeBinaryOp<Op, ...> for the given type.
319+
// Need to use a struct because template-using declarations cannot appear in
320+
// function scope.
321+
template <typename Op, typename OptionsType>
322+
struct MakeVectorCumulativeBinaryOpFunction {
323+
template <typename ArgType>
324+
using State = CumulativeBinaryOp<Op, ArgType>;
325+
326+
static void Call(FunctionRegistry* registry, std::string func_name, FunctionDoc doc) {
327+
MakeVectorCumulativeStatefulFunction<State, OptionsType>(
328+
registry, std::move(func_name), std::move(doc));
329+
}
330+
};
331+
332+
} // namespace
333+
250334
void RegisterVectorCumulativeSum(FunctionRegistry* registry) {
251-
MakeVectorCumulativeFunction<Add, CumulativeOptions>(registry, "cumulative_sum",
252-
cumulative_sum_doc);
253-
MakeVectorCumulativeFunction<AddChecked, CumulativeOptions>(
335+
MakeVectorCumulativeBinaryOpFunction<Add, CumulativeOptions>::Call(
336+
registry, "cumulative_sum", cumulative_sum_doc);
337+
MakeVectorCumulativeBinaryOpFunction<AddChecked, CumulativeOptions>::Call(
254338
registry, "cumulative_sum_checked", cumulative_sum_checked_doc);
255339

256-
MakeVectorCumulativeFunction<Multiply, CumulativeOptions>(registry, "cumulative_prod",
257-
cumulative_prod_doc);
258-
MakeVectorCumulativeFunction<MultiplyChecked, CumulativeOptions>(
340+
MakeVectorCumulativeBinaryOpFunction<Multiply, CumulativeOptions>::Call(
341+
registry, "cumulative_prod", cumulative_prod_doc);
342+
MakeVectorCumulativeBinaryOpFunction<MultiplyChecked, CumulativeOptions>::Call(
259343
registry, "cumulative_prod_checked", cumulative_prod_checked_doc);
260344

261-
MakeVectorCumulativeFunction<Min, CumulativeOptions>(registry, "cumulative_min",
262-
cumulative_min_doc);
263-
MakeVectorCumulativeFunction<Max, CumulativeOptions>(registry, "cumulative_max",
264-
cumulative_max_doc);
345+
MakeVectorCumulativeBinaryOpFunction<Min, CumulativeOptions>::Call(
346+
registry, "cumulative_min", cumulative_min_doc);
347+
MakeVectorCumulativeBinaryOpFunction<Max, CumulativeOptions>::Call(
348+
registry, "cumulative_max", cumulative_max_doc);
349+
350+
MakeVectorCumulativeStatefulFunction<CumulativeMean, CumulativeOptions>(
351+
registry, "cumulative_mean", cumulative_max_doc);
265352
}
266353

267-
} // namespace internal
268-
} // namespace compute
269-
} // namespace arrow
354+
} // namespace arrow::compute::internal

0 commit comments

Comments
 (0)