Skip to content

Commit

Permalink
[Opt](exec) opt aggreate function performance in nullable column
Browse files Browse the repository at this point in the history
  • Loading branch information
HappenLee authored Feb 16, 2023
1 parent 4c7f19a commit 24ef60b
Show file tree
Hide file tree
Showing 9 changed files with 339 additions and 118 deletions.
21 changes: 17 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_avg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,23 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name,

AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (is_decimal(data_type)) {
res.reset(
create_with_decimal_type<AggregateFuncAvg>(*data_type, *data_type, argument_types));
if (data_type->is_nullable()) {
auto no_null_argument_types = remove_nullable(argument_types);
if (is_decimal(no_null_argument_types[0])) {
res.reset(create_with_decimal_type_null<AggregateFuncAvg>(
no_null_argument_types, parameters, *no_null_argument_types[0],
no_null_argument_types));
} else {
res.reset(create_with_numeric_type_null<AggregateFuncAvg>(
no_null_argument_types, parameters, no_null_argument_types));
}
} else {
res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, argument_types));
if (is_decimal(data_type)) {
res.reset(create_with_decimal_type<AggregateFuncAvg>(*data_type, *data_type,
argument_types));
} else {
res.reset(create_with_numeric_type<AggregateFuncAvg>(*data_type, argument_types));
}
}

if (!res) {
Expand All @@ -61,5 +73,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const std::string& name,

void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
factory.register_function("avg", create_aggregate_function_avg);
factory.register_function("avg", create_aggregate_function_avg, true);
}
} // namespace doris::vectorized
3 changes: 2 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ class AggregateFunctionCount final
DataTypePtr get_serialized_type() const override { return std::make_shared<DataTypeUInt64>(); }
};

/// Simply count number of not-NULL values.
// TODO: Maybe AggregateFunctionCountNotNullUnary should be a subclass of AggregateFunctionCount
// Simply count number of not-NULL values.
class AggregateFunctionCountNotNullUnary final
: public IAggregateFunctionDataHelper<AggregateFunctionCountData,
AggregateFunctionCountNotNullUnary> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "vec/aggregate_functions/helpers.h"

