Skip to content

Commit

Permalink
[opt](agg function) Signature verification of aggregate function. (ap…
Browse files Browse the repository at this point in the history
…ache#40682)

Prevent be from crash when Nullable property of aggregate function is
changed on FE.

Rule:

* With group by key:

Result type of Aggregate function must be exactly same with type from
planner.

* Without group by key:

If planner gives BE a nullable type, but aggregate function returns a
not null type, it is valid.
Else, a type mismatch is satisfied, BE will report an exception.

Note, for `xxx_foreach`, type check is disabled for now.
  • Loading branch information
zhiqiang-hhhh authored Nov 4, 2024
1 parent ced1a1d commit 865028f
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 10 deletions.
2 changes: 1 addition & 1 deletion be/src/pipeline/exec/aggregation_sink_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ Status AggSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) {
RETURN_IF_ERROR(vectorized::AggFnEvaluator::create(
_pool, tnode.agg_node.aggregate_functions[i],
tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy,
&evaluator));
tnode.agg_node.grouping_exprs.empty(), &evaluator));
_aggregate_evaluators.push_back(evaluator);
}

Expand Down
6 changes: 4 additions & 2 deletions be/src/pipeline/exec/analytic_source_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,11 +502,13 @@ Status AnalyticSourceOperatorX::init(const TPlanNode& tnode, RuntimeState* state
RETURN_IF_ERROR(OperatorX<AnalyticLocalState>::init(tnode, state));
const TAnalyticNode& analytic_node = tnode.analytic_node;
size_t agg_size = analytic_node.analytic_functions.size();

for (int i = 0; i < agg_size; ++i) {
vectorized::AggFnEvaluator* evaluator = nullptr;
// Window function treats all NullableAggregateFunction as AlwaysNullable.
// Its behavior is same with executed without group by key.
// https://github.com/apache/doris/pull/40693
RETURN_IF_ERROR(vectorized::AggFnEvaluator::create(
_pool, analytic_node.analytic_functions[i], {}, &evaluator));
_pool, analytic_node.analytic_functions[i], {}, /*wihout_key*/ true, &evaluator));
_agg_functions.emplace_back(evaluator);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ Status DistinctStreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState*
RETURN_IF_ERROR(vectorized::AggFnEvaluator::create(
_pool, tnode.agg_node.aggregate_functions[i],
tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy,
&evaluator));
tnode.agg_node.grouping_exprs.empty(), &evaluator));
_aggregate_evaluators.push_back(evaluator);
}

Expand Down
2 changes: 1 addition & 1 deletion be/src/pipeline/exec/streaming_aggregation_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state)
RETURN_IF_ERROR(vectorized::AggFnEvaluator::create(
_pool, tnode.agg_node.aggregate_functions[i],
tnode.agg_node.__isset.agg_sort_infos ? tnode.agg_node.agg_sort_infos[i] : dummy,
&evaluator));
tnode.agg_node.grouping_exprs.empty(), &evaluator));
_aggregate_evaluators.push_back(evaluator);
}

Expand Down
44 changes: 44 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#pragma once

#include "common/exception.h"
#include "common/status.h"
#include "util/defer_op.h"
#include "vec/columns/column_complex.h"
#include "vec/columns/column_string.h"
Expand All @@ -30,6 +32,7 @@
#include "vec/core/column_numbers.h"
#include "vec/core/field.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_string.h"

namespace doris::vectorized {
Expand Down Expand Up @@ -222,6 +225,10 @@ class IAggregateFunction {

virtual AggregateFunctionPtr transmit_to_stable() { return nullptr; }

/// Verify function signature
virtual Status verify_result_type(const bool without_key, const DataTypes& argument_types,
const DataTypePtr result_type) const = 0;

protected:
DataTypes argument_types;
int version {};
Expand Down Expand Up @@ -494,6 +501,43 @@ class IAggregateFunctionHelper : public IAggregateFunction {
arena);
assert_cast<const Derived*, TypeCheckOnRelease::DISABLE>(this)->merge(place, rhs, arena);
}

Status verify_result_type(const bool without_key, const DataTypes& argument_types_with_nullable,
const DataTypePtr result_type_with_nullable) const override {
DataTypePtr function_result_type = assert_cast<const Derived*>(this)->get_return_type();

if (function_result_type->equals(*result_type_with_nullable)) {
return Status::OK();
}

if (!remove_nullable(function_result_type)
->equals(*remove_nullable(result_type_with_nullable))) {
return Status::InternalError(
"Result type of {} is not matched, planner expect {}, but get {}, with group "
"by: "
"{}",
get_name(), result_type_with_nullable->get_name(),
function_result_type->get_name(), !without_key);
}

if (without_key == true) {
if (result_type_with_nullable->is_nullable()) {
// This branch is decicated for NullableAggregateFunction.
// When they are executed without group by key, the result from planner will be AlwaysNullable
// since Planer does not know whether there are any invalid input at runtime, if so, the result
// should be Null, so the result type must be nullable.
// Backend will wrap a ColumnNullable in this situation. For example: AggLocalState::_get_without_key_result
return Status::OK();
}
}

// Executed with group by key, result type must be exactly same with the return type from Planner.
return Status::InternalError(
"Result type of {} is not matched, planner expect {}, but get {}, with group by: "
"{}",
get_name(), result_type_with_nullable->get_name(), function_result_type->get_name(),
!without_key);
}
};

