Skip to content

Commit

Permalink
[feature](function) Support for aggregate function foreach combiner (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Mryange authored Mar 6, 2024
1 parent 69c3238 commit 3c39734
Show file tree
Hide file tree
Showing 54 changed files with 691 additions and 68 deletions.
2 changes: 1 addition & 1 deletion be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class IAggregateFunction {
* row_num is number of row which should be added.
* Additional parameter arena should be used instead of standard memory allocator if the addition requires memory allocation.
*/
virtual void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
virtual void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const = 0;

virtual void add_many(AggregateDataPtr __restrict place, const IColumn** columns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class AggregateFunctionApproxCountDistinct final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
if constexpr (IsFixLenColumnType<ColumnDataType>::value) {
auto column = assert_cast<const ColumnDataType*>(columns[0]);
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/aggregate_functions/aggregate_function_avg.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class AggregateFunctionAvg final
}
}

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
#ifdef __clang__
#pragma clang fp reassociate(on)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class AggregateFunctionAvgWeight final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
const auto& column = assert_cast<const ColVecType&>(*columns[0]);
const auto& weight = assert_cast<const ColumnVector<Float64>&>(*columns[1]);
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/aggregate_functions/aggregate_function_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct AggregateFunctionBinary

bool allocates_memory_in_arena() const override { return false; }

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
this->data(place).add(
static_cast<ResultType>(
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/aggregate_functions/aggregate_function_bit.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class AggregateFunctionBitwise final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeNumber<T>>(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
const auto& column = assert_cast<const ColumnVector<T>&>(*columns[0]);
this->data(place).add(column.get_data()[row_num]);
Expand Down
4 changes: 2 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_bitmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ class AggregateFunctionBitmapOp final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeBitMap>(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
const auto& column = assert_cast<const ColVecType&>(*columns[0]);
this->data(place).add(column.get_data()[row_num]);
Expand Down Expand Up @@ -361,7 +361,7 @@ class AggregateFunctionBitmapCount final
String get_name() const override { return "count"; }
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
if constexpr (arg_is_nullable) {
auto& nullable_column = assert_cast<const ColumnNullable&>(*columns[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class AggregateFunctionBitmapAgg final
std::string get_name() const override { return "bitmap_agg"; }
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeBitMap>(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
DCHECK_LT(row_num, columns[0]->size());
if constexpr (arg_nullable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ class AggregateFunctionCollect

bool allocates_memory_in_arena() const override { return ENABLE_ARENA; }

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
auto& data = this->data(place);
if constexpr (HasLimit::value) {
Expand Down
4 changes: 2 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class AggregateFunctionCount final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }

void add(AggregateDataPtr __restrict place, const IColumn**, size_t, Arena*) const override {
void add(AggregateDataPtr __restrict place, const IColumn**, ssize_t, Arena*) const override {
++data(place).count;
}

Expand Down Expand Up @@ -194,7 +194,7 @@ class AggregateFunctionCountNotNullUnary final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
data(place).count += !assert_cast<const ColumnNullable&>(*columns[0]).is_null_at(row_num);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class AggregateFunctionCountByEnum final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeString>(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
for (int i = 0; i < arg_count; i++) {
const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]);
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/aggregate_functions/aggregate_function_covar.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ class AggregateFunctionSampCovariance
}
}

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
if constexpr (is_pop) {
this->data(place).add(columns[0], columns[1], row_num);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class AggregateFunctionDistinct
prefix_size = (sizeof(Data) + nested_size - 1) / nested_size * nested_size;
}

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
this->data(place).add(columns, arguments_num, row_num, arena);
}
Expand Down
64 changes: 64 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_foreach.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.cpp
// and modified by Doris

#include "vec/aggregate_functions/aggregate_function_foreach.h"

#include <memory>
#include <ostream>

#include "common/logging.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/common/typeid_cast.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_nullable.h"

namespace doris::vectorized {

void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& factory) {
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const bool result_is_nullable) -> AggregateFunctionPtr {
const std::string& suffix = AggregateFunctionForEach::AGG_FOREACH_SUFFIX;
DataTypes transform_arguments;
for (const auto& t : types) {
auto item_type =
assert_cast<const DataTypeArray*>(remove_nullable(t).get())->get_nested_type();
transform_arguments.push_back((item_type));
}
auto nested_function_name = name.substr(0, name.size() - suffix.size());
auto nested_function =
factory.get(nested_function_name, transform_arguments, result_is_nullable);
if (!nested_function) {
throw Exception(
ErrorCode::INTERNAL_ERROR,
"The combiner did not find a foreach combiner function. nested function "
"name {} , args {}",
nested_function_name, types_name(types));
}
return creator_without_type::create<AggregateFunctionForEach>(transform_arguments, true,
nested_function);
};
factory.register_foreach_function_combinator(
creator, AggregateFunctionForEach::AGG_FOREACH_SUFFIX, true);
factory.register_foreach_function_combinator(
creator, AggregateFunctionForEach::AGG_FOREACH_SUFFIX, false);
}
} // namespace doris::vectorized
Loading

0 comments on commit 3c39734

Please sign in to comment.