Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enchancement](function) Inline some aggregate function && remove nullable combinator #17328

Merged
merged 15 commits into from
Mar 9, 2023
Prev Previous commit
Next Next commit
fix
  • Loading branch information
BiteTheDDDDt committed Mar 8, 2023
commit b05a87947de1b6bad24f0744a055602a04a5d889
49 changes: 25 additions & 24 deletions be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,33 @@ template <template <typename, bool> class AggregateFunctionTemplate,
bool is_stddev, bool is_nullable = false>
static IAggregateFunction* create_function_single_value(const String& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
const bool result_is_nullable,
bool custom_nullable) {
WhichDataType which(remove_nullable(argument_types[0]));

IAggregateFunction* res;
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return creator_without_type::create<AggregateFunctionTemplate< \
res = creator_without_type::create<AggregateFunctionTemplate< \
NameData<Data<TYPE, BaseData<TYPE, is_stddev>>>, is_nullable>>( \
result_is_nullable, \
is_nullable ? remove_nullable(argument_types) : argument_types);
custom_nullable ? remove_nullable(argument_types) : argument_types);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH

#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return creator_without_type::create<AggregateFunctionTemplate< \
res = creator_without_type::create<AggregateFunctionTemplate< \
NameData<Data<TYPE, BaseDatadecimal<TYPE, is_stddev>>>, is_nullable>>( \
result_is_nullable, \
is_nullable ? remove_nullable(argument_types) : argument_types);
custom_nullable ? remove_nullable(argument_types) : argument_types);
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH

DCHECK(false) << "with unknowed type, failed in create_aggregate_function_stddev_variance";
return nullptr;
if (res == nullptr) {
LOG(WARNING) << fmt::format("create_function_single_value with unknowed type {}",
argument_types[0]->get_name());
}
return res;
}

template <bool is_stddev, bool is_nullable>
Expand All @@ -60,16 +64,16 @@ AggregateFunctionPtr create_aggregate_function_variance_samp(const std::string&
return AggregateFunctionPtr(
create_function_single_value<AggregateFunctionSamp, VarianceSampName, SampData,
is_stddev, is_nullable>(name, argument_types,
result_is_nullable));
result_is_nullable, true));
}

template <bool is_stddev, bool is_nullable>
AggregateFunctionPtr create_aggregate_function_stddev_samp(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
return AggregateFunctionPtr(
create_function_single_value<AggregateFunctionSamp, StddevSampName, SampData, is_stddev,
is_nullable>(name, argument_types, result_is_nullable));
return AggregateFunctionPtr(create_function_single_value<AggregateFunctionSamp, StddevSampName,
SampData, is_stddev, is_nullable>(
name, argument_types, result_is_nullable, true));
}

template <bool is_stddev>
Expand All @@ -78,7 +82,7 @@ AggregateFunctionPtr create_aggregate_function_variance_pop(const std::string& n
const bool result_is_nullable) {
return AggregateFunctionPtr(
create_function_single_value<AggregateFunctionPop, VarianceName, PopData, is_stddev>(
name, argument_types, result_is_nullable));
name, argument_types, result_is_nullable, false));
}

template <bool is_stddev>
Expand All @@ -87,7 +91,7 @@ AggregateFunctionPtr create_aggregate_function_stddev_pop(const std::string& nam
const bool result_is_nullable) {
return AggregateFunctionPtr(
create_function_single_value<AggregateFunctionPop, StddevName, PopData, is_stddev>(
name, argument_types, result_is_nullable));
name, argument_types, result_is_nullable, false));
}

void register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory& factory) {
Expand All @@ -99,16 +103,13 @@ void register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFact
}

void register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory& factory) {
// _samp<bool, bool>: first indicate is stddev or variance function
// second indicate is arg nullable column
factory.register_function_both("variance_samp",
create_aggregate_function_variance_samp<false, false>);
factory.register_function_both("variance_samp",
create_aggregate_function_variance_samp<false, true>);
factory.register_function("variance_samp",
create_aggregate_function_variance_samp<false, false>);
factory.register_function("variance_samp", create_aggregate_function_variance_samp<false, true>,
true);
factory.register_alias("variance_samp", "var_samp");
factory.register_function_both("stddev_samp",
create_aggregate_function_stddev_samp<true, false>);
factory.register_function_both("stddev_samp",
create_aggregate_function_stddev_samp<true, true>);
factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true, false>);
factory.register_function("stddev_samp", create_aggregate_function_stddev_samp<true, true>,
true);
}
} // namespace doris::vectorized
4 changes: 2 additions & 2 deletions be/src/vec/functions/array/function_array_aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ struct AggregateFunction {
using Function = typename Derived::template TypeTraits<T>::Function;

static auto create(const DataTypePtr& data_type_ptr) -> AggregateFunctionPtr {
return AggregateFunctionPtr(creator_with_type::create<Function>(
true, DataTypes {make_nullable(data_type_ptr)}));
return AggregateFunctionPtr(
creator_with_type::create<Function>(true, DataTypes {data_type_ptr}));
}
};

Expand Down