/// Implements several methods for manipulation with data. T - type of structure with data for aggregation.
Expand Down
17 changes: 14 additions & 3 deletions be/src/vec/exprs/vectorized_agg_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
#include "vec/exprs/vexpr_context.h"
#include "vec/utils/util.hpp"

static constexpr int64_t BE_VERSION_THAT_SUPPORT_NULLABLE_CHECK = 8;

namespace doris {
class RowDescriptor;
namespace vectorized {
Expand All @@ -63,9 +65,10 @@ AggregateFunctionPtr get_agg_state_function(const DataTypes& argument_types,
argument_types, return_type);
}

AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
AggFnEvaluator::AggFnEvaluator(const TExprNode& desc, const bool without_key)
: _fn(desc.fn),
_is_merge(desc.agg_expr.is_merge_agg),
_without_key(without_key),
_return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)) {
bool nullable = true;
if (desc.__isset.is_nullable) {
Expand All @@ -83,8 +86,8 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
}

Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info,
AggFnEvaluator** result) {
*result = pool->add(AggFnEvaluator::create_unique(desc.nodes[0]).release());
const bool without_key, AggFnEvaluator** result) {
*result = pool->add(AggFnEvaluator::create_unique(desc.nodes[0], without_key).release());
auto& agg_fn_evaluator = *result;
int node_idx = 0;
for (int i = 0; i < desc.nodes[0].num_children; ++i) {
Expand Down Expand Up @@ -213,6 +216,13 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc,
_function = transform_to_sort_agg_function(_function, _argument_types_with_sort,
_sort_description, state);
}

if (!AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name)) {
if (state->be_exec_version() >= BE_VERSION_THAT_SUPPORT_NULLABLE_CHECK) {
RETURN_IF_ERROR(
_function->verify_result_type(_without_key, argument_types, _data_type));
}
}
_expr_name = fmt::format("{}({})", _fn.name.function_name, child_expr_name);
return Status::OK();
}
Expand Down Expand Up @@ -320,6 +330,7 @@ AggFnEvaluator* AggFnEvaluator::clone(RuntimeState* state, ObjectPool* pool) {
AggFnEvaluator::AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state)
: _fn(evaluator._fn),
_is_merge(evaluator._is_merge),
_without_key(evaluator._without_key),
_argument_types_with_sort(evaluator._argument_types_with_sort),
_real_argument_types(evaluator._real_argument_types),
_return_type(evaluator._return_type),
Expand Down
8 changes: 6 additions & 2 deletions be/src/vec/exprs/vectorized_agg_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class AggFnEvaluator {

public:
static Status create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info,
AggFnEvaluator** result);
const bool without_key, AggFnEvaluator** result);

Status prepare(RuntimeState* state, const RowDescriptor& desc,
const SlotDescriptor* intermediate_slot_desc,
Expand Down Expand Up @@ -109,8 +109,12 @@ class AggFnEvaluator {
const TFunction _fn;

const bool _is_merge;
// We need this flag to distinguish between the two types of aggregation functions:
// 1. executed without group by key (agg function used with window function is also regarded as this type)
// 2. executed with group by key
const bool _without_key;

AggFnEvaluator(const TExprNode& desc);
AggFnEvaluator(const TExprNode& desc, const bool without_key);
AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state);

Status _calc_argument_columns(Block* block);
Expand Down

0 comments on commit 865028f

Please sign in to comment.