Skip to content

Commit

Permalink
Reformat aggregate functions (duckdb#14530)
Browse files Browse the repository at this point in the history
### Merge order
The function formatting PRs should be merged in this order (all pointing
to Feature branch):
- [14470 - Reformat compressed materialization
functions](duckdb#14470)
- [14489 - Reformat arithmetic
operators](duckdb#14489)
- [14495 - Reformat nested and sequence
functions](duckdb#14495)
- [14530 - Reformat aggregate
functions](duckdb#14530) (this PR)
  • Loading branch information
Mytherin authored Oct 28, 2024
2 parents e3b77e3 + eb2a5e8 commit 895a496
Show file tree
Hide file tree
Showing 19 changed files with 197 additions and 85 deletions.
2 changes: 1 addition & 1 deletion scripts/generate_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


function_groups = {
('src', 'include/duckdb', 'function'): ['scalar'],
('src', 'include/duckdb', 'function'): ['scalar', 'aggregate'],
('extension', 'core_functions/include', 'core_functions'): ['scalar', 'aggregate'],
}

Expand Down
3 changes: 2 additions & 1 deletion src/execution/operator/join/physical_hash_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp"
#include "duckdb/function/aggregate/distributive_functions.hpp"
#include "duckdb/function/aggregate/distributive_function_utils.hpp"
#include "duckdb/function/function_binder.hpp"
#include "duckdb/main/client_context.hpp"
#include "duckdb/main/query_profiler.hpp"
Expand Down Expand Up @@ -248,7 +249,7 @@ unique_ptr<JoinHashTable> PhysicalHashJoin::InitializeHashTable(ClientContext &c
delim_payload_types.push_back(aggr->return_type);
info.correlated_aggregates.push_back(std::move(aggr));

auto count_fun = CountFun::GetFunction();
auto count_fun = CountFunctionBase::GetFunction();
vector<unique_ptr<Expression>> children;
// this is a dummy but we need it to make the hash table understand whats going on
children.push_back(make_uniq_base<Expression, BoundReferenceExpression>(count_fun.return_type, 0U));
Expand Down
7 changes: 4 additions & 3 deletions src/execution/physical_plan/plan_distinct.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp"
#include "duckdb/execution/operator/projection/physical_projection.hpp"
#include "duckdb/execution/physical_plan_generator.hpp"
#include "duckdb/function/aggregate/distributive_functions.hpp"
#include "duckdb/function/aggregate/distributive_function_utils.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/planner/expression/bound_reference_expression.hpp"
#include "duckdb/planner/operator/logical_distinct.hpp"
Expand Down Expand Up @@ -59,8 +59,9 @@ unique_ptr<PhysicalOperator> PhysicalPlanGenerator::CreatePlan(LogicalDistinct &
first_children.push_back(std::move(bound));

FunctionBinder function_binder(context);
auto first_aggregate = function_binder.BindAggregateFunction(
FirstFun::GetFunction(logical_type), std::move(first_children), nullptr, AggregateType::NON_DISTINCT);
auto first_aggregate =
function_binder.BindAggregateFunction(FirstFunctionGetter::GetFunction(logical_type),
std::move(first_children), nullptr, AggregateType::NON_DISTINCT);
first_aggregate->order_bys = op.order_by ? op.order_by->Copy() : nullptr;

if (ClientConfig::GetConfig(context).enable_optimizer) {
Expand Down
3 changes: 1 addition & 2 deletions src/function/aggregate/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
add_subdirectory(distributive)

add_library_unity(duckdb_func_aggr OBJECT distributive_functions.cpp
sorted_aggregate_function.cpp)
add_library_unity(duckdb_func_aggr OBJECT sorted_aggregate_function.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:duckdb_func_aggr>
PARENT_SCOPE)
3 changes: 2 additions & 1 deletion src/function/aggregate/distributive/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_library_unity(duckdb_aggr_distr OBJECT count.cpp first.cpp minmax.cpp)
add_library_unity(duckdb_aggr_distr OBJECT count.cpp first_last_any.cpp
minmax.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:duckdb_aggr_distr>
PARENT_SCOPE)
16 changes: 5 additions & 11 deletions src/function/aggregate/distributive/count.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "duckdb/common/exception.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/function/aggregate/distributive_functions.hpp"
#include "duckdb/function/aggregate/distributive_function_utils.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"

namespace duckdb {
Expand Down Expand Up @@ -210,7 +211,7 @@ struct CountFunction : public BaseCountFunction {
}
};

AggregateFunction CountFun::GetFunction() {
AggregateFunction CountFunctionBase::GetFunction() {
AggregateFunction fun({LogicalType(LogicalTypeId::ANY)}, LogicalType::BIGINT, AggregateFunction::StateSize<int64_t>,
AggregateFunction::StateInitialize<int64_t, CountFunction>, CountFunction::CountScatter,
AggregateFunction::StateCombine<int64_t, CountFunction>,
Expand Down Expand Up @@ -241,21 +242,14 @@ unique_ptr<BaseStatistics> CountPropagateStats(ClientContext &context, BoundAggr
return nullptr;
}

void CountFun::RegisterFunction(BuiltinFunctions &set) {
AggregateFunction count_function = CountFun::GetFunction();
AggregateFunctionSet CountFun::GetFunctions() {
AggregateFunction count_function = CountFunctionBase::GetFunction();
count_function.statistics = CountPropagateStats;
AggregateFunctionSet count("count");
count.AddFunction(count_function);
// the count function can also be called without arguments
count_function = CountStarFun::GetFunction();
count.AddFunction(count_function);
set.AddFunction(count);
}

void CountStarFun::RegisterFunction(BuiltinFunctions &set) {
AggregateFunctionSet count("count_star");
count.AddFunction(CountStarFun::GetFunction());
set.AddFunction(count);
return count;
}

} // namespace duckdb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "duckdb/common/exception.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/function/aggregate/distributive_functions.hpp"
#include "duckdb/function/aggregate/distributive_function_utils.hpp"
#include "duckdb/function/create_sort_key.hpp"
#include "duckdb/planner/expression.hpp"

Expand Down Expand Up @@ -293,7 +294,7 @@ static AggregateFunction GetFirstFunction(const LogicalType &type) {
}
}

AggregateFunction FirstFun::GetFunction(const LogicalType &type) {
AggregateFunction FirstFunctionGetter::GetFunction(const LogicalType &type) {
auto fun = GetFirstFunction<false, false>(type);
fun.name = "first";
return fun;
Expand Down Expand Up @@ -340,22 +341,22 @@ static void AddFirstOperator(AggregateFunctionSet &set) {
nullptr, BindFirst<LAST, SKIP_NULLS>));
}

void FirstFun::RegisterFunction(BuiltinFunctions &set) {
AggregateFunctionSet FirstFun::GetFunctions() {
AggregateFunctionSet first("first");
AggregateFunctionSet last("last");
AggregateFunctionSet any_value("any_value");

AddFirstOperator<false, false>(first);
AddFirstOperator<true, false>(last);
AddFirstOperator<false, true>(any_value);

set.AddFunction(first);
first.name = "arbitrary";
set.AddFunction(first);
return first;
}

set.AddFunction(last);
AggregateFunctionSet LastFun::GetFunctions() {
AggregateFunctionSet last("last");
AddFirstOperator<true, false>(last);
return last;
}

set.AddFunction(any_value);
AggregateFunctionSet AnyValueFun::GetFunctions() {
AggregateFunctionSet any_value("any_value");
AddFirstOperator<false, true>(any_value);
return any_value;
}

} // namespace duckdb
52 changes: 52 additions & 0 deletions src/function/aggregate/distributive/functions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
[
{
"name": "count_star",
"parameters": "",
"description": "",
"example": "",
"type": "aggregate_function"
},
{
"name": "count",
"parameters": "arg",
"description": "Returns the number of non-null values in arg.",
"example": "count(A)",
"type": "aggregate_function_set"
},
{
"name": "first",
"parameters": "arg",
"description": "Returns the first value (null or non-null) from arg. This function is affected by ordering.",
"example": "first(A)",
"type": "aggregate_function_set",
"aliases": ["arbitrary"]
},
{
"name": "last",
"parameters": "arg",
"description": "Returns the last value of a column. This function is affected by ordering.",
"example": "last(A)",
"type": "aggregate_function_set"
},
{
"name": "any_value",
"parameters": "arg",
"description": "Returns the first non-null value from arg. This function is affected by ordering.",
"example": "",
"type": "aggregate_function_set"
},
{
"name": "min",
"parameters": "arg",
"description": "Returns the minimum value present in arg.",
"example": "min(A)",
"type": "aggregate_function_set"
},
{
"name": "max",
"parameters": "arg",
"description": "Returns the maximum value present in arg.",
"example": "max(A)",
"type": "aggregate_function_set"
}
]
17 changes: 9 additions & 8 deletions src/function/aggregate/distributive/minmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "duckdb/common/types/null_value.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/function/aggregate/distributive_functions.hpp"
#include "duckdb/function/aggregate/distributive_function_utils.hpp"
#include "duckdb/function/aggregate/minmax_n_helpers.hpp"
#include "duckdb/function/aggregate/sort_key_helpers.hpp"
#include "duckdb/function/function_binder.hpp"
Expand Down Expand Up @@ -389,11 +390,11 @@ static AggregateFunction GetMinMaxOperator(string name) {
nullptr, nullptr, BindMinMax<OP, OP_STRING, OP_VECTOR>);
}

AggregateFunction MinFun::GetFunction() {
AggregateFunction MinFunction::GetFunction() {
return GetMinMaxOperator<MinOperation, MinOperationString, MinOperationVector>("min");
}

AggregateFunction MaxFun::GetFunction() {
AggregateFunction MaxFunction::GetFunction() {
return GetMinMaxOperator<MaxOperation, MaxOperationString, MaxOperationVector>("max");
}

Expand Down Expand Up @@ -537,18 +538,18 @@ static AggregateFunction GetMinMaxNFunction() {
//---------------------------------------------------
// Function Registration
//---------------------------------------------------s
void MinFun::RegisterFunction(BuiltinFunctions &set) {
AggregateFunctionSet MinFun::GetFunctions() {
AggregateFunctionSet min("min");
min.AddFunction(GetFunction());
min.AddFunction(MinFunction::GetFunction());
min.AddFunction(GetMinMaxNFunction<LessThan>());
set.AddFunction(min);
return min;
}

void MaxFun::RegisterFunction(BuiltinFunctions &set) {
AggregateFunctionSet MaxFun::GetFunctions() {
AggregateFunctionSet max("max");
max.AddFunction(GetFunction());
max.AddFunction(MaxFunction::GetFunction());
max.AddFunction(GetMinMaxNFunction<GreaterThan>());
set.AddFunction(max);
return max;
}

} // namespace duckdb
17 changes: 0 additions & 17 deletions src/function/aggregate/distributive_functions.cpp

This file was deleted.

2 changes: 0 additions & 2 deletions src/function/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ void BuiltinFunctions::Initialize() {
RegisterTableFunctions();
RegisterArrowFunctions();

RegisterDistributiveAggregates();

RegisterPragmaFunctions();

// initialize collations
Expand Down
10 changes: 10 additions & 0 deletions src/function/function_list.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "duckdb/function/function_list.hpp"

#include "duckdb/function/aggregate/distributive_functions.hpp"
#include "duckdb/function/scalar/compressed_materialization_functions.hpp"
#include "duckdb/function/scalar/date_functions.hpp"
#include "duckdb/function/scalar/generic_functions.hpp"
Expand Down Expand Up @@ -66,6 +68,8 @@ static const StaticFunctionDefinition function[] = {
DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralUsmallintFun),
DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressStringFun),
DUCKDB_SCALAR_FUNCTION_SET_ALIAS(AddFun),
DUCKDB_AGGREGATE_FUNCTION_SET(AnyValueFun),
DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArbitraryFun),
DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayCatFun),
DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayConcatFun),
DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayContainsFun),
Expand All @@ -84,13 +88,17 @@ static const StaticFunctionDefinition function[] = {
DUCKDB_SCALAR_FUNCTION(ConcatWsFun),
DUCKDB_SCALAR_FUNCTION(ConstantOrNullFun),
DUCKDB_SCALAR_FUNCTION_SET(ContainsFun),
DUCKDB_AGGREGATE_FUNCTION_SET(CountFun),
DUCKDB_AGGREGATE_FUNCTION(CountStarFun),
DUCKDB_SCALAR_FUNCTION(CreateSortKeyFun),
DUCKDB_SCALAR_FUNCTION(CurrvalFun),
DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DivideFun),
DUCKDB_SCALAR_FUNCTION(ErrorFun),
DUCKDB_SCALAR_FUNCTION(FinalizeFun),
DUCKDB_AGGREGATE_FUNCTION_SET(FirstFun),
DUCKDB_SCALAR_FUNCTION(GetVariableFun),
DUCKDB_SCALAR_FUNCTION(IlikeEscapeFun),
DUCKDB_AGGREGATE_FUNCTION_SET(LastFun),
DUCKDB_SCALAR_FUNCTION_ALIAS(LcaseFun),
DUCKDB_SCALAR_FUNCTION_SET_ALIAS(LenFun),
DUCKDB_SCALAR_FUNCTION_SET(LengthFun),
Expand All @@ -110,8 +118,10 @@ static const StaticFunctionDefinition function[] = {
DUCKDB_SCALAR_FUNCTION(ListZipFun),
DUCKDB_SCALAR_FUNCTION(LowerFun),
DUCKDB_SCALAR_FUNCTION(MapContainsFun),
DUCKDB_AGGREGATE_FUNCTION_SET(MaxFun),
DUCKDB_SCALAR_FUNCTION_SET(MD5Fun),
DUCKDB_SCALAR_FUNCTION_SET(MD5NumberFun),
DUCKDB_AGGREGATE_FUNCTION_SET(MinFun),
DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ModFun),
DUCKDB_SCALAR_FUNCTION_SET_ALIAS(MultiplyFun),
DUCKDB_SCALAR_FUNCTION(NextvalFun),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===----------------------------------------------------------------------===//
// DuckDB
//
// duckdb/function/aggregate/distributive_functions.hpp
//
//
//===----------------------------------------------------------------------===//

#pragma once

#include "duckdb/function/function_set.hpp"

namespace duckdb {

struct CountFunctionBase {
static AggregateFunction GetFunction();
};

struct FirstFunctionGetter {
static AggregateFunction GetFunction(const LogicalType &type);
};

struct MinFunction {
static AggregateFunction GetFunction();
};

struct MaxFunction {
static AggregateFunction GetFunction();
};

} // namespace duckdb
Loading

0 comments on commit 895a496

Please sign in to comment.