Skip to content

float sum aggregation has been fixed #19466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 169 additions & 1 deletion ydb/core/formats/arrow/program/functions.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,176 @@
#include "functions.h"

#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/registry_internal.h>
#include <contrib/libs/apache/arrow/cpp/src/arrow/table.h>

namespace NKikimr::NArrow::NSSA {

namespace internal {

// Find the largest compatible primitive type for a primitive type.
template <typename I, typename Enable = void>
struct FindAccumulatorType {};

template <typename I>
struct FindAccumulatorType<I, arrow::enable_if_boolean<I>> {
using Type = arrow::UInt64Type;
};

template <typename I>
struct FindAccumulatorType<I, arrow::enable_if_signed_integer<I>> {
using Type = arrow::Int64Type;
};

template <typename I>
struct FindAccumulatorType<I, arrow::enable_if_unsigned_integer<I>> {
using Type = arrow::UInt64Type;
};

template <typename I>
struct FindAccumulatorType<I, arrow::enable_if_floating_point<I>> {
using Type = arrow::DoubleType;
};

template <>
struct FindAccumulatorType<arrow::FloatType, void> {
using Type = arrow::FloatType;
};

template <typename ArrowType, arrow::compute::SimdLevel::type SimdLevel>
struct SumImpl : public arrow::compute::ScalarAggregator {
using ThisType = SumImpl<ArrowType, SimdLevel>;
using CType = typename ArrowType::c_type;
using SumType = typename FindAccumulatorType<ArrowType>::Type;
using OutputType = typename arrow::TypeTraits<SumType>::ScalarType;

arrow::Status Consume(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch) override {
if (batch[0].is_array()) {
const auto& data = batch[0].array();
this->Count += data->length - data->GetNullCount();
if (arrow::is_boolean_type<ArrowType>::value) {
this->Sum +=
static_cast<typename SumType::c_type>(arrow::BooleanArray(data).true_count());
} else {
this->Sum +=
arrow::compute::detail::SumArray<CType, typename SumType::c_type, SimdLevel>(
*data);
}
} else {
const auto& data = *batch[0].scalar();
this->Count += data.is_valid * batch.length;
if (data.is_valid) {
this->Sum += arrow::compute::internal::UnboxScalar<ArrowType>::Unbox(data) * batch.length;
}
}
return arrow::Status::OK();
}

arrow::Status MergeFrom(arrow::compute::KernelContext*, arrow::compute::KernelState&& src) override {
const auto& other = arrow::checked_cast<const ThisType&>(src);
this->Count += other.Count;
this->Sum += other.Sum;
return arrow::Status::OK();
}

arrow::Status Finalize(arrow::compute::KernelContext*, arrow::Datum* out) override {
if (this->Count < Options.min_count) {
out->value = std::make_shared<OutputType>();
} else {
out->value = arrow::MakeScalar(this->Sum);
}
return arrow::Status::OK();
}

size_t Count = 0;
typename SumType::c_type Sum = 0;
arrow::compute::ScalarAggregateOptions Options;
};

template <typename ArrowType>
struct SumImplDefault : public SumImpl<ArrowType, arrow::compute::SimdLevel::NONE> {
explicit SumImplDefault(const arrow::compute::ScalarAggregateOptions& options) {
this->Options = options;
}
};

void AddScalarAggKernels(arrow::compute::KernelInit init,
const std::vector<std::shared_ptr<arrow::DataType>>& types,
std::shared_ptr<arrow::DataType> out_ty,
arrow::compute::ScalarAggregateFunction* func) {
for (const auto& ty : types) {
// scalar[InT] -> scalar[OutT]
auto sig = arrow::compute::KernelSignature::Make({arrow::compute::InputType::Scalar(ty)}, arrow::ValueDescr::Scalar(out_ty));
AddAggKernel(std::move(sig), init, func, arrow::compute::SimdLevel::NONE);
}
}

void AddArrayScalarAggKernels(arrow::compute::KernelInit init,
const std::vector<std::shared_ptr<arrow::DataType>>& types,
std::shared_ptr<arrow::DataType> out_ty,
arrow::compute::ScalarAggregateFunction* func,
arrow::compute::SimdLevel::type simd_level = arrow::compute::SimdLevel::NONE) {
arrow::compute::aggregate::AddBasicAggKernels(init, types, out_ty, func, simd_level);
AddScalarAggKernels(init, types, out_ty, func);
}

arrow::Result<std::unique_ptr<arrow::compute::KernelState>> SumInit(arrow::compute::KernelContext* ctx,
const arrow::compute::KernelInitArgs& args) {
arrow::compute::aggregate::SumLikeInit<SumImplDefault> visitor(
ctx, *args.inputs[0].type,
static_cast<const arrow::compute::ScalarAggregateOptions&>(*args.options));
return visitor.Create();
}

static std::unique_ptr<arrow::compute::FunctionRegistry> CreateCustomRegistry() {
arrow::compute::FunctionRegistry* defaultRegistry = arrow::compute::GetFunctionRegistry();
auto registry = arrow::compute::FunctionRegistry::Make();
for (const auto& func : defaultRegistry->GetFunctionNames()) {
if (func == "sum") {
auto aggregateFunc = dynamic_cast<arrow::compute::ScalarAggregateFunction*>(defaultRegistry->GetFunction(func)->get());
if (!aggregateFunc) {
DCHECK_OK(registry->AddFunction(*defaultRegistry->GetFunction(func)));
continue;
}
arrow::compute::ScalarAggregateFunction newFunc(func, aggregateFunc->arity(), &aggregateFunc->doc(), aggregateFunc->default_options());
for (const arrow::compute::ScalarAggregateKernel* kernel : aggregateFunc->kernels()) {
auto shouldReplaceKernel = [](const arrow::compute::ScalarAggregateKernel& kernel) {
const auto& params = kernel.signature->in_types();
if (params.empty()) {
return false;
}

if (params[0].kind() == arrow::compute::InputType::Kind::EXACT_TYPE) {
auto type = params[0].type();
return type->id() == arrow::Type::FLOAT;
}

return false;
};

if (shouldReplaceKernel(*kernel)) {
AddArrayScalarAggKernels(SumInit, {arrow::float32()}, arrow::float32(), &newFunc);
} else {
DCHECK_OK(newFunc.AddKernel(*kernel));
}
}
DCHECK_OK(registry->AddFunction(std::make_shared<arrow::compute::ScalarAggregateFunction>(std::move(newFunc))));
} else {
DCHECK_OK(registry->AddFunction(*defaultRegistry->GetFunction(func)));
}
}

return registry;
}
arrow::compute::FunctionRegistry* GetCustomFunctionRegistry() {
static auto registry = internal::CreateCustomRegistry();
return registry.get();
}

} // namespace internal

TConclusion<arrow::Datum> TInternalFunction::Call(
const TExecFunctionContext& context, const std::shared_ptr<TAccessorsCollection>& resources) const {
auto funcNames = GetRegistryFunctionNames();
Expand All @@ -16,7 +183,8 @@ TConclusion<arrow::Datum> TInternalFunction::Call(
if (GetContext() && GetContext()->func_registry()->GetFunction(funcName).ok()) {
result = arrow::compute::CallFunction(funcName, *arguments, FunctionOptions.get(), GetContext());
} else {
result = arrow::compute::CallFunction(funcName, *arguments, FunctionOptions.get());
arrow::compute::ExecContext defaultContext(arrow::default_memory_pool(), nullptr, internal::GetCustomFunctionRegistry());
result = arrow::compute::CallFunction(funcName, *arguments, FunctionOptions.get(), &defaultContext);
}

if (result.ok() && funcName == "count"sv) {
Expand Down
4 changes: 4 additions & 0 deletions ydb/core/formats/arrow/program/ya.make
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,8 @@ GENERATE_ENUM_SERIALIZATION(execution.h)

YQL_LAST_ABI_VERSION()

CFLAGS(
-Wno-unused-parameter
)

END()
49 changes: 49 additions & 0 deletions ydb/core/kqp/ut/olap/aggregations_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,55 @@ Y_UNIT_TEST_SUITE(KqpOlapAggregations) {

TestTableWithNulls({ testCase }, /* generic */ true);
}

Y_UNIT_TEST(FloatSum) {
NKikimrConfig::TAppConfig appConfig;
appConfig.MutableTableServiceConfig()->SetEnableOlapSink(true);
auto settings = TKikimrSettings()
.SetAppConfig(appConfig)
.SetWithSampleTables(false);
TKikimrRunner kikimr(settings);

auto queryClient = kikimr.GetQueryClient();
{
auto status = queryClient.ExecuteQuery(
R"(
CREATE TABLE `olap_table` (
id Uint64 NOT NULL,
value Float,
PRIMARY KEY (id)
) WITH (STORE = COLUMN);
)", NYdb::NQuery::TTxControl::NoTx()
).GetValueSync();
UNIT_ASSERT_C(status.IsSuccess(), status.GetIssues().ToString());
}

{
auto status = queryClient.ExecuteQuery(
R"(
INSERT INTO `olap_table` (id, value) VALUES (1u, 0.4f);
INSERT INTO `olap_table` (id, value) VALUES (2u, 0.85f);
INSERT INTO `olap_table` (id, value) VALUES (3u, 11.3f);
INSERT INTO `olap_table` (id, value) VALUES (4u, 7.15f);
INSERT INTO `olap_table` (id, value) VALUES (5u, 0.3f);
)", NYdb::NQuery::TTxControl::BeginTx().CommitTx()
).GetValueSync();
UNIT_ASSERT_C(status.IsSuccess(), status.GetIssues().ToString());
}

{
auto status = queryClient.ExecuteQuery(R"(
--!syntax_v1
SELECT SUM(value) FROM `olap_table`
WHERE id = 1
)", NYdb::NQuery::TTxControl::BeginTx().CommitTx()
).GetValueSync();

UNIT_ASSERT_C(status.IsSuccess(), status.GetIssues().ToString());
TString result = FormatResultSetYson(status.GetResultSet(0));
CompareYson(result, R"([[[0.400000006;]]])");
}
}
}

}
Loading