namespace doris::vectorized {

/// min, max, any
template <template <typename, bool> class AggregateFunctionTemplate, template <typename> class Data>
static IAggregateFunction* create_aggregate_function_single_value(const String& name,
Expand Down
1 change: 0 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_null.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinato
};

void register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& factory) {
// factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorNull>());
AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
const Array& params, const bool result_is_nullable) {
auto function_combinator = std::make_shared<AggregateFunctionCombinatorNull>();
Expand Down
245 changes: 245 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_null.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace doris::vectorized {
/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter.
/// true - return NULL; false - return value from empty aggregation state of nested function.

// TODO: only keep class xxxInline after we support all aggregate function
template <bool result_is_nullable, typename Derived>
class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived> {
protected:
Expand Down Expand Up @@ -409,4 +410,248 @@ class AggregateFunctionNullVariadic final
is_nullable; /// Plain array is better than std::vector due to one indirection less.
};

template <typename NestFunction, bool result_is_nullable, typename Derived>
class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived> {
protected:
std::unique_ptr<NestFunction> nested_function;
size_t prefix_size;

/** In addition to data for nested aggregate function, we keep a flag
* indicating - was there at least one non-NULL value accumulated.
* In case of no not-NULL values, the function will return NULL.
*
* We use prefix_size bytes for flag to satisfy the alignment requirement of nested state.
*/

AggregateDataPtr nested_place(AggregateDataPtr __restrict place) const noexcept {
return place + prefix_size;
}

ConstAggregateDataPtr nested_place(ConstAggregateDataPtr __restrict place) const noexcept {
return place + prefix_size;
}

static void init_flag(AggregateDataPtr __restrict place) noexcept {
if constexpr (result_is_nullable) {
place[0] = false;
}
}

static void set_flag(AggregateDataPtr __restrict place) noexcept {
if constexpr (result_is_nullable) {
place[0] = true;
}
}

static bool get_flag(ConstAggregateDataPtr __restrict place) noexcept {
return result_is_nullable ? place[0] : true;
}

public:
AggregateFunctionNullBaseInline(IAggregateFunction* nested_function_,
const DataTypes& arguments, const Array& params)
: IAggregateFunctionHelper<Derived>(arguments, params),
nested_function {assert_cast<NestFunction*>(nested_function_)} {
if (result_is_nullable) {
prefix_size = nested_function->align_of_data();
} else {
prefix_size = 0;
}
}

String get_name() const override {
/// This is just a wrapper. The function for Nullable arguments is named the same as the nested function itself.
return nested_function->get_name();
}

DataTypePtr get_return_type() const override {
return result_is_nullable ? make_nullable(nested_function->get_return_type())
: nested_function->get_return_type();
}

void create(AggregateDataPtr __restrict place) const override {
init_flag(place);
nested_function->create(nested_place(place));
}

void destroy(AggregateDataPtr __restrict place) const noexcept override {
nested_function->destroy(nested_place(place));
}
void reset(AggregateDataPtr place) const override {
init_flag(place);
nested_function->reset(nested_place(place));
}

bool has_trivial_destructor() const override {
return nested_function->has_trivial_destructor();
}

size_t size_of_data() const override { return prefix_size + nested_function->size_of_data(); }

size_t align_of_data() const override { return nested_function->align_of_data(); }

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena* arena) const override {
if (result_is_nullable && get_flag(rhs)) {
set_flag(place);
}

nested_function->merge(nested_place(place), nested_place(rhs), arena);
}

void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
bool flag = get_flag(place);
if (result_is_nullable) {
write_binary(flag, buf);
}
if (flag) {
nested_function->serialize(nested_place(place), buf);
}
}

void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena* arena) const override {
bool flag = true;
if (result_is_nullable) {
read_binary(flag, buf);
}
if (flag) {
set_flag(place);
nested_function->deserialize(nested_place(place), buf, arena);
}
}

void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena* arena) const override {
bool flag = true;
if (result_is_nullable) {
read_binary(flag, buf);
}
if (flag) {
set_flag(place);
nested_function->deserialize_and_merge(nested_place(place), buf, arena);
}
}

void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column,
Arena* arena) const override {
size_t num_rows = column.size();
for (size_t i = 0; i != num_rows; ++i) {
VectorBufferReader buffer_reader(
(assert_cast<const ColumnString&>(column)).get_data_at(i));
deserialize_and_merge(place, buffer_reader, arena);
}
}

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
if constexpr (result_is_nullable) {
ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);
if (get_flag(place)) {
nested_function->insert_result_into(nested_place(place),
to_concrete.get_nested_column());
to_concrete.get_null_map_data().push_back(0);
} else {
to_concrete.insert_default();
}
} else {
nested_function->insert_result_into(nested_place(place), to);
}
}

bool allocates_memory_in_arena() const override {
return nested_function->allocates_memory_in_arena();
}

bool is_state() const override { return nested_function->is_state(); }
};

/** There are two cases: for single argument and variadic.
* Code for single argument is much more efficient.
*/
template <typename NestFuction, bool result_is_nullable>
class AggregateFunctionNullUnaryInline final
: public AggregateFunctionNullBaseInline<
NestFuction, result_is_nullable,
AggregateFunctionNullUnaryInline<NestFuction, result_is_nullable>> {
public:
AggregateFunctionNullUnaryInline(IAggregateFunction* nested_function_,
const DataTypes& arguments, const Array& params)
: AggregateFunctionNullBaseInline<
NestFuction, result_is_nullable,
AggregateFunctionNullUnaryInline<NestFuction, result_is_nullable>>(
nested_function_, arguments, params) {}

void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena* arena) const override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
if (!column->is_null_at(row_num)) {
this->set_flag(place);
const IColumn* nested_column = &column->get_nested_column();
this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena);
}
}

