Skip to content

Commit

Permalink
[Enhancement] Add exception safe interface for AggregateFunction (Sta…
Browse files Browse the repository at this point in the history
…rRocks#55392)

Signed-off-by: trueeyu <lxhhust350@qq.com>
  • Loading branch information
trueeyu authored Jan 24, 2025
1 parent a2c62c0 commit d47385d
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 0 deletions.
45 changes: 45 additions & 0 deletions be/src/exprs/agg/aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,22 @@
#include <type_traits>

#include "column/column.h"
#include "runtime/current_thread.h"

namespace starrocks {
class FunctionContext;
}

namespace starrocks {

#define EXCEPTION_SAFE_FUNC_CALL(func) \
if (is_exception_safe()) { \
return func; \
} else { \
SCOPED_SET_CATCHED(false); \
return func; \
}

/**
* For each aggregate function, it may use different agg state kind for Incremental MV:
* - RESULT : Use result table data as AggStateData, eg sum/count
Expand Down Expand Up @@ -56,6 +65,8 @@ class AggregateFunction {
public:
virtual ~AggregateFunction() = default;

virtual bool is_exception_safe() const { return true; }

// Reset the aggregation state, for aggregate window functions
virtual void reset(FunctionContext* ctx, const Columns& args, AggDataPtr __restrict state) const {}

Expand Down Expand Up @@ -160,14 +171,31 @@ class AggregateFunction {
virtual void update_batch(FunctionContext* ctx, size_t chunk_size, size_t state_offset, const Column** columns,
AggDataPtr* states) const = 0;

virtual void update_batch_exception_safe(FunctionContext* ctx, size_t chunk_size, size_t state_offset,
const Column** columns, AggDataPtr* states) const final {
EXCEPTION_SAFE_FUNC_CALL(update_batch(ctx, chunk_size, state_offset, columns, states));
}

// filter[i] = 0, will be update
virtual void update_batch_selectively(FunctionContext* ctx, size_t chunk_size, size_t state_offset,
const Column** columns, AggDataPtr* states, const Filter& filter) const = 0;

virtual void update_batch_selectively_exception_safe(FunctionContext* ctx, size_t chunk_size, size_t state_offset,
const Column** columns, AggDataPtr* states,
const Filter& filter) const final {
EXCEPTION_SAFE_FUNC_CALL(update_batch_selectively(ctx, chunk_size, state_offset, columns, states, filter));
}

// update result to single state
virtual void update_batch_single_state(FunctionContext* ctx, size_t chunk_size, const Column** columns,
AggDataPtr __restrict state) const = 0;

virtual void update_batch_single_state_exception_safe(FunctionContext* ctx, size_t chunk_size,
const Column** columns,
AggDataPtr __restrict state) const final {
EXCEPTION_SAFE_FUNC_CALL(update_batch_single_state(ctx, chunk_size, columns, state));
}

// For window functions
// A peer group is all of the rows that are peers within the specified ordering.
// Rows are peers if they compare equal to each other using the specified ordering expression.
Expand Down Expand Up @@ -202,17 +230,33 @@ class AggregateFunction {
virtual void merge_batch(FunctionContext* ctx, size_t chunk_size, size_t state_offset, const Column* column,
AggDataPtr* states) const = 0;

virtual void merge_batch_exception_safe(FunctionContext* ctx, size_t chunk_size, size_t state_offset,
const Column* column, AggDataPtr* states) const final {
EXCEPTION_SAFE_FUNC_CALL(merge_batch(ctx, chunk_size, state_offset, column, states));
}

// filter[i] = 0, will be merged
virtual void merge_batch_selectively(FunctionContext* ctx, size_t chunk_size, size_t state_offset,
const Column* column, AggDataPtr* states, const Filter& filter) const = 0;

virtual void merge_batch_selectively_exception_safe(FunctionContext* ctx, size_t chunk_size, size_t state_offset,
const Column* column, AggDataPtr* states,
const Filter& filter) const final {
EXCEPTION_SAFE_FUNC_CALL(merge_batch_selectively(ctx, chunk_size, state_offset, column, states, filter));
}

// Merge some continuous portion of a chunk to a given state.
// This will be useful for sorted streaming aggregation.
// 'start': the start position of the continuous portion
// 'size': the length of the continuous portion
virtual void merge_batch_single_state(FunctionContext* ctx, AggDataPtr __restrict state, const Column* column,
size_t start, size_t size) const = 0;

virtual void merge_batch_single_state_exception_safe(FunctionContext* ctx, AggDataPtr __restrict state,
const Column* column, size_t start, size_t size) const final {
EXCEPTION_SAFE_FUNC_CALL(merge_batch_single_state(ctx, state, column, start, size));
}

///////////////// STREAM MV METHODS /////////////////

// Return stream agg function's state table kind, see AggStateTableKind's description.
Expand Down Expand Up @@ -420,4 +464,5 @@ using AggregateFunctionPtr = std::shared_ptr<AggregateFunction>;

struct AggregateFunctionEmptyState {};

#undef EXCEPTION_SAFE_FUNC_CALL
} // namespace starrocks
2 changes: 2 additions & 0 deletions be/src/exprs/agg/bitmap_agg.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class BitmapAggAggregateFunction final
using InputColumnType = RunTimeColumnType<LT>;
using InputCppType = RunTimeCppType<LT>;

bool is_exception_safe() const override { return false; }

void update(FunctionContext* ctx, const Column** columns, AggDataPtr state, size_t row_num) const override {
const auto* col = down_cast<const InputColumnType*>(columns[0]);
auto value = col->get_data()[row_num];
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/agg/bitmap_intersect.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ struct BitmapValuePacked {
class BitmapIntersectAggregateFunction final
: public AggregateFunctionBatchHelper<BitmapValuePacked, BitmapIntersectAggregateFunction> {
public:
bool is_exception_safe() const override { return false; }

void update(FunctionContext* ctx, const Column** columns, AggDataPtr state, size_t row_num) const override {
const auto* col = down_cast<const BitmapColumn*>(columns[0]);
if (!this->data(state).initial) {
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/agg/bitmap_union.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ namespace starrocks {
class BitmapUnionAggregateFunction final
: public AggregateFunctionBatchHelper<BitmapValue, BitmapUnionAggregateFunction> {
public:
bool is_exception_safe() const override { return false; }

void update(FunctionContext* ctx, const Column** columns, AggDataPtr state, size_t row_num) const override {
const auto* col = down_cast<const BitmapColumn*>(columns[0]);
this->data(state) |= *(col->get_object(row_num));
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/agg/bitmap_union_count.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ namespace starrocks {
class BitmapUnionCountAggregateFunction final
: public AggregateFunctionBatchHelper<BitmapValue, BitmapUnionCountAggregateFunction> {
public:
bool is_exception_safe() const override { return false; }

void reset(FunctionContext* ctx, const Columns& args, AggDataPtr __restrict state) const override {
this->data(state).clear();
}
Expand Down
3 changes: 3 additions & 0 deletions be/src/exprs/agg/bitmap_union_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class BitmapUnionIntAggregateFunction final
: public AggregateFunctionBatchHelper<BitmapValue, BitmapUnionIntAggregateFunction<LT, T>> {
public:
using InputColumnType = RunTimeColumnType<LT>;

bool is_exception_safe() const override { return false; }

void update(FunctionContext* ctx, const Column** columns, AggDataPtr state, size_t row_num) const override {
DCHECK((*columns[0]).is_numeric());
if constexpr (std::is_integral_v<T>) {
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/agg/java_udaf_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class JavaUDAFAggregateFunction : public AggregateFunction {
public:
using State = JavaUDAFState;

bool is_exception_safe() const override { return false; }

void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state, size_t row_num) const final {
CHECK(false) << "unreadable path";
}
Expand Down
2 changes: 2 additions & 0 deletions be/src/exprs/agg/nullable_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class NullableAggregateFunctionBase : public AggregateFunctionStateHelper<State>
static constexpr bool is_result_always_nullable = !std::is_same_v<AggNullPred, AggNonNullPred<NestedState>>;

public:
bool is_exception_safe() const override { return nested_function->is_exception_safe(); }

explicit NullableAggregateFunctionBase(NestedAggregateFunctionPtr nested_function_,
AggNullPred null_pred = AggNullPred())
: nested_function(std::move(nested_function_)), null_pred(std::move(null_pred)) {}
Expand Down

0 comments on commit d47385d

Please sign in to comment.