11#include " functions.h"
22
33#include < contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h>
4+ #include < contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h>
5+ #include < contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h>
6+ #include < contrib/libs/apache/arrow/cpp/src/arrow/compute/registry_internal.h>
47#include < contrib/libs/apache/arrow/cpp/src/arrow/table.h>
58
69namespace NKikimr ::NArrow::NSSA {
10+
11+ namespace internal {
12+
13+ // Find the largest compatible primitive type for a primitive type.
14+ template <typename I, typename Enable = void >
15+ struct FindAccumulatorType {};
16+
17+ template <typename I>
18+ struct FindAccumulatorType <I, arrow::enable_if_boolean<I>> {
19+ using Type = arrow::UInt64Type;
20+ };
21+
22+ template <typename I>
23+ struct FindAccumulatorType <I, arrow::enable_if_signed_integer<I>> {
24+ using Type = arrow::Int64Type;
25+ };
26+
27+ template <typename I>
28+ struct FindAccumulatorType <I, arrow::enable_if_unsigned_integer<I>> {
29+ using Type = arrow::UInt64Type;
30+ };
31+
32+ template <typename I>
33+ struct FindAccumulatorType <I, arrow::enable_if_floating_point<I>> {
34+ using Type = arrow::DoubleType;
35+ };
36+
37+ template <>
38+ struct FindAccumulatorType <arrow::FloatType, void > {
39+ using Type = arrow::FloatType;
40+ };
41+
42+ template <typename ArrowType, arrow::compute::SimdLevel::type SimdLevel>
43+ struct SumImpl : public arrow ::compute::ScalarAggregator {
44+ using ThisType = SumImpl<ArrowType, SimdLevel>;
45+ using CType = typename ArrowType::c_type;
46+ using SumType = typename FindAccumulatorType<ArrowType>::Type;
47+ using OutputType = typename arrow::TypeTraits<SumType>::ScalarType;
48+
49+ arrow::Status Consume (arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch) override {
50+ if (batch[0 ].is_array ()) {
51+ const auto & data = batch[0 ].array ();
52+ this ->Count += data->length - data->GetNullCount ();
53+ if (arrow::is_boolean_type<ArrowType>::value) {
54+ this ->Sum +=
55+ static_cast <typename SumType::c_type>(arrow::BooleanArray (data).true_count ());
56+ } else {
57+ this ->Sum +=
58+ arrow::compute::detail::SumArray<CType, typename SumType::c_type, SimdLevel>(
59+ *data);
60+ }
61+ } else {
62+ const auto & data = *batch[0 ].scalar ();
63+ this ->Count += data.is_valid * batch.length ;
64+ if (data.is_valid ) {
65+ this ->Sum += arrow::compute::internal::UnboxScalar<ArrowType>::Unbox (data) * batch.length ;
66+ }
67+ }
68+ return arrow::Status::OK ();
69+ }
70+
71+ arrow::Status MergeFrom (arrow::compute::KernelContext*, arrow::compute::KernelState&& src) override {
72+ const auto & other = arrow::checked_cast<const ThisType&>(src);
73+ this ->Count += other.Count ;
74+ this ->Sum += other.Sum ;
75+ return arrow::Status::OK ();
76+ }
77+
78+ arrow::Status Finalize (arrow::compute::KernelContext*, arrow::Datum* out) override {
79+ if (this ->Count < Options.min_count ) {
80+ out->value = std::make_shared<OutputType>();
81+ } else {
82+ out->value = arrow::MakeScalar (this ->Sum );
83+ }
84+ return arrow::Status::OK ();
85+ }
86+
87+ size_t Count = 0 ;
88+ typename SumType::c_type Sum = 0 ;
89+ arrow::compute::ScalarAggregateOptions Options;
90+ };
91+
92+ template <typename ArrowType>
93+ struct SumImplDefault : public SumImpl <ArrowType, arrow::compute::SimdLevel::NONE> {
94+ explicit SumImplDefault (const arrow::compute::ScalarAggregateOptions& options) {
95+ this ->Options = options;
96+ }
97+ };
98+
99+ void AddScalarAggKernels (arrow::compute::KernelInit init,
100+ const std::vector<std::shared_ptr<arrow::DataType>>& types,
101+ std::shared_ptr<arrow::DataType> out_ty,
102+ arrow::compute::ScalarAggregateFunction* func) {
103+ for (const auto & ty : types) {
104+ // scalar[InT] -> scalar[OutT]
105+ auto sig = arrow::compute::KernelSignature::Make ({arrow::compute::InputType::Scalar (ty)}, arrow::ValueDescr::Scalar (out_ty));
106+ AddAggKernel (std::move (sig), init, func, arrow::compute::SimdLevel::NONE);
107+ }
108+ }
109+
110+ void AddArrayScalarAggKernels (arrow::compute::KernelInit init,
111+ const std::vector<std::shared_ptr<arrow::DataType>>& types,
112+ std::shared_ptr<arrow::DataType> out_ty,
113+ arrow::compute::ScalarAggregateFunction* func,
114+ arrow::compute::SimdLevel::type simd_level = arrow::compute::SimdLevel::NONE) {
115+ arrow::compute::aggregate::AddBasicAggKernels (init, types, out_ty, func, simd_level);
116+ AddScalarAggKernels (init, types, out_ty, func);
117+ }
118+
119+ arrow::Result<std::unique_ptr<arrow::compute::KernelState>> SumInit (arrow::compute::KernelContext* ctx,
120+ const arrow::compute::KernelInitArgs& args) {
121+ arrow::compute::aggregate::SumLikeInit<SumImplDefault> visitor (
122+ ctx, *args.inputs [0 ].type ,
123+ static_cast <const arrow::compute::ScalarAggregateOptions&>(*args.options ));
124+ return visitor.Create ();
125+ }
126+
127+ static std::unique_ptr<arrow::compute::FunctionRegistry> CreateCustomRegistry () {
128+ arrow::compute::FunctionRegistry* defaultRegistry = arrow::compute::GetFunctionRegistry ();
129+ auto registry = arrow::compute::FunctionRegistry::Make ();
130+ for (const auto & func : defaultRegistry->GetFunctionNames ()) {
131+ if (func == " sum" ) {
132+ auto aggregateFunc = dynamic_cast <arrow::compute::ScalarAggregateFunction*>(defaultRegistry->GetFunction (func)->get ());
133+ if (!aggregateFunc) {
134+ DCHECK_OK (registry->AddFunction (*defaultRegistry->GetFunction (func)));
135+ continue ;
136+ }
137+ arrow::compute::ScalarAggregateFunction newFunc (func, aggregateFunc->arity (), &aggregateFunc->doc (), aggregateFunc->default_options ());
138+ for (const arrow::compute::ScalarAggregateKernel* kernel : aggregateFunc->kernels ()) {
139+ auto shouldReplaceKernel = [](const arrow::compute::ScalarAggregateKernel& kernel) {
140+ const auto & params = kernel.signature ->in_types ();
141+ if (params.empty ()) {
142+ return false ;
143+ }
144+
145+ if (params[0 ].kind () == arrow::compute::InputType::Kind::EXACT_TYPE) {
146+ auto type = params[0 ].type ();
147+ return type->id () == arrow::Type::FLOAT;
148+ }
149+
150+ return false ;
151+ };
152+
153+ if (shouldReplaceKernel (*kernel)) {
154+ AddArrayScalarAggKernels (SumInit, {arrow::float32 ()}, arrow::float32 (), &newFunc);
155+ } else {
156+ DCHECK_OK (newFunc.AddKernel (*kernel));
157+ }
158+ }
159+ DCHECK_OK (registry->AddFunction (std::make_shared<arrow::compute::ScalarAggregateFunction>(std::move (newFunc))));
160+ } else {
161+ DCHECK_OK (registry->AddFunction (*defaultRegistry->GetFunction (func)));
162+ }
163+ }
164+
165+ return registry;
166+ }
167+ arrow::compute::FunctionRegistry* GetCustomFunctionRegistry () {
168+ static auto registry = internal::CreateCustomRegistry ();
169+ return registry.get ();
170+ }
171+
172+ } // namespace internal
173+
7174TConclusion<arrow::Datum> TInternalFunction::Call (
8175 const TExecFunctionContext& context, const std::shared_ptr<TAccessorsCollection>& resources) const {
9176 auto funcNames = GetRegistryFunctionNames ();
@@ -16,7 +183,8 @@ TConclusion<arrow::Datum> TInternalFunction::Call(
16183 if (GetContext () && GetContext ()->func_registry ()->GetFunction (funcName).ok ()) {
17184 result = arrow::compute::CallFunction (funcName, *arguments, FunctionOptions.get (), GetContext ());
18185 } else {
19- result = arrow::compute::CallFunction (funcName, *arguments, FunctionOptions.get ());
186+ arrow::compute::ExecContext defaultContext (arrow::default_memory_pool (), nullptr , internal::GetCustomFunctionRegistry ());
187+ result = arrow::compute::CallFunction (funcName, *arguments, FunctionOptions.get (), &defaultContext);
20188 }
21189
22190 if (result.ok () && funcName == " count" sv) {
0 commit comments