void add_not_nullable(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num, Arena* arena) const {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
this->set_flag(place);
const IColumn* nested_column = &column->get_nested_column();
this->nested_function->add(this->nested_place(place), &nested_column, row_num, arena);
}

void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset,
const IColumn** columns, Arena* arena, bool agg_many) const override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
// The overhead introduced is negligible here, just an extra memory read from NullMap
const auto* __restrict null_map_data = column->get_null_map_data().data();
const IColumn* nested_column = &column->get_nested_column();
for (int i = 0; i < batch_size; ++i) {
if (!null_map_data[i]) {
AggregateDataPtr __restrict place = places[i] + place_offset;
this->set_flag(place);
this->nested_function->add(this->nested_place(place), &nested_column, i, arena);
}
}
}

void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena* arena) const override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);
bool has_null = column->has_null();

if (has_null) {
for (size_t i = 0; i < batch_size; ++i) {
if (!column->is_null_at(i)) {
this->set_flag(place);
this->add(place, columns, i, arena);
}
}
} else {
this->set_flag(place);
const IColumn* nested_column = &column->get_nested_column();
this->nested_function->add_batch_single_place(batch_size, this->nested_place(place),
&nested_column, arena);
}
}

void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place,
const IColumn** columns, Arena* arena, bool has_null) override {
const ColumnNullable* column = assert_cast<const ColumnNullable*>(columns[0]);

if (has_null) {
for (size_t i = batch_begin; i <= batch_end; ++i) {
if (!column->is_null_at(i)) {
this->set_flag(place);
this->add(place, columns, i, arena);
}
}
} else {
this->set_flag(place);
const IColumn* nested_column = &column->get_nested_column();
this->nested_function->add_batch_range(batch_begin, batch_end,
this->nested_place(place), &nested_column, arena,
false);
}
}
};
} // namespace doris::vectorized
24 changes: 18 additions & 6 deletions be/src/vec/aggregate_functions/aggregate_function_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "common/logging.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/data_types/data_type_nullable.h"

namespace doris::vectorized {

Expand All @@ -45,15 +46,24 @@ AggregateFunctionPtr create_aggregate_function_sum(const std::string& name,
const DataTypes& argument_types,
const Array& parameters,
const bool result_is_nullable) {
// assert_no_parameters(name, parameters);
// assert_unary(name, argument_types);

AggregateFunctionPtr res;
DataTypePtr data_type = argument_types[0];
if (is_decimal(data_type)) {
res.reset(create_with_decimal_type<Function>(*data_type, *data_type, argument_types));
if (data_type->is_nullable()) {
auto no_null_argument_types = remove_nullable(argument_types);
if (is_decimal(no_null_argument_types[0])) {
res.reset(create_with_decimal_type_null<Function>(no_null_argument_types, parameters,
*no_null_argument_types[0],
no_null_argument_types));
} else {
res.reset(create_with_numeric_type_null<Function>(no_null_argument_types, parameters,
no_null_argument_types));
}
} else {
res.reset(create_with_numeric_type<Function>(*data_type, argument_types));
if (is_decimal(data_type)) {
res.reset(create_with_decimal_type<Function>(*data_type, *data_type, argument_types));
} else {
res.reset(create_with_numeric_type<Function>(*data_type, argument_types));
}
}

if (!res) {
Expand Down Expand Up @@ -84,6 +94,8 @@ AggregateFunctionPtr create_aggregate_function_sum_reader(const std::string& nam

void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) {
factory.register_function("sum", create_aggregate_function_sum<AggregateFunctionSumSimple>);
factory.register_function("sum", create_aggregate_function_sum<AggregateFunctionSumSimple>,
true);
}

} // namespace doris::vectorized
Loading

0 comments on commit 24ef60b

Please sign in to comment.