Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enchancement](function) Inline some aggregate function && remove nullable combinator #17328

Merged
merged 15 commits into from
Mar 9, 2023
Prev Previous commit
Next Next commit
fix
  • Loading branch information
BiteTheDDDDt committed Mar 8, 2023
commit 6592bb68bfbcbb309c126877206c86955eb07813
6 changes: 0 additions & 6 deletions be/src/vec/aggregate_functions/aggregate_function_bit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ template <template <typename> class Data>
AggregateFunctionPtr createAggregateFunctionBitwise(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
if (!argument_types[0]->can_be_used_in_bit_operations()) {
LOG(WARNING) << fmt::format("The type " + argument_types[0]->get_name() +
" of argument for aggregate function " + name +
" is illegal, because it cannot be used in bitwise operations");
}

AggregateFunctionPtr res(creator_with_integer_type::create<AggregateFunctionBitwise, Data>(
result_is_nullable, argument_types));
if (res) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri
if (argument_types.size() == 1) {
return AggregateFunctionPtr(
creator_without_type::create<AggregateFunctionPercentileApproxMerge<is_nullable>>(
result_is_nullable, argument_types));
result_is_nullable, remove_nullable(argument_types)));
} else if (argument_types.size() == 2) {
return AggregateFunctionPtr(creator_without_type::create<
AggregateFunctionPercentileApproxTwoParams<is_nullable>>(
result_is_nullable, argument_types));
result_is_nullable, remove_nullable(argument_types)));
} else if (argument_types.size() == 3) {
return AggregateFunctionPtr(creator_without_type::create<
AggregateFunctionPercentileApproxThreeParams<is_nullable>>(
result_is_nullable, argument_types));
result_is_nullable, remove_nullable(argument_types)));
}
LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate function {}",
argument_types.size(), name);
Expand All @@ -48,18 +48,20 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri
AggregateFunctionPtr create_aggregate_function_percentile(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
return std::make_shared<AggregateFunctionPercentile>(argument_types);
return AggregateFunctionPtr(creator_without_type::create<AggregateFunctionPercentile>(
result_is_nullable, argument_types));
}

AggregateFunctionPtr create_aggregate_function_percentile_array(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
return std::make_shared<AggregateFunctionPercentileArray>(argument_types);
return AggregateFunctionPtr(creator_without_type::create<AggregateFunctionPercentileArray>(
result_is_nullable, argument_types));
}

void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) {
factory.register_function("percentile", create_aggregate_function_percentile);
factory.register_function("percentile_array", create_aggregate_function_percentile_array);
factory.register_function_both("percentile", create_aggregate_function_percentile);
factory.register_function_both("percentile_array", create_aggregate_function_percentile_array);
}

void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory) {
Expand Down
18 changes: 10 additions & 8 deletions be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,21 @@ static IAggregateFunction* create_function_single_value(const String& name,
const bool result_is_nullable) {
WhichDataType which(remove_nullable(argument_types[0]));

#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return creator_without_type::create<AggregateFunctionTemplate< \
NameData<Data<TYPE, BaseData<TYPE, is_stddev>>>, is_nullable>>(result_is_nullable, \
argument_types);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return creator_without_type::create<AggregateFunctionTemplate< \
NameData<Data<TYPE, BaseData<TYPE, is_stddev>>>, is_nullable>>( \
result_is_nullable, \
is_nullable ? remove_nullable(argument_types) : argument_types);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH

#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return creator_without_type::create<AggregateFunctionTemplate< \
NameData<Data<TYPE, BaseDatadecimal<TYPE, is_stddev>>>, is_nullable>>( \
result_is_nullable, argument_types);
result_is_nullable, \
is_nullable ? remove_nullable(argument_types) : argument_types);
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH

Expand Down Expand Up @@ -89,10 +91,10 @@ AggregateFunctionPtr create_aggregate_function_stddev_pop(const std::string& nam
}

void register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory& factory) {
factory.register_function("variance", create_aggregate_function_variance_pop<false>);
factory.register_function_both("variance", create_aggregate_function_variance_pop<false>);
factory.register_alias("variance", "var_pop");
factory.register_alias("variance", "variance_pop");
factory.register_function("stddev", create_aggregate_function_stddev_pop<true>);
factory.register_function_both("stddev", create_aggregate_function_stddev_pop<true>);
factory.register_alias("stddev", "stddev_pop");
}

Expand Down
1 change: 1 addition & 0 deletions be/src/vec/aggregate_functions/aggregate_function_stddev.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
Expand Down