2525#include  " arrow/compute/kernels/codegen_internal.h" 
2626#include  " arrow/compute/kernels/common_internal.h" 
2727#include  " arrow/result.h" 
28+ #include  " arrow/type_traits.h" 
2829#include  " arrow/util/bit_util.h" 
2930#include  " arrow/visit_type_inline.h" 
3031
31- namespace  arrow  {
32- namespace  compute  {
33- namespace  internal  {
32+ namespace  arrow ::compute::internal {
3433
3534namespace  {
3635
@@ -63,19 +62,60 @@ struct CumulativeOptionsWrapper : public OptionsWrapper<OptionsType> {
6362  }
6463};
6564
66- //  The driver kernel for all cumulative compute functions. Op is a compute kernel
67- //  representing any binary associative operation with an identity element (add, product,
68- //  min, max, etc.), i.e. ones that form a monoid, and OptionsType the options type
69- //  corresponding to Op. ArgType and OutType are the input and output types, which will
65+ //  The cumulative value is computed based on a simple arithmetic binary op
66+ //  such as Add, Mul, Min, Max, etc.
67+ template  <typename  Op, typename  ArgType>
68+ struct  CumulativeBinaryOp  {
69+   using  OutType = ArgType;
70+   using  OutValue = typename  GetOutputType<OutType>::T;
71+   using  ArgValue = typename  GetViewType<ArgType>::T;
72+ 
73+   OutValue current_value;
74+ 
75+   CumulativeBinaryOp () { current_value = Identity<Op>::template  value<OutValue>; }
76+ 
77+   explicit  CumulativeBinaryOp (const  std::shared_ptr<Scalar> start) {
78+     current_value = UnboxScalar<OutType>::Unbox (*start);
79+   }
80+ 
81+   OutValue Call (KernelContext* ctx, ArgValue arg, Status* st) {
82+     current_value =
83+         Op::template  Call<OutValue, ArgValue, ArgValue>(ctx, arg, current_value, st);
84+     return  current_value;
85+   }
86+ };
87+ 
88+ template  <typename  ArgType>
89+ struct  CumulativeMean  {
90+   using  OutType = DoubleType;
91+   using  ArgValue = typename  GetViewType<ArgType>::T;
92+   int64_t  count = 0 ;
93+   double  sum = 0 ;
94+ 
95+   CumulativeMean () = default ;
96+ 
97+   //  start value is ignored for CumulativeMean
98+   explicit  CumulativeMean (const  std::shared_ptr<Scalar> start) {}
99+ 
100+   double  Call (KernelContext* ctx, ArgValue arg, Status* st) {
101+     sum += static_cast <double >(arg);
102+     ++count;
103+     return  sum / count;
104+   }
105+ };
106+ 
107+ //  The driver kernel for all cumulative compute functions.
108+ //  ArgType and OutType are the input and output types, which will
70109//  normally be the same (e.g. the cumulative sum of an array of Int64Type will result in
71- //  an array of Int64Type).
72- template  <typename  OutType, typename  ArgType, typename  Op, typename  OptionsType>
110+ //  an array of Int64Type) with the exception of CumulativeMean, which will always return
111+ //  a double.
112+ template  <typename  ArgType, typename  CumulativeState>
73113struct  Accumulator  {
74-   using  OutValue  = typename  GetOutputType<OutType>::T ;
114+   using  OutType  = typename  CumulativeState::OutType ;
75115  using  ArgValue = typename  GetViewType<ArgType>::T;
76116
77117  KernelContext* ctx;
78-   ArgValue current_value ;
118+   CumulativeState current_state ;
79119  bool  skip_nulls;
80120  bool  encountered_null = false ;
81121  NumericBuilder<OutType> builder;
@@ -88,21 +128,15 @@ struct Accumulator {
88128    if  (skip_nulls || (input.GetNullCount () == 0  && !encountered_null)) {
89129      VisitArrayValuesInline<ArgType>(
90130          input,
91-           [&](ArgValue v) {
92-             current_value = Op::template  Call<OutValue, ArgValue, ArgValue>(
93-                 ctx, v, current_value, &st);
94-             builder.UnsafeAppend (current_value);
95-           },
131+           [&](ArgValue v) { builder.UnsafeAppend (current_state.Call (ctx, v, &st)); },
96132          [&]() { builder.UnsafeAppendNull (); });
97133    } else  {
98134      int64_t  nulls_start_idx = 0 ;
99135      VisitArrayValuesInline<ArgType>(
100136          input,
101137          [&](ArgValue v) {
102138            if  (!encountered_null) {
103-               current_value = Op::template  Call<OutValue, ArgValue, ArgValue>(
104-                   ctx, v, current_value, &st);
105-               builder.UnsafeAppend (current_value);
139+               builder.UnsafeAppend (current_state.Call (ctx, v, &st));
106140              ++nulls_start_idx;
107141            }
108142          },
@@ -115,16 +149,17 @@ struct Accumulator {
115149  }
116150};
117151
118- template  <typename  OutType,  typename   ArgType, typename  Op , typename  OptionsType>
152+ template  <typename  ArgType, typename  CumulativeState , typename  OptionsType>
119153struct  CumulativeKernel  {
154+   using  OutType = typename  CumulativeState::OutType;
120155  using  OutValue = typename  GetOutputType<OutType>::T;
121156  static  Status Exec (KernelContext* ctx, const  ExecSpan& batch, ExecResult* out) {
122157    const  auto & options = CumulativeOptionsWrapper<OptionsType>::Get (ctx);
123-     Accumulator<OutType,  ArgType, Op, OptionsType > accumulator (ctx);
158+     Accumulator<ArgType, CumulativeState > accumulator (ctx);
124159    if  (options.start .has_value ()) {
125-       accumulator.current_value  = UnboxScalar<OutType>:: Unbox (*( options.start .value () ));
160+       accumulator.current_state  = CumulativeState ( options.start .value ());
126161    } else  {
127-       accumulator.current_value  = Identity<Op>:: template  value<OutValue> ;
162+       accumulator.current_state  = CumulativeState () ;
128163    }
129164    accumulator.skip_nulls  = options.skip_nulls ;
130165
@@ -138,16 +173,17 @@ struct CumulativeKernel {
138173  }
139174};
140175
141- template  <typename  OutType,  typename   ArgType, typename  Op , typename  OptionsType>
176+ template  <typename  ArgType, typename  CumulativeState , typename  OptionsType>
142177struct  CumulativeKernelChunked  {
178+   using  OutType = typename  CumulativeState::OutType;
143179  using  OutValue = typename  GetOutputType<OutType>::T;
144180  static  Status Exec (KernelContext* ctx, const  ExecBatch& batch, Datum* out) {
145181    const  auto & options = CumulativeOptionsWrapper<OptionsType>::Get (ctx);
146-     Accumulator<OutType,  ArgType, Op, OptionsType > accumulator (ctx);
182+     Accumulator<ArgType, CumulativeState > accumulator (ctx);
147183    if  (options.start .has_value ()) {
148-       accumulator.current_value  = UnboxScalar<OutType>:: Unbox (*( options.start .value () ));
184+       accumulator.current_state  = CumulativeState ( options.start .value ());
149185    } else  {
150-       accumulator.current_value  = Identity<Op>:: template  value<OutValue> ;
186+       accumulator.current_state  = CumulativeState () ;
151187    }
152188    accumulator.skip_nulls  = options.skip_nulls ;
153189
@@ -217,53 +253,102 @@ const FunctionDoc cumulative_min_doc{
217253     " start as the new minimum)." 
218254    {" values" 
219255    " CumulativeOptions" 
220- }  //  namespace
221256
222- template  <typename  Op, typename  OptionsType>
223- void  MakeVectorCumulativeFunction (FunctionRegistry* registry, const  std::string func_name,
224-                                   const  FunctionDoc doc) {
257+ const  FunctionDoc cumulative_mean_doc{
258+     " Compute the cumulative mean over a numeric input" 
259+     (" `values` must be numeric. Return an array/chunked array which is the\n " 
260+      " cumulative mean computed over `values`. CumulativeOptions::start_value is \n " 
261+      " ignored." 
262+     {" values" 
263+     " CumulativeOptions" 
264+ 
265+ //  Kernel factory for complex stateful computations.
266+ template  <template  <typename  ArgType> typename  State, typename  OptionsType>
267+ struct  CumulativeStatefulKernelFactory  {
268+   VectorKernel kernel;
269+ 
270+   CumulativeStatefulKernelFactory () {
271+     kernel.can_execute_chunkwise  = false ;
272+     kernel.null_handling  = NullHandling::type::COMPUTED_NO_PREALLOCATE;
273+     kernel.mem_allocation  = MemAllocation::type::NO_PREALLOCATE;
274+     kernel.init  = CumulativeOptionsWrapper<OptionsType>::Init;
275+   }
276+ 
277+   template  <typename  Type>
278+   enable_if_number<Type, Status> Visit (const  Type& type) {
279+     kernel.signature  = KernelSignature::Make (
280+         {type.GetSharedPtr ()},
281+         OutputType (TypeTraits<typename  State<Type>::OutType>::type_singleton ()));
282+     kernel.exec  = CumulativeKernel<Type, State<Type>, OptionsType>::Exec;
283+     kernel.exec_chunked  = CumulativeKernelChunked<Type, State<Type>, OptionsType>::Exec;
284+     return  arrow::Status::OK ();
285+   }
286+ 
287+   Status Visit (const  DataType& type) {
288+     return  Status::NotImplemented (" Cumulative kernel not implemented for type " 
289+                                   type.ToString ());
290+   }
291+ 
292+   Result<VectorKernel> Make (const  DataType& type) {
293+     RETURN_NOT_OK (VisitTypeInline (type, this ));
294+     return  kernel;
295+   }
296+ };
297+ 
298+ template  <template  <typename  ArgType> typename  State, typename  OptionsType>
299+ void  MakeVectorCumulativeStatefulFunction (FunctionRegistry* registry,
300+                                           const  std::string func_name,
301+                                           const  FunctionDoc doc) {
225302  static  const  OptionsType kDefaultOptions  = OptionsType::Defaults ();
226303  auto  func =
227304      std::make_shared<VectorFunction>(func_name, Arity::Unary (), doc, &kDefaultOptions );
228305
229306  std::vector<std::shared_ptr<DataType>> types;
230307  types.insert (types.end (), NumericTypes ().begin (), NumericTypes ().end ());
231308
309+   CumulativeStatefulKernelFactory<State, OptionsType> kernel_factory;
232310  for  (const  auto & ty : types) {
233-     VectorKernel kernel;
234-     kernel.can_execute_chunkwise  = false ;
235-     kernel.null_handling  = NullHandling::type::COMPUTED_NO_PREALLOCATE;
236-     kernel.mem_allocation  = MemAllocation::type::NO_PREALLOCATE;
237-     kernel.signature  = KernelSignature::Make ({ty}, OutputType (ty));
238-     kernel.exec  =
239-         ArithmeticExecFromOp<CumulativeKernel, Op, ArrayKernelExec, OptionsType>(ty);
240-     kernel.exec_chunked  =
241-         ArithmeticExecFromOp<CumulativeKernelChunked, Op, VectorKernel::ChunkedExec,
242-                              OptionsType>(ty);
243-     kernel.init  = CumulativeOptionsWrapper<OptionsType>::Init;
311+     auto  kernel = kernel_factory.Make (*ty).ValueOrDie ();
244312    DCHECK_OK (func->AddKernel (std::move (kernel)));
245313  }
246314
247315  DCHECK_OK (registry->AddFunction (std::move (func)));
248316}
249317
318+ //  A kernel factory that forwards to CumulativeBinaryOp<Op, ...> for the given type.
319+ //  Need to use a struct because template-using declarations cannot appear in
320+ //  function scope.
321+ template  <typename  Op, typename  OptionsType>
322+ struct  MakeVectorCumulativeBinaryOpFunction  {
323+   template  <typename  ArgType>
324+   using  State = CumulativeBinaryOp<Op, ArgType>;
325+ 
326+   static  void  Call (FunctionRegistry* registry, std::string func_name, FunctionDoc doc) {
327+     MakeVectorCumulativeStatefulFunction<State, OptionsType>(
328+         registry, std::move (func_name), std::move (doc));
329+   }
330+ };
331+ 
332+ }  //  namespace
333+ 
250334void  RegisterVectorCumulativeSum (FunctionRegistry* registry) {
251-   MakeVectorCumulativeFunction <Add, CumulativeOptions>(registry,  " cumulative_sum " , 
252-                                                         cumulative_sum_doc);
253-   MakeVectorCumulativeFunction <AddChecked, CumulativeOptions>(
335+   MakeVectorCumulativeBinaryOpFunction <Add, CumulativeOptions>:: Call ( 
336+       registry,  " cumulative_sum " ,  cumulative_sum_doc);
337+   MakeVectorCumulativeBinaryOpFunction <AddChecked, CumulativeOptions>:: Call (
254338      registry, " cumulative_sum_checked" 
255339
256-   MakeVectorCumulativeFunction <Multiply, CumulativeOptions>(registry,  " cumulative_prod " , 
257-                                                              cumulative_prod_doc);
258-   MakeVectorCumulativeFunction <MultiplyChecked, CumulativeOptions>(
340+   MakeVectorCumulativeBinaryOpFunction <Multiply, CumulativeOptions>:: Call ( 
341+       registry,  " cumulative_prod " ,  cumulative_prod_doc);
342+   MakeVectorCumulativeBinaryOpFunction <MultiplyChecked, CumulativeOptions>:: Call (
259343      registry, " cumulative_prod_checked" 
260344
261-   MakeVectorCumulativeFunction<Min, CumulativeOptions>(registry, " cumulative_min" 
262-                                                        cumulative_min_doc);
263-   MakeVectorCumulativeFunction<Max, CumulativeOptions>(registry, " cumulative_max" 
264-                                                        cumulative_max_doc);
345+   MakeVectorCumulativeBinaryOpFunction<Min, CumulativeOptions>::Call (
346+       registry, " cumulative_min" 
347+   MakeVectorCumulativeBinaryOpFunction<Max, CumulativeOptions>::Call (
348+       registry, " cumulative_max" 
349+ 
350+   MakeVectorCumulativeStatefulFunction<CumulativeMean, CumulativeOptions>(
351+       registry, " cumulative_mean" 
265352}
266353
267- }  //  namespace internal
268- }  //  namespace compute
269- }  //  namespace arrow
354+ }  //  namespace arrow::compute::internal
0 commit comments