Skip to content

Commit

Permalink
Add AggregateDestructorType which signifies whether or not an aggrega…
Browse files Browse the repository at this point in the history
…te state can be trivially destructible - only AggregateDestructorType::LEGACY can be trivially destructible
  • Loading branch information
Mytherin committed Oct 28, 2024
1 parent b83a0be commit 6b00cdf
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 42 deletions.
25 changes: 14 additions & 11 deletions extension/core_functions/aggregate/distributive/arg_min_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,21 +314,22 @@ struct VectorArgMinMaxBase : ArgMinMaxBase<COMPARATOR, IGNORE_NULL> {
template <class OP>
AggregateFunction GetGenericArgMinMaxFunction() {
using STATE = ArgMinMaxState<string_t, string_t>;
return AggregateFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY,
AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
OP::template Update<STATE>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, OP::Bind,
AggregateFunction::StateDestroy<STATE, OP>);
return AggregateFunction(
{LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>, OP::template Update<STATE>,
AggregateFunction::StateCombine<STATE, OP>, AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, OP::Bind,
AggregateFunction::StateDestroy<STATE, OP>);
}

template <class OP, class ARG_TYPE, class BY_TYPE>
AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) {
#ifndef DUCKDB_SMALLER_BINARY
using STATE = ArgMinMaxState<ARG_TYPE, BY_TYPE>;
return AggregateFunction(
{type, by_type}, type, AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
OP::template Update<STATE>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, OP::Bind, AggregateFunction::StateDestroy<STATE, OP>);
return AggregateFunction({type, by_type}, type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
OP::template Update<STATE>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, OP::Bind,
AggregateFunction::StateDestroy<STATE, OP>);
#else
auto function = GetGenericArgMinMaxFunction<OP>();
function.arguments = {type, by_type};
Expand Down Expand Up @@ -380,7 +381,9 @@ template <class OP, class ARG_TYPE, class BY_TYPE>
AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) {
#ifndef DUCKDB_SMALLER_BINARY
using STATE = ArgMinMaxState<ARG_TYPE, BY_TYPE>;
auto function = AggregateFunction::BinaryAggregate<STATE, ARG_TYPE, BY_TYPE, ARG_TYPE, OP>(type, by_type, type);
auto function =
AggregateFunction::BinaryAggregate<STATE, ARG_TYPE, BY_TYPE, ARG_TYPE, OP, AggregateDestructorType::LEGACY>(
type, by_type, type);
if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) {
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;
}
Expand Down Expand Up @@ -618,7 +621,7 @@ static void SpecializeArgMinMaxNFunction(AggregateFunction &function) {
using OP = MinMaxNOperation;

function.state_size = AggregateFunction::StateSize<STATE>;
function.initialize = AggregateFunction::StateInitialize<STATE, OP>;
function.initialize = AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>;
function.combine = AggregateFunction::StateCombine<STATE, OP>;
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;

Expand Down
3 changes: 2 additions & 1 deletion extension/core_functions/aggregate/holistic/mad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const Logical
const LogicalType &target_type) {
using STATE = QuantileState<INPUT_TYPE, QuantileStandardType>;
using OP = MedianAbsoluteDeviationOperation<MEDIAN_TYPE>;
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP>(input_type, target_type);
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP,
AggregateDestructorType::LEGACY>(input_type, target_type);
fun.bind = BindMAD;
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
#ifndef DUCKDB_SMALLER_BINARY
Expand Down
12 changes: 8 additions & 4 deletions extension/core_functions/aggregate/holistic/mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ AggregateFunction GetFallbackModeFunction(const LogicalType &type) {
using STATE = ModeState<string_t, ModeString>;
using OP = ModeFallbackFunction<ModeString>;
AggregateFunction aggr({type}, type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr);
aggr.destructor = AggregateFunction::StateDestroy<STATE, OP>;
Expand All @@ -435,7 +435,9 @@ template <typename INPUT_TYPE, typename TYPE_OP = ModeStandard<INPUT_TYPE>>
AggregateFunction GetTypedModeFunction(const LogicalType &type) {
using STATE = ModeState<INPUT_TYPE, TYPE_OP>;
using OP = ModeFunction<TYPE_OP>;
auto func = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP>(type, type);
auto func =
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP, AggregateDestructorType::LEGACY>(
type, type);
func.window = OP::template Window<STATE, INPUT_TYPE, INPUT_TYPE>;
return func;
}
Expand Down Expand Up @@ -528,7 +530,9 @@ template <typename INPUT_TYPE, typename TYPE_OP = ModeStandard<INPUT_TYPE>>
AggregateFunction GetTypedEntropyFunction(const LogicalType &type) {
using STATE = ModeState<INPUT_TYPE, TYPE_OP>;
using OP = EntropyFunction<TYPE_OP>;
auto func = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, double, OP>(type, LogicalType::DOUBLE);
auto func =
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, double, OP, AggregateDestructorType::LEGACY>(
type, LogicalType::DOUBLE);
func.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
return func;
}
Expand All @@ -537,7 +541,7 @@ AggregateFunction GetFallbackEntropyFunction(const LogicalType &type) {
using STATE = ModeState<string_t, ModeString>;
using OP = EntropyFallbackFunction<ModeString>;
AggregateFunction func({type}, LogicalType::DOUBLE, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, double, OP>, nullptr);
func.destructor = AggregateFunction::StateDestroy<STATE, OP>;
Expand Down
31 changes: 18 additions & 13 deletions extension/core_functions/aggregate/holistic/quantile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ struct ScalarDiscreteQuantile {
static AggregateFunction GetFunction(const LogicalType &type) {
using STATE = QuantileState<INPUT_TYPE, TYPE_OP>;
using OP = QuantileScalarOperation<true>;
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP>(type, type);
auto fun = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP,
AggregateDestructorType::LEGACY>(type, type);
#ifndef DUCKDB_SMALLER_BINARY
fun.window = OP::Window<STATE, INPUT_TYPE, INPUT_TYPE>;
fun.window_init = OP::WindowInit<STATE, INPUT_TYPE>;
Expand All @@ -432,11 +433,12 @@ struct ScalarDiscreteQuantile {
using STATE = QuantileState<string_t, QuantileStringType>;
using OP = QuantileScalarFallback;

AggregateFunction fun(
{type}, type, AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, nullptr,
AggregateFunction::StateDestroy<STATE, OP>);
AggregateFunction fun({type}, type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>,
AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateVoidFinalize<STATE, OP>, nullptr, nullptr,
AggregateFunction::StateDestroy<STATE, OP>);
return fun;
}
};
Expand All @@ -445,7 +447,8 @@ template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
static AggregateFunction QuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { // NOLINT
LogicalType result_type = LogicalType::LIST(child_type);
return AggregateFunction(
{input_type}, result_type, AggregateFunction::StateSize<STATE>, AggregateFunction::StateInitialize<STATE, OP>,
{input_type}, result_type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateFunction::UnaryScatterUpdate<STATE, INPUT_TYPE, OP>, AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>, AggregateFunction::UnaryUpdate<STATE, INPUT_TYPE, OP>,
nullptr, AggregateFunction::StateDestroy<STATE, OP>);
Expand All @@ -469,11 +472,12 @@ struct ListDiscreteQuantile {
using STATE = QuantileState<string_t, QuantileStringType>;
using OP = QuantileListFallback;

AggregateFunction fun(
{type}, LogicalType::LIST(type), AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>, AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>,
AggregateFunction::StateCombine<STATE, OP>, AggregateFunction::StateFinalize<STATE, list_entry_t, OP>,
nullptr, nullptr, AggregateFunction::StateDestroy<STATE, OP>);
AggregateFunction fun({type}, LogicalType::LIST(type), AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>,
AggregateSortKeyHelpers::UnaryUpdate<STATE, OP>,
AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, list_entry_t, OP>, nullptr, nullptr,
AggregateFunction::StateDestroy<STATE, OP>);
return fun;
}
};
Expand Down Expand Up @@ -547,7 +551,8 @@ struct ScalarContinuousQuantile {
using STATE = QuantileState<INPUT_TYPE, QuantileStandardType>;
using OP = QuantileScalarOperation<false>;
auto fun =
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP>(input_type, target_type);
AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, TARGET_TYPE, OP,
AggregateDestructorType::LEGACY>(input_type, target_type);
fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT;
#ifndef DUCKDB_SMALLER_BINARY
fun.window = OP::template Window<STATE, INPUT_TYPE, TARGET_TYPE>;
Expand Down
2 changes: 1 addition & 1 deletion src/function/aggregate/distributive/minmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ static void SpecializeMinMaxNFunction(AggregateFunction &function) {
using OP = MinMaxNOperation;

function.state_size = AggregateFunction::StateSize<STATE>;
function.initialize = AggregateFunction::StateInitialize<STATE, OP>;
function.initialize = AggregateFunction::StateInitialize<STATE, OP, AggregateDestructorType::LEGACY>;
function.combine = AggregateFunction::StateCombine<STATE, OP>;
function.destructor = AggregateFunction::StateDestroy<STATE, OP>;

Expand Down
3 changes: 2 additions & 1 deletion src/function/aggregate/sorted_aggregate_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE
// Replace the aggregate with the wrapper
AggregateFunction ordered_aggregate(
bound_function.name, arguments, bound_function.return_type, AggregateFunction::StateSize<SortedAggregateState>,
AggregateFunction::StateInitialize<SortedAggregateState, SortedAggregateFunction>,
AggregateFunction::StateInitialize<SortedAggregateState, SortedAggregateFunction,
AggregateDestructorType::LEGACY>,
SortedAggregateFunction::ScatterUpdate,
AggregateFunction::StateCombine<SortedAggregateState, SortedAggregateFunction>,
SortedAggregateFunction::Finalize, bound_function.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr,
Expand Down
37 changes: 26 additions & 11 deletions src/include/duckdb/function/aggregate_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ struct AggregateFunctionInfo {
}
};

enum class AggregateDestructorType {
STANDARD,
// legacy destructors allow non-trivial destructors in aggregate states
// these might not be trivial to off-load to disk
LEGACY
};

class AggregateFunction : public BaseScalarFunction { // NOLINT: work-around bug in clang-tidy
public:
AggregateFunction(const string &name, const vector<LogicalType> &arguments, const LogicalType &return_type,
Expand Down Expand Up @@ -206,29 +213,33 @@ class AggregateFunction : public BaseScalarFunction { // NOLINT: work-around bug
AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>, AggregateFunction::NullaryUpdate<STATE, OP>);
}

template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP,
AggregateDestructorType destructor_type = AggregateDestructorType::STANDARD>
static AggregateFunction
UnaryAggregate(const LogicalType &input_type, LogicalType return_type,
FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING) {
return AggregateFunction(
{input_type}, return_type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>, AggregateFunction::UnaryScatterUpdate<STATE, INPUT_TYPE, OP>,
AggregateFunction::StateCombine<STATE, OP>, AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>,
null_handling, AggregateFunction::UnaryUpdate<STATE, INPUT_TYPE, OP>);
return AggregateFunction({input_type}, return_type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP, destructor_type>,
AggregateFunction::UnaryScatterUpdate<STATE, INPUT_TYPE, OP>,
AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>, null_handling,
AggregateFunction::UnaryUpdate<STATE, INPUT_TYPE, OP>);
}

template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP>
template <class STATE, class INPUT_TYPE, class RESULT_TYPE, class OP,
AggregateDestructorType destructor_type = AggregateDestructorType::STANDARD>
static AggregateFunction UnaryAggregateDestructor(LogicalType input_type, LogicalType return_type) {
auto aggregate = UnaryAggregate<STATE, INPUT_TYPE, RESULT_TYPE, OP>(input_type, return_type);
auto aggregate = UnaryAggregate<STATE, INPUT_TYPE, RESULT_TYPE, OP, destructor_type>(input_type, return_type);
aggregate.destructor = AggregateFunction::StateDestroy<STATE, OP>;
return aggregate;
}

template <class STATE, class A_TYPE, class B_TYPE, class RESULT_TYPE, class OP>
template <class STATE, class A_TYPE, class B_TYPE, class RESULT_TYPE, class OP,
AggregateDestructorType destructor_type = AggregateDestructorType::STANDARD>
static AggregateFunction BinaryAggregate(const LogicalType &a_type, const LogicalType &b_type,
LogicalType return_type) {
return AggregateFunction({a_type, b_type}, return_type, AggregateFunction::StateSize<STATE>,
AggregateFunction::StateInitialize<STATE, OP>,
AggregateFunction::StateInitialize<STATE, OP, destructor_type>,
AggregateFunction::BinaryScatterUpdate<STATE, A_TYPE, B_TYPE, OP>,
AggregateFunction::StateCombine<STATE, OP>,
AggregateFunction::StateFinalize<STATE, RESULT_TYPE, OP>,
Expand All @@ -241,8 +252,12 @@ class AggregateFunction : public BaseScalarFunction { // NOLINT: work-around bug
return sizeof(STATE);
}

template <class STATE, class OP>
template <class STATE, class OP, AggregateDestructorType destructor_type = AggregateDestructorType::STANDARD>
static void StateInitialize(const AggregateFunction &, data_ptr_t state) {
// FIXME: we should remove the "destructor_type" option in the future
static_assert(std::is_trivially_destructible<STATE>::value ||
destructor_type == AggregateDestructorType::LEGACY,
"Aggregate state must be trivially destructible");
OP::Initialize(*reinterpret_cast<STATE *>(state));
}

Expand Down

0 comments on commit 6b00cdf

Please sign in to